[llvm] [X86] Add support for `__bf16` to `f16` conversion (PR #134859)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Sat Apr 26 10:26:44 PDT 2025
https://github.com/antoniofrighetto updated https://github.com/llvm/llvm-project/pull/134859
>From 7f2a1120c9fc31912e7d98de99dd2a04f368b5c6 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Tue, 8 Apr 2025 16:06:44 +0200
Subject: [PATCH 1/3] [X86] Add support for `__bf16` to `f16` conversion
`bf16` is a typedef short type introduced in AVX-512_BF16 and
should be able to leverage SSE/AVX registers used for `f16`.
Fixes: https://github.com/llvm/llvm-project/issues/134222.
---
llvm/lib/Target/X86/X86ISelLowering.cpp | 12 +++++++++++-
llvm/lib/Target/X86/X86InstrAVX512.td | 5 +++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a4381b99dbae0..ad170f49aebb5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
};
if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
- // f16, f32 and f64 use SSE.
+ // f16, bf16, f32 and f64 use SSE.
// Set up the FP register classes.
addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
: &X86::FR16RegClass);
+ addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
+ : &X86::FR16RegClass);
addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
: &X86::FR32RegClass);
addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// non-optsize case.
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
+ // Set the operation action Custom for bitcast to do the customization
+ // later.
+ setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+
for (auto VT : { MVT::f32, MVT::f64 }) {
// Use ANDPD to simulate FABS.
setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
return DAG.getZExtOrTrunc(V, DL, DstVT);
}
+ // Bitcasts between f16 and bf16 should be legal.
+ if (DstVT == MVT::f16 || DstVT == MVT::bf16)
+ return Op;
+
assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
SrcVT == MVT::i64) && "Unexpected VT!");
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 0ab94cca41425..5705094dca55b 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
(VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
}
+let Predicates = [HasAVX512, HasBF16] in {
+ def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+ def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+}
+
// ----------------------------------------------------------------
// FPClass
>From 22f39769356f9b844fb48a6ba22d42935f73e891 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Sat, 26 Apr 2025 19:21:25 +0200
Subject: [PATCH 2/3] !fixup update affected patterns w/ explicit cast, update
calling conventions
---
llvm/lib/Target/X86/X86CallingConv.td | 11 ++-
llvm/lib/Target/X86/X86InstrAVX10.td | 10 +-
llvm/lib/Target/X86/X86InstrAVX512.td | 110 ++++++++++-----------
llvm/lib/Target/X86/X86InstrVecCompiler.td | 4 +-
llvm/lib/Target/X86/X86RegisterInfo.td | 2 +-
5 files changed, 70 insertions(+), 67 deletions(-)
diff --git a/llvm/lib/Target/X86/X86CallingConv.td b/llvm/lib/Target/X86/X86CallingConv.td
index 0d087e057a2bd..d981c010b338d 100644
--- a/llvm/lib/Target/X86/X86CallingConv.td
+++ b/llvm/lib/Target/X86/X86CallingConv.td
@@ -364,6 +364,7 @@ def RetCC_X86_32_VectorCall : CallingConv<[
def RetCC_X86_64_C : CallingConv<[
// The X86-64 calling convention always returns FP values in XMM0.
CCIfType<[f16], CCAssignToReg<[XMM0, XMM1]>>,
+ CCIfType<[bf16], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f32], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f64], CCAssignToReg<[XMM0, XMM1]>>,
CCIfType<[f128], CCAssignToReg<[XMM0, XMM1]>>,
@@ -569,6 +570,10 @@ def CC_X86_64_C : CallingConv<[
CCIfSubtarget<"hasSSE1()",
CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,
+ // The first 8 128-bits bf16 arguments are passed in XMM registers (part of AVX-512_BF16).
+ CCIfType<[bf16], CCIfSubtarget<"hasAVX512()",
+ CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,
+
// The first 8 256-bit vector arguments are passed in YMM registers, unless
// this is a vararg function.
// FIXME: This isn't precisely correct; the x86-64 ABI document says that
@@ -586,7 +591,7 @@ def CC_X86_64_C : CallingConv<[
// Integer/FP values get stored in stack slots that are 8 bytes in size and
// 8-byte aligned if there are no more registers to hold them.
- CCIfType<[i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>,
+ CCIfType<[i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>,
// Long doubles get stack slots whose size and alignment depends on the
// subtarget.
@@ -649,7 +654,7 @@ def CC_X86_Win64_C : CallingConv<[
CCIfType<[f64], CCIfNotSubtarget<"hasSSE1()", CCBitConvertToType<i64>>>,
// The first 4 FP/Vector arguments are passed in XMM registers.
- CCIfType<[f16, f32, f64],
+ CCIfType<[bf16, f16, f32, f64],
CCAssignToRegWithShadow<[XMM0, XMM1, XMM2, XMM3],
[RCX , RDX , R8 , R9 ]>>,
@@ -672,7 +677,7 @@ def CC_X86_Win64_C : CallingConv<[
// Integer/FP values get stored in stack slots that are 8 bytes in size and
// 8-byte aligned if there are no more registers to hold them.
- CCIfType<[i8, i16, i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>
+ CCIfType<[i8, i16, i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>
]>;
def CC_X86_Win64_VectorCall : CallingConv<[
diff --git a/llvm/lib/Target/X86/X86InstrAVX10.td b/llvm/lib/Target/X86/X86InstrAVX10.td
index 2d2bf1f6c725e..b0c6f27c7a575 100644
--- a/llvm/lib/Target/X86/X86InstrAVX10.td
+++ b/llvm/lib/Target/X86/X86InstrAVX10.td
@@ -103,14 +103,14 @@ multiclass avx10_minmax_packed<string OpStr, AVX512VLVectorVTInfo VTI, SDNode Op
}
multiclass avx10_minmax_scalar<string OpStr, X86VectorVTInfo _, SDNode OpNode,
- SDNode OpNodeSAE> {
+ SDNode OpNodeSAE, ValueType CT> {
let ExeDomain = _.ExeDomain, Predicates = [HasAVX10_2] in {
let mayRaiseFPException = 1 in {
let isCodeGenOnly = 1 in {
def rri : AVX512Ii8<0x53, MRMSrcReg, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.FRC:$src2, i32u8imm:$src3),
!strconcat(OpStr, "\t{$src3, $src2, $src1|$src1, $src2, $src3}"),
- [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2, (i32 timm:$src3)))]>,
+ [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2), (i32 timm:$src3)))]>,
Sched<[WriteFMAX]>;
def rmi : AVX512Ii8<0x53, MRMSrcMem, (outs _.FRC:$dst),
@@ -165,11 +165,11 @@ defm VMINMAXPS : avx10_minmax_packed<"vminmaxps", avx512vl_f32_info, X86vminmax>
avx10_minmax_packed_sae<"vminmaxps", avx512vl_f32_info, X86vminmaxSae>,
AVX512PDIi8Base, TA, EVEX_CD8<32, CD8VF>;
-defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae, f64>,
AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<64, CD8VT1>, REX_W;
-defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae, f16>,
AVX512PSIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<16, CD8VT1>, TA;
-defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae, f32>,
AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<32, CD8VT1>;
//-------------------------------------------------
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 5705094dca55b..ebaba0ac29799 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -1942,7 +1942,7 @@ defm VPBLENDMW : blendmask_bw<0x66, "vpblendmw", SchedWriteVarBlend,
// avx512_cmp_scalar - AVX512 CMPSS and CMPSD
-multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
+multiclass avx512_cmp_scalar<X86VectorVTInfo _, ValueType CT, SDNode OpNode, SDNode OpNodeSAE,
PatFrag OpNode_su, PatFrag OpNodeSAE_su,
X86FoldableSchedWrite sched> {
defm rri : AVX512_maskable_cmp<0xC2, MRMSrcReg, _,
@@ -1983,8 +1983,8 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
(outs _.KRC:$dst), (ins _.FRC:$src1, _.FRC:$src2, u8imm:$cc),
!strconcat("vcmp", _.Suffix,
"\t{$cc, $src2, $src1, $dst|$dst, $src1, $src2, $cc}"),
- [(set _.KRC:$dst, (OpNode _.FRC:$src1,
- _.FRC:$src2,
+ [(set _.KRC:$dst, (OpNode (CT _.FRC:$src1),
+ (CT _.FRC:$src2),
timm:$cc))]>,
EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC;
def rmi : AVX512Ii8<0xC2, MRMSrcMem,
@@ -2002,16 +2002,16 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
let Predicates = [HasAVX512] in {
let ExeDomain = SSEPackedSingle in
- defm VCMPSSZ : avx512_cmp_scalar<f32x_info, X86cmpms, X86cmpmsSAE,
+ defm VCMPSSZ : avx512_cmp_scalar<f32x_info, f32, X86cmpms, X86cmpmsSAE,
X86cmpms_su, X86cmpmsSAE_su,
SchedWriteFCmp.Scl>, AVX512XSIi8Base;
let ExeDomain = SSEPackedDouble in
- defm VCMPSDZ : avx512_cmp_scalar<f64x_info, X86cmpms, X86cmpmsSAE,
+ defm VCMPSDZ : avx512_cmp_scalar<f64x_info, f64, X86cmpms, X86cmpmsSAE,
X86cmpms_su, X86cmpmsSAE_su,
SchedWriteFCmp.Scl>, AVX512XDIi8Base, REX_W;
}
let Predicates = [HasFP16], ExeDomain = SSEPackedSingle in
- defm VCMPSHZ : avx512_cmp_scalar<f16x_info, X86cmpms, X86cmpmsSAE,
+ defm VCMPSHZ : avx512_cmp_scalar<f16x_info, f16, X86cmpms, X86cmpmsSAE,
X86cmpms_su, X86cmpmsSAE_su,
SchedWriteFCmp.Scl>, AVX512XSIi8Base, TA;
@@ -2456,11 +2456,6 @@ let Predicates = [HasFP16] in {
(VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
}
-let Predicates = [HasAVX512, HasBF16] in {
- def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
- def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
-}
-
// ----------------------------------------------------------------
// FPClass
@@ -3908,7 +3903,7 @@ def : Pat<(f64 (bitconvert VK64:$src)),
//===----------------------------------------------------------------------===//
multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
- X86VectorVTInfo _, Predicate prd = HasAVX512> {
+ X86VectorVTInfo _, ValueType CT, Predicate prd = HasAVX512> {
let Predicates = !if (!eq (prd, HasFP16), [HasFP16], [prd, OptForSize]) in
def rr : AVX512PI<0x10, MRMSrcReg, (outs _.RC:$dst),
(ins _.RC:$src1, _.RC:$src2),
@@ -3960,7 +3955,7 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
}
def mr: AVX512PI<0x11, MRMDestMem, (outs), (ins _.ScalarMemOp:$dst, _.FRC:$src),
!strconcat(asm, "\t{$src, $dst|$dst, $src}"),
- [(store _.FRC:$src, addr:$dst)], _.ExeDomain>,
+ [(store (CT _.FRC:$src), addr:$dst)], _.ExeDomain>,
EVEX, Sched<[WriteFStore]>;
let mayStore = 1, hasSideEffects = 0 in
def mrk: AVX512PI<0x11, MRMDestMem, (outs),
@@ -3970,13 +3965,13 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
}
}
-defm VMOVSSZ : avx512_move_scalar<"vmovss", X86Movss, X86vzload32, f32x_info>,
+defm VMOVSSZ : avx512_move_scalar<"vmovss", X86Movss, X86vzload32, f32x_info, f32>,
VEX_LIG, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VMOVSDZ : avx512_move_scalar<"vmovsd", X86Movsd, X86vzload64, f64x_info>,
+defm VMOVSDZ : avx512_move_scalar<"vmovsd", X86Movsd, X86vzload64, f64x_info, f64>,
VEX_LIG, TB, XD, REX_W, EVEX_CD8<64, CD8VT1>;
-defm VMOVSHZ : avx512_move_scalar<"vmovsh", X86Movsh, X86vzload16, f16x_info,
+defm VMOVSHZ : avx512_move_scalar<"vmovsh", X86Movsh, X86vzload16, f16x_info, f16,
HasFP16>,
VEX_LIG, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
@@ -5364,7 +5359,7 @@ defm : avx512_logical_lowering_types<"VPANDN", X86andnp>;
//===----------------------------------------------------------------------===//
multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
- SDPatternOperator OpNode, SDNode VecNode,
+ ValueType CT, SDPatternOperator OpNode, SDNode VecNode,
X86FoldableSchedWrite sched, bit IsCommutable> {
let ExeDomain = _.ExeDomain, Uses = [MXCSR], mayRaiseFPException = 1 in {
defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
@@ -5383,7 +5378,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.FRC:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
- [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+ [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
@@ -5408,8 +5403,8 @@ multiclass avx512_fp_scalar_round<bits<8> opc, string OpcodeStr,X86VectorVTInfo
EVEX_B, EVEX_RC, Sched<[sched]>;
}
multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
- SDPatternOperator OpNode, SDNode VecNode, SDNode SaeNode,
- X86FoldableSchedWrite sched, bit IsCommutable> {
+ ValueType CT, SDPatternOperator OpNode, SDNode VecNode,
+ SDNode SaeNode, X86FoldableSchedWrite sched, bit IsCommutable> {
let ExeDomain = _.ExeDomain in {
defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
(ins _.RC:$src1, _.RC:$src2), OpcodeStr,
@@ -5429,7 +5424,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.FRC:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
- [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+ [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
Sched<[sched]> {
let isCommutable = IsCommutable;
}
@@ -5453,18 +5448,18 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
multiclass avx512_binop_s_round<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
SDNode VecNode, SDNode RndNode,
X86SchedWriteSizes sched, bit IsCommutable> {
- defm SSZ : avx512_fp_scalar<opc, OpcodeStr#"ss", f32x_info, OpNode, VecNode,
+ defm SSZ : avx512_fp_scalar<opc, OpcodeStr#"ss", f32x_info, f32, OpNode, VecNode,
sched.PS.Scl, IsCommutable>,
avx512_fp_scalar_round<opc, OpcodeStr#"ss", f32x_info, RndNode,
sched.PS.Scl>,
TB, XS, EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>;
- defm SDZ : avx512_fp_scalar<opc, OpcodeStr#"sd", f64x_info, OpNode, VecNode,
+ defm SDZ : avx512_fp_scalar<opc, OpcodeStr#"sd", f64x_info, f64, OpNode, VecNode,
sched.PD.Scl, IsCommutable>,
avx512_fp_scalar_round<opc, OpcodeStr#"sd", f64x_info, RndNode,
sched.PD.Scl>,
TB, XD, REX_W, EVEX, VVVV, VEX_LIG, EVEX_CD8<64, CD8VT1>;
let Predicates = [HasFP16] in
- defm SHZ : avx512_fp_scalar<opc, OpcodeStr#"sh", f16x_info, OpNode,
+ defm SHZ : avx512_fp_scalar<opc, OpcodeStr#"sh", f16x_info, f16, OpNode,
VecNode, sched.PH.Scl, IsCommutable>,
avx512_fp_scalar_round<opc, OpcodeStr#"sh", f16x_info, RndNode,
sched.PH.Scl>,
@@ -5474,14 +5469,14 @@ multiclass avx512_binop_s_round<bits<8> opc, string OpcodeStr, SDPatternOperator
multiclass avx512_binop_s_sae<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
SDNode VecNode, SDNode SaeNode,
X86SchedWriteSizes sched, bit IsCommutable> {
- defm SSZ : avx512_fp_scalar_sae<opc, OpcodeStr#"ss", f32x_info, OpNode,
+ defm SSZ : avx512_fp_scalar_sae<opc, OpcodeStr#"ss", f32x_info, f32, OpNode,
VecNode, SaeNode, sched.PS.Scl, IsCommutable>,
TB, XS, EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>;
- defm SDZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sd", f64x_info, OpNode,
+ defm SDZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sd", f64x_info, f64, OpNode,
VecNode, SaeNode, sched.PD.Scl, IsCommutable>,
TB, XD, REX_W, EVEX, VVVV, VEX_LIG, EVEX_CD8<64, CD8VT1>;
let Predicates = [HasFP16] in {
- defm SHZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sh", f16x_info, OpNode,
+ defm SHZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sh", f16x_info, f16, OpNode,
VecNode, SaeNode, sched.PH.Scl, IsCommutable>,
T_MAP5, XS, EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>;
}
@@ -5502,13 +5497,13 @@ defm VMAX : avx512_binop_s_sae<0x5F, "vmax", X86any_fmax, X86fmaxs, X86fmaxSAEs,
// MIN/MAX nodes are commutable under "unsafe-fp-math". In this case we use
// X86fminc and X86fmaxc instead of X86fmin and X86fmax
multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
- X86VectorVTInfo _, SDNode OpNode,
+ X86VectorVTInfo _, ValueType CT, SDNode OpNode,
X86FoldableSchedWrite sched> {
let isCodeGenOnly = 1, Predicates = [HasAVX512], ExeDomain = _.ExeDomain in {
def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
(ins _.FRC:$src1, _.FRC:$src2),
OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
- [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+ [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
Sched<[sched]> {
let isCommutable = 1;
}
@@ -5520,29 +5515,29 @@ multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
Sched<[sched.Folded, sched.ReadAfterFold]>;
}
}
-defm VMINCSSZ : avx512_comutable_binop_s<0x5D, "vminss", f32x_info, X86fminc,
+defm VMINCSSZ : avx512_comutable_binop_s<0x5D, "vminss", f32x_info, f32, X86fminc,
SchedWriteFCmp.Scl>, TB, XS,
EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>, SIMD_EXC;
-defm VMINCSDZ : avx512_comutable_binop_s<0x5D, "vminsd", f64x_info, X86fminc,
+defm VMINCSDZ : avx512_comutable_binop_s<0x5D, "vminsd", f64x_info, f64, X86fminc,
SchedWriteFCmp.Scl>, TB, XD,
REX_W, EVEX, VVVV, VEX_LIG,
EVEX_CD8<64, CD8VT1>, SIMD_EXC;
-defm VMAXCSSZ : avx512_comutable_binop_s<0x5F, "vmaxss", f32x_info, X86fmaxc,
+defm VMAXCSSZ : avx512_comutable_binop_s<0x5F, "vmaxss", f32x_info, f32, X86fmaxc,
SchedWriteFCmp.Scl>, TB, XS,
EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>, SIMD_EXC;
-defm VMAXCSDZ : avx512_comutable_binop_s<0x5F, "vmaxsd", f64x_info, X86fmaxc,
+defm VMAXCSDZ : avx512_comutable_binop_s<0x5F, "vmaxsd", f64x_info, f64, X86fmaxc,
SchedWriteFCmp.Scl>, TB, XD,
REX_W, EVEX, VVVV, VEX_LIG,
EVEX_CD8<64, CD8VT1>, SIMD_EXC;
-defm VMINCSHZ : avx512_comutable_binop_s<0x5D, "vminsh", f16x_info, X86fminc,
+defm VMINCSHZ : avx512_comutable_binop_s<0x5D, "vminsh", f16x_info, f16, X86fminc,
SchedWriteFCmp.Scl>, T_MAP5, XS,
EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>, SIMD_EXC;
-defm VMAXCSHZ : avx512_comutable_binop_s<0x5F, "vmaxsh", f16x_info, X86fmaxc,
+defm VMAXCSHZ : avx512_comutable_binop_s<0x5F, "vmaxsh", f16x_info, f16, X86fmaxc,
SchedWriteFCmp.Scl>, T_MAP5, XS,
EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>, SIMD_EXC;
@@ -7565,15 +7560,15 @@ def : Pat<(v2f64 (X86Movsd
// Convert float/double to signed/unsigned int 32/64 with truncation
multiclass avx512_cvt_s_all<bits<8> opc, string asm, X86VectorVTInfo _SrcRC,
- X86VectorVTInfo _DstRC, SDPatternOperator OpNode,
- SDNode OpNodeInt, SDNode OpNodeSAE,
+ X86VectorVTInfo _DstRC, ValueType CT,
+ SDPatternOperator OpNode, SDNode OpNodeInt, SDNode OpNodeSAE,
X86FoldableSchedWrite sched, string aliasStr,
Predicate prd = HasAVX512> {
let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
let isCodeGenOnly = 1 in {
def rr : AVX512<opc, MRMSrcReg, (outs _DstRC.RC:$dst), (ins _SrcRC.FRC:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
- [(set _DstRC.RC:$dst, (OpNode _SrcRC.FRC:$src))]>,
+ [(set _DstRC.RC:$dst, (OpNode (CT _SrcRC.FRC:$src)))]>,
EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
def rm : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.ScalarMemOp:$src),
!strconcat(asm,"\t{$src, $dst|$dst, $src}"),
@@ -7607,29 +7602,29 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
_SrcRC.IntScalarMemOp:$src), 0, "att">;
}
-defm VCVTTSS2SIZ: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i32x_info,
+defm VCVTTSS2SIZ: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i32x_info, f32,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
"{l}">, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSS2SI64Z: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i64x_info,
+defm VCVTTSS2SI64Z: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i64x_info, f32,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
"{q}">, REX_W, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSD2SIZ: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i32x_info,
+defm VCVTTSD2SIZ: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i32x_info, f64,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSD2I,
"{l}">, TB, XD, EVEX_CD8<64, CD8VT1>;
-defm VCVTTSD2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i64x_info,
+defm VCVTTSD2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i64x_info, f64,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSD2I,
"{q}">, REX_W, TB, XD, EVEX_CD8<64, CD8VT1>;
-defm VCVTTSS2USIZ: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i32x_info,
+defm VCVTTSS2USIZ: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i32x_info, f32,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
"{l}">, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSS2USI64Z: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i64x_info,
+defm VCVTTSS2USI64Z: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i64x_info, f32,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
"{q}">, TB, XS,REX_W, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSD2USIZ: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i32x_info,
+defm VCVTTSD2USIZ: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i32x_info, f64,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSD2I,
"{l}">, TB, XD, EVEX_CD8<64, CD8VT1>;
-defm VCVTTSD2USI64Z: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i64x_info,
+defm VCVTTSD2USI64Z: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i64x_info, f64,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSD2I,
"{q}">, TB, XD, REX_W, EVEX_CD8<64, CD8VT1>;
@@ -7747,15 +7742,15 @@ def : Pat<(f32 (any_fpround FR64X:$src)),
(VCVTSD2SSZrr (f32 (IMPLICIT_DEF)), FR64X:$src)>,
Requires<[HasAVX512]>;
-def : Pat<(f32 (any_fpextend FR16X:$src)),
- (VCVTSH2SSZrr (f32 (IMPLICIT_DEF)), FR16X:$src)>,
+def : Pat<(f32 (any_fpextend (f16 FR16X:$src))),
+ (VCVTSH2SSZrr (f32 (IMPLICIT_DEF)), (f16 FR16X:$src))>,
Requires<[HasFP16]>;
def : Pat<(f32 (any_fpextend (loadf16 addr:$src))),
(VCVTSH2SSZrm (f32 (IMPLICIT_DEF)), addr:$src)>,
Requires<[HasFP16, OptForSize]>;
-def : Pat<(f64 (any_fpextend FR16X:$src)),
- (VCVTSH2SDZrr (f64 (IMPLICIT_DEF)), FR16X:$src)>,
+def : Pat<(f64 (any_fpextend (f16 FR16X:$src))),
+ (VCVTSH2SDZrr (f64 (IMPLICIT_DEF)), (f16 FR16X:$src))>,
Requires<[HasFP16]>;
def : Pat<(f64 (any_fpextend (loadf16 addr:$src))),
(VCVTSH2SDZrm (f64 (IMPLICIT_DEF)), addr:$src)>,
@@ -11587,6 +11582,9 @@ let Predicates = [HasBWI], AddedComplexity = -10 in {
def : Pat<(f16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWZrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
}
+def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+
//===----------------------------------------------------------------------===//
// VSHUFPS - VSHUFPD Operations
//===----------------------------------------------------------------------===//
@@ -12950,9 +12948,9 @@ def VMOVWmr : AVX512<0x7E, MRMDestMem, (outs),
(iPTR 0))), addr:$dst)]>,
T_MAP5, PD, EVEX, EVEX_CD8<16, CD8VT1>, Sched<[WriteFStore]>;
-def : Pat<(i16 (bitconvert FR16X:$src)),
+def : Pat<(i16 (bitconvert (f16 FR16X:$src))),
(i16 (EXTRACT_SUBREG
- (VMOVSH2Wrr (COPY_TO_REGCLASS FR16X:$src, VR128X)),
+ (VMOVSH2Wrr (COPY_TO_REGCLASS (f16 FR16X:$src), VR128X)),
sub_16bit))>;
def : Pat<(i16 (extractelt (v8i16 VR128X:$src), (iPTR 0))),
(i16 (EXTRACT_SUBREG (VMOVSH2Wrr VR128X:$src), sub_16bit))>;
@@ -13281,16 +13279,16 @@ defm VCVTSH2USI64Z: avx512_cvt_s_int_round<0x79, f16x_info, i64x_info, X86cvts2u
X86cvts2usiRnd, WriteCvtSS2I, "cvtsh2usi", "{q}", HasFP16>,
T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2SIZ: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i32x_info,
+defm VCVTTSH2SIZ: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i32x_info, f16,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
"{l}", HasFP16>, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i64x_info,
+defm VCVTTSH2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i64x_info, f16,
any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
"{q}", HasFP16>, REX_W, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2USIZ: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i32x_info,
+defm VCVTTSH2USIZ: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i32x_info, f16,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
"{l}", HasFP16>, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info,
+defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info, f16,
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
"{q}", HasFP16>, T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
diff --git a/llvm/lib/Target/X86/X86InstrVecCompiler.td b/llvm/lib/Target/X86/X86InstrVecCompiler.td
index 122627ca45d31..68609b66f7256 100644
--- a/llvm/lib/Target/X86/X86InstrVecCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrVecCompiler.td
@@ -47,8 +47,8 @@ let Predicates = [NoVLX] in {
}
let Predicates = [HasVLX] in {
- def : Pat<(v8f16 (scalar_to_vector FR16X:$src)),
- (COPY_TO_REGCLASS FR16X:$src, VR128X)>;
+ def : Pat<(v8f16 (scalar_to_vector (f16 FR16X:$src))),
+ (COPY_TO_REGCLASS (f16 FR16X:$src), VR128X)>;
// Implicitly promote a 32-bit scalar to a vector.
def : Pat<(v4f32 (scalar_to_vector FR32X:$src)),
(COPY_TO_REGCLASS FR32X:$src, VR128X)>;
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td
index 48459b3aca508..6dc29a8a51065 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.td
+++ b/llvm/lib/Target/X86/X86RegisterInfo.td
@@ -806,7 +806,7 @@ def FR32X : RegisterClass<"X86", [f32], 32, (sequence "XMM%u", 0, 31)>;
def FR64X : RegisterClass<"X86", [f64], 64, (add FR32X)>;
-def FR16X : RegisterClass<"X86", [f16], 16, (add FR32X)> {let Size = 32;}
+def FR16X : RegisterClass<"X86", [f16, bf16], 16, (add FR32X)> {let Size = 32;}
// Extended VR128 and VR256 for AVX-512 instructions
def VR128X : RegisterClass<"X86", [v4f32, v2f64, v8f16, v8bf16, v16i8, v8i16, v4i32, v2i64, f128],
>From b56a6ba5184836f89fc398b18a22355a37fad146 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Sat, 26 Apr 2025 19:24:37 +0200
Subject: [PATCH 3/3] !fixup custom hooks for trunc/ext, add missing
instruction selection patterns
---
llvm/lib/Target/X86/X86FastISel.cpp | 5 +-
llvm/lib/Target/X86/X86ISelLowering.cpp | 55 +++++++++++++++++++-
llvm/lib/Target/X86/X86InstrAVX512.td | 23 +++++++-
llvm/lib/Target/X86/X86InstrFragmentsSIMD.td | 4 ++
llvm/lib/Target/X86/X86InstrSSE.td | 22 ++++++--
5 files changed, 102 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Target/X86/X86FastISel.cpp b/llvm/lib/Target/X86/X86FastISel.cpp
index afe5d7f4bc7ed..2136da48dcf84 100644
--- a/llvm/lib/Target/X86/X86FastISel.cpp
+++ b/llvm/lib/Target/X86/X86FastISel.cpp
@@ -147,7 +147,8 @@ class X86FastISel final : public FastISel {
/// computed in an SSE register, not on the X87 floating point stack.
bool isScalarFPTypeInSSEReg(EVT VT) const {
return (VT == MVT::f64 && Subtarget->hasSSE2()) ||
- (VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16;
+ (VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16 ||
+ VT == MVT::bf16;
}
bool isTypeLegal(Type *Ty, MVT &VT, bool AllowI1 = false);
@@ -2283,6 +2284,7 @@ bool X86FastISel::X86FastEmitPseudoSelect(MVT RetVT, const Instruction *I) {
case MVT::i16: Opc = X86::CMOV_GR16; break;
case MVT::i32: Opc = X86::CMOV_GR32; break;
case MVT::f16:
+ case MVT::bf16:
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR16X : X86::CMOV_FR16; break;
case MVT::f32:
Opc = Subtarget->hasAVX512() ? X86::CMOV_FR32X : X86::CMOV_FR32; break;
@@ -3972,6 +3974,7 @@ Register X86FastISel::fastMaterializeFloatZero(const ConstantFP *CF) {
switch (VT.SimpleTy) {
default: return 0;
case MVT::f16:
+ case MVT::bf16:
Opc = HasAVX512 ? X86::AVX512_FsFLD0SH : X86::FsFLD0SH;
break;
case MVT::f32:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ad170f49aebb5..daa19ac43d10f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -678,9 +678,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
// non-optsize case.
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
- // Set the operation action Custom for bitcast to do the customization
- // later.
+ // Set the operation action Custom for bitcast and conversion, and fall-back
+ // to software libcalls for the latter for the now.
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+ setOperationAction(ISD::FP_EXTEND, MVT::bf16, Custom);
+ setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
for (auto VT : { MVT::f32, MVT::f64 }) {
// Use ANDPD to simulate FABS.
@@ -22066,6 +22068,31 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
return Res;
}
+ if (SVT == MVT::bf16 && VT == MVT::f32) {
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
+
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Node = In;
+ Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
+ Args.push_back(Entry);
+
+ SDValue Callee =
+ DAG.getExternalSymbol(getLibcallName(RTLIB::FPEXT_BF16_F32),
+ getPointerTy(DAG.getDataLayout()));
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+ CallingConv::C, EVT(VT).getTypeForEVT(*DAG.getContext()), Callee,
+ std::move(Args));
+
+ SDValue Res;
+ std::tie(Res, Chain) = LowerCallTo(CLI);
+ if (IsStrict)
+ Res = DAG.getMergeValues({Res, Chain}, DL);
+
+ return Res;
+ }
+
if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
return Op;
@@ -22149,6 +22176,30 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
Subtarget.hasAVXNECONVERT()))
return Op;
+
+ // Need a soft libcall if the target has not BF16.
+ if (SVT.getScalarType() == MVT::f32 || SVT.getScalarType() == MVT::f64) {
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
+
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Node = In;
+ Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
+ Args.push_back(Entry);
+ SDValue Callee = DAG.getExternalSymbol(
+ getLibcallName(SVT == MVT::f64 ? RTLIB::FPROUND_F64_BF16
+ : RTLIB::FPROUND_F32_BF16),
+ getPointerTy(DAG.getDataLayout()));
+ CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+ CallingConv::C, EVT(MVT::bf16).getTypeForEVT(*DAG.getContext()),
+ Callee, std::move(Args));
+
+ SDValue Res;
+ std::tie(Res, Chain) = LowerCallTo(CLI);
+ return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res;
+ }
+
return SDValue();
}
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index ebaba0ac29799..d476f92d34340 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -11585,6 +11585,13 @@ let Predicates = [HasBWI], AddedComplexity = -10 in {
def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+let Predicates = [HasBWI, HasBF16] in {
+ def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWZrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+ def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWZmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+ def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWZrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128X)), 0), sub_16bit)>;
+ def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWZrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
+}
+
//===----------------------------------------------------------------------===//
// VSHUFPS - VSHUFPD Operations
//===----------------------------------------------------------------------===//
@@ -12809,9 +12816,17 @@ let Predicates = [HasBF16, HasVLX] in {
}
let Predicates = [HasBF16] in {
+ def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
+ (VPBROADCASTWrm addr:$src)>;
+ def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
+ (VPBROADCASTWYrm addr:$src)>;
def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
(VPBROADCASTWZrm addr:$src)>;
+ def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128:$src))),
+ (VPBROADCASTWrr VR128:$src)>;
+ def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128:$src))),
+ (VPBROADCASTWYrr VR128:$src)>;
def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
(VPBROADCASTWZrr VR128X:$src)>;
@@ -12819,7 +12834,13 @@ let Predicates = [HasBF16] in {
(VCVTNEPS2BF16Zrr VR512:$src)>;
def : Pat<(v16bf16 (X86vfpround (loadv16f32 addr:$src))),
(VCVTNEPS2BF16Zrm addr:$src)>;
- // TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
+
+ def : Pat<(v8bf16 (X86VBroadcast (bf16 FR16X:$src))),
+ (VPBROADCASTWrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
+ def : Pat<(v16bf16 (X86VBroadcast (bf16 FR16X:$src))),
+ (VPBROADCASTWYrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
+ def : Pat<(v32bf16 (X86VBroadcast (bf16 FR16X:$src))),
+ (VPBROADCASTWZrr (COPY_TO_REGCLASS FR16X:$src, VR128X))>;
}
let Constraints = "$src1 = $dst" in {
diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
index de70570481fc2..b81119c1f6d26 100644
--- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -1193,6 +1193,10 @@ def fp16imm0 : PatLeaf<(f16 fpimm), [{
return N->isExactlyValue(+0.0);
}]>;
+def bfp16imm0 : PatLeaf<(bf16 fpimm), [{
+ return N->isExactlyValue(+0.0);
+}]>;
+
def fp32imm0 : PatLeaf<(f32 fpimm), [{
return N->isExactlyValue(+0.0);
}]>;
diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index 49a62fd3422d0..e2aa90b59bbfa 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -4048,6 +4048,19 @@ let Predicates = [HasAVX, NoBWI] in {
def : Pat<(f16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
}
+let Predicates = [UseSSE2] in {
+ def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (PINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+ def : Pat<(store bf16:$src, addr:$dst), (MOV16mr addr:$dst, (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit))>;
+ def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
+ def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (PINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
+}
+
+let Predicates = [HasAVX, NoBWI] in {
+ def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+ def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
+ def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
+}
+
//===---------------------------------------------------------------------===//
// SSE2 - Packed Mask Creation
//===---------------------------------------------------------------------===//
@@ -5279,12 +5292,15 @@ let Predicates = [HasAVX, NoBWI] in
defm PEXTRW : SS41I_extract16<0x15, "pextrw">;
-let Predicates = [UseSSE41] in
+let Predicates = [UseSSE41] in {
def : Pat<(store f16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
+ def : Pat<(store bf16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+}
-let Predicates = [HasAVX, NoBWI] in
+let Predicates = [HasAVX, NoBWI] in {
def : Pat<(store f16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
-
+ def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+}
/// SS41I_extract32 - SSE 4.1 extract 32 bits to int reg or memory destination
multiclass SS41I_extract32<bits<8> opc, string OpcodeStr> {
More information about the llvm-commits
mailing list