[llvm] [X86] Add support for `__bf16` to `f16` conversion (PR #134859)

Antonio Frighetto via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 26 10:26:44 PDT 2025


https://github.com/antoniofrighetto updated https://github.com/llvm/llvm-project/pull/134859

>From 7f2a1120c9fc31912e7d98de99dd2a04f368b5c6 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Tue, 8 Apr 2025 16:06:44 +0200
Subject: [PATCH 1/3] [X86] Add support for `__bf16` to `f16` conversion

`bf16` is a typedef short type introduced in AVX-512_BF16 and
should be able to leverage SSE/AVX registers used for `f16`.

Fixes: https://github.com/llvm/llvm-project/issues/134222.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp | 12 +++++++++++-
 llvm/lib/Target/X86/X86InstrAVX512.td   |  5 +++++
 2 files changed, 16 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a4381b99dbae0..ad170f49aebb5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -661,10 +661,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
   };
 
   if (!Subtarget.useSoftFloat() && Subtarget.hasSSE2()) {
-    // f16, f32 and f64 use SSE.
+    // f16, bf16, f32 and f64 use SSE.
     // Set up the FP register classes.
     addRegisterClass(MVT::f16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
                                                      : &X86::FR16RegClass);
+    addRegisterClass(MVT::bf16, Subtarget.hasAVX512() ? &X86::FR16XRegClass
+                                                      : &X86::FR16RegClass);
     addRegisterClass(MVT::f32, Subtarget.hasAVX512() ? &X86::FR32XRegClass
                                                      : &X86::FR32RegClass);
     addRegisterClass(MVT::f64, Subtarget.hasAVX512() ? &X86::FR64XRegClass
@@ -676,6 +678,10 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     // non-optsize case.
     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
 
+    // Set the operation action Custom for bitcast to do the customization
+    // later.
+    setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+
     for (auto VT : { MVT::f32, MVT::f64 }) {
       // Use ANDPD to simulate FABS.
       setOperationAction(ISD::FABS, VT, Custom);
@@ -32151,6 +32157,10 @@ static SDValue LowerBITCAST(SDValue Op, const X86Subtarget &Subtarget,
     return DAG.getZExtOrTrunc(V, DL, DstVT);
   }
 
+  // Bitcasts between f16 and bf16 should be legal.
+  if (DstVT == MVT::f16 || DstVT == MVT::bf16)
+    return Op;
+
   assert((SrcVT == MVT::v2i32 || SrcVT == MVT::v4i16 || SrcVT == MVT::v8i8 ||
           SrcVT == MVT::i64) && "Unexpected VT!");
 
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 0ab94cca41425..5705094dca55b 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -2456,6 +2456,11 @@ let Predicates = [HasFP16] in {
             (VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
 }
 
+let Predicates = [HasAVX512, HasBF16] in {
+  def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+  def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+}
+
 // ----------------------------------------------------------------
 // FPClass
 

>From 22f39769356f9b844fb48a6ba22d42935f73e891 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Sat, 26 Apr 2025 19:21:25 +0200
Subject: [PATCH 2/3] !fixup update affected patterns w/ explicit cast, update
 calling conventions

---
 llvm/lib/Target/X86/X86CallingConv.td      |  11 ++-
 llvm/lib/Target/X86/X86InstrAVX10.td       |  10 +-
 llvm/lib/Target/X86/X86InstrAVX512.td      | 110 ++++++++++-----------
 llvm/lib/Target/X86/X86InstrVecCompiler.td |   4 +-
 llvm/lib/Target/X86/X86RegisterInfo.td     |   2 +-
 5 files changed, 70 insertions(+), 67 deletions(-)

diff --git a/llvm/lib/Target/X86/X86CallingConv.td b/llvm/lib/Target/X86/X86CallingConv.td
index 0d087e057a2bd..d981c010b338d 100644
--- a/llvm/lib/Target/X86/X86CallingConv.td
+++ b/llvm/lib/Target/X86/X86CallingConv.td
@@ -364,6 +364,7 @@ def RetCC_X86_32_VectorCall : CallingConv<[
 def RetCC_X86_64_C : CallingConv<[
   // The X86-64 calling convention always returns FP values in XMM0.
   CCIfType<[f16], CCAssignToReg<[XMM0, XMM1]>>,
+  CCIfType<[bf16], CCAssignToReg<[XMM0, XMM1]>>,
   CCIfType<[f32], CCAssignToReg<[XMM0, XMM1]>>,
   CCIfType<[f64], CCAssignToReg<[XMM0, XMM1]>>,
   CCIfType<[f128], CCAssignToReg<[XMM0, XMM1]>>,
@@ -569,6 +570,10 @@ def CC_X86_64_C : CallingConv<[
             CCIfSubtarget<"hasSSE1()",
             CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,
 
+  // The first 8 128-bits bf16 arguments are passed in XMM registers (part of AVX-512_BF16).
+  CCIfType<[bf16], CCIfSubtarget<"hasAVX512()",
+            CCAssignToReg<[XMM0, XMM1, XMM2, XMM3, XMM4, XMM5, XMM6, XMM7]>>>,
+
   // The first 8 256-bit vector arguments are passed in YMM registers, unless
   // this is a vararg function.
   // FIXME: This isn't precisely correct; the x86-64 ABI document says that
@@ -586,7 +591,7 @@ def CC_X86_64_C : CallingConv<[
 
   // Integer/FP values get stored in stack slots that are 8 bytes in size and
   // 8-byte aligned if there are no more registers to hold them.
-  CCIfType<[i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>,
+  CCIfType<[i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>,
 
   // Long doubles get stack slots whose size and alignment depends on the
   // subtarget.
@@ -649,7 +654,7 @@ def CC_X86_Win64_C : CallingConv<[
   CCIfType<[f64], CCIfNotSubtarget<"hasSSE1()", CCBitConvertToType<i64>>>,
 
   // The first 4 FP/Vector arguments are passed in XMM registers.
-  CCIfType<[f16, f32, f64],
+  CCIfType<[bf16, f16, f32, f64],
            CCAssignToRegWithShadow<[XMM0, XMM1, XMM2, XMM3],
                                    [RCX , RDX , R8  , R9  ]>>,
 
@@ -672,7 +677,7 @@ def CC_X86_Win64_C : CallingConv<[
 
   // Integer/FP values get stored in stack slots that are 8 bytes in size and
   // 8-byte aligned if there are no more registers to hold them.
-  CCIfType<[i8, i16, i32, i64, f16, f32, f64], CCAssignToStack<8, 8>>
+  CCIfType<[i8, i16, i32, i64, bf16, f16, f32, f64], CCAssignToStack<8, 8>>
 ]>;
 
 def CC_X86_Win64_VectorCall : CallingConv<[
diff --git a/llvm/lib/Target/X86/X86InstrAVX10.td b/llvm/lib/Target/X86/X86InstrAVX10.td
index 2d2bf1f6c725e..b0c6f27c7a575 100644
--- a/llvm/lib/Target/X86/X86InstrAVX10.td
+++ b/llvm/lib/Target/X86/X86InstrAVX10.td
@@ -103,14 +103,14 @@ multiclass avx10_minmax_packed<string OpStr, AVX512VLVectorVTInfo VTI, SDNode Op
 }
 
 multiclass avx10_minmax_scalar<string OpStr, X86VectorVTInfo _, SDNode OpNode,
-                                SDNode OpNodeSAE> {
+                                SDNode OpNodeSAE, ValueType CT> {
   let ExeDomain = _.ExeDomain, Predicates = [HasAVX10_2] in {
     let mayRaiseFPException = 1 in {
       let isCodeGenOnly = 1 in {
         def rri : AVX512Ii8<0x53, MRMSrcReg, (outs _.FRC:$dst),
                             (ins _.FRC:$src1, _.FRC:$src2, i32u8imm:$src3),
                              !strconcat(OpStr, "\t{$src3, $src2, $src1|$src1, $src2, $src3}"),
-                             [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2, (i32 timm:$src3)))]>,
+                             [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2), (i32 timm:$src3)))]>,
                        Sched<[WriteFMAX]>;
 
         def rmi : AVX512Ii8<0x53, MRMSrcMem, (outs _.FRC:$dst),
@@ -165,11 +165,11 @@ defm VMINMAXPS : avx10_minmax_packed<"vminmaxps", avx512vl_f32_info, X86vminmax>
                  avx10_minmax_packed_sae<"vminmaxps", avx512vl_f32_info, X86vminmaxSae>,
                  AVX512PDIi8Base, TA, EVEX_CD8<32, CD8VF>;
 
-defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSD : avx10_minmax_scalar<"vminmaxsd", v2f64x_info, X86vminmaxs, X86vminmaxsSae, f64>,
                  AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<64, CD8VT1>, REX_W;
-defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSH : avx10_minmax_scalar<"vminmaxsh", v8f16x_info, X86vminmaxs, X86vminmaxsSae, f16>,
                  AVX512PSIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<16, CD8VT1>, TA;
-defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae>,
+defm VMINMAXSS : avx10_minmax_scalar<"vminmaxss", v4f32x_info, X86vminmaxs, X86vminmaxsSae, f32>,
                  AVX512AIi8Base, VEX_LIG, EVEX, VVVV, EVEX_CD8<32, CD8VT1>;
 
 //-------------------------------------------------
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 5705094dca55b..ebaba0ac29799 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -1942,7 +1942,7 @@ defm VPBLENDMW : blendmask_bw<0x66, "vpblendmw", SchedWriteVarBlend,
 
 // avx512_cmp_scalar - AVX512 CMPSS and CMPSD
 
-multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
+multiclass avx512_cmp_scalar<X86VectorVTInfo _, ValueType CT, SDNode OpNode, SDNode OpNodeSAE,
                              PatFrag OpNode_su, PatFrag OpNodeSAE_su,
                              X86FoldableSchedWrite sched> {
   defm  rri  : AVX512_maskable_cmp<0xC2, MRMSrcReg, _,
@@ -1983,8 +1983,8 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
                         (outs _.KRC:$dst), (ins _.FRC:$src1, _.FRC:$src2, u8imm:$cc),
                         !strconcat("vcmp", _.Suffix,
                                    "\t{$cc, $src2, $src1, $dst|$dst, $src1, $src2, $cc}"),
-                        [(set _.KRC:$dst, (OpNode _.FRC:$src1,
-                                                  _.FRC:$src2,
+                        [(set _.KRC:$dst, (OpNode (CT _.FRC:$src1),
+                                                  (CT _.FRC:$src2),
                                                   timm:$cc))]>,
                         EVEX, VVVV, VEX_LIG, Sched<[sched]>, SIMD_EXC;
     def rmi : AVX512Ii8<0xC2, MRMSrcMem,
@@ -2002,16 +2002,16 @@ multiclass avx512_cmp_scalar<X86VectorVTInfo _, SDNode OpNode, SDNode OpNodeSAE,
 
 let Predicates = [HasAVX512] in {
   let ExeDomain = SSEPackedSingle in
-  defm VCMPSSZ : avx512_cmp_scalar<f32x_info, X86cmpms, X86cmpmsSAE,
+  defm VCMPSSZ : avx512_cmp_scalar<f32x_info, f32, X86cmpms, X86cmpmsSAE,
                                    X86cmpms_su, X86cmpmsSAE_su,
                                    SchedWriteFCmp.Scl>, AVX512XSIi8Base;
   let ExeDomain = SSEPackedDouble in
-  defm VCMPSDZ : avx512_cmp_scalar<f64x_info, X86cmpms, X86cmpmsSAE,
+  defm VCMPSDZ : avx512_cmp_scalar<f64x_info, f64, X86cmpms, X86cmpmsSAE,
                                    X86cmpms_su, X86cmpmsSAE_su,
                                    SchedWriteFCmp.Scl>, AVX512XDIi8Base, REX_W;
 }
 let Predicates = [HasFP16], ExeDomain = SSEPackedSingle in
-  defm VCMPSHZ : avx512_cmp_scalar<f16x_info, X86cmpms, X86cmpmsSAE,
+  defm VCMPSHZ : avx512_cmp_scalar<f16x_info, f16, X86cmpms, X86cmpmsSAE,
                                    X86cmpms_su, X86cmpmsSAE_su,
                                    SchedWriteFCmp.Scl>, AVX512XSIi8Base, TA;
 
@@ -2456,11 +2456,6 @@ let Predicates = [HasFP16] in {
             (VCMPSHZrmi FR16X:$src1, addr:$src2, (X86cmpm_imm_commute timm:$cc))>;
 }
 
-let Predicates = [HasAVX512, HasBF16] in {
-  def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
-  def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
-}
-
 // ----------------------------------------------------------------
 // FPClass
 
@@ -3908,7 +3903,7 @@ def : Pat<(f64 (bitconvert VK64:$src)),
 //===----------------------------------------------------------------------===//
 
 multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
-                              X86VectorVTInfo _, Predicate prd = HasAVX512> {
+                              X86VectorVTInfo _, ValueType CT, Predicate prd = HasAVX512> {
   let Predicates = !if (!eq (prd, HasFP16), [HasFP16], [prd, OptForSize]) in
   def rr : AVX512PI<0x10, MRMSrcReg, (outs _.RC:$dst),
              (ins _.RC:$src1, _.RC:$src2),
@@ -3960,7 +3955,7 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
   }
   def mr: AVX512PI<0x11, MRMDestMem, (outs), (ins _.ScalarMemOp:$dst, _.FRC:$src),
              !strconcat(asm, "\t{$src, $dst|$dst, $src}"),
-             [(store _.FRC:$src, addr:$dst)],  _.ExeDomain>,
+             [(store (CT _.FRC:$src), addr:$dst)],  _.ExeDomain>,
              EVEX, Sched<[WriteFStore]>;
   let mayStore = 1, hasSideEffects = 0 in
   def mrk: AVX512PI<0x11, MRMDestMem, (outs),
@@ -3970,13 +3965,13 @@ multiclass avx512_move_scalar<string asm, SDNode OpNode, PatFrag vzload_frag,
   }
 }
 
-defm VMOVSSZ : avx512_move_scalar<"vmovss", X86Movss, X86vzload32, f32x_info>,
+defm VMOVSSZ : avx512_move_scalar<"vmovss", X86Movss, X86vzload32, f32x_info, f32>,
                                   VEX_LIG, TB, XS, EVEX_CD8<32, CD8VT1>;
 
-defm VMOVSDZ : avx512_move_scalar<"vmovsd", X86Movsd, X86vzload64, f64x_info>,
+defm VMOVSDZ : avx512_move_scalar<"vmovsd", X86Movsd, X86vzload64, f64x_info, f64>,
                                   VEX_LIG, TB, XD, REX_W, EVEX_CD8<64, CD8VT1>;
 
-defm VMOVSHZ : avx512_move_scalar<"vmovsh", X86Movsh, X86vzload16, f16x_info,
+defm VMOVSHZ : avx512_move_scalar<"vmovsh", X86Movsh, X86vzload16, f16x_info, f16,
                                   HasFP16>,
                                   VEX_LIG, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
 
@@ -5364,7 +5359,7 @@ defm : avx512_logical_lowering_types<"VPANDN", X86andnp>;
 //===----------------------------------------------------------------------===//
 
 multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
-                            SDPatternOperator OpNode, SDNode VecNode,
+                            ValueType CT, SDPatternOperator OpNode, SDNode VecNode,
                             X86FoldableSchedWrite sched, bit IsCommutable> {
   let ExeDomain = _.ExeDomain, Uses = [MXCSR], mayRaiseFPException = 1 in {
   defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
@@ -5383,7 +5378,7 @@ multiclass avx512_fp_scalar<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
   def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
                          (ins _.FRC:$src1, _.FRC:$src2),
                           OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
-                          [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+                          [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
                           Sched<[sched]> {
     let isCommutable = IsCommutable;
   }
@@ -5408,8 +5403,8 @@ multiclass avx512_fp_scalar_round<bits<8> opc, string OpcodeStr,X86VectorVTInfo
                           EVEX_B, EVEX_RC, Sched<[sched]>;
 }
 multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
-                                SDPatternOperator OpNode, SDNode VecNode, SDNode SaeNode,
-                                X86FoldableSchedWrite sched, bit IsCommutable> {
+                                ValueType CT, SDPatternOperator OpNode, SDNode VecNode,
+                                SDNode SaeNode, X86FoldableSchedWrite sched, bit IsCommutable> {
   let ExeDomain = _.ExeDomain in {
   defm rr : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
                            (ins _.RC:$src1, _.RC:$src2), OpcodeStr,
@@ -5429,7 +5424,7 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
   def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
                          (ins _.FRC:$src1, _.FRC:$src2),
                           OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
-                          [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+                          [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
                           Sched<[sched]> {
     let isCommutable = IsCommutable;
   }
@@ -5453,18 +5448,18 @@ multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
 multiclass avx512_binop_s_round<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
                                 SDNode VecNode, SDNode RndNode,
                                 X86SchedWriteSizes sched, bit IsCommutable> {
-  defm SSZ : avx512_fp_scalar<opc, OpcodeStr#"ss", f32x_info, OpNode, VecNode,
+  defm SSZ : avx512_fp_scalar<opc, OpcodeStr#"ss", f32x_info, f32, OpNode, VecNode,
                               sched.PS.Scl, IsCommutable>,
              avx512_fp_scalar_round<opc, OpcodeStr#"ss", f32x_info, RndNode,
                               sched.PS.Scl>,
                               TB, XS, EVEX, VVVV, VEX_LIG,  EVEX_CD8<32, CD8VT1>;
-  defm SDZ : avx512_fp_scalar<opc, OpcodeStr#"sd", f64x_info, OpNode, VecNode,
+  defm SDZ : avx512_fp_scalar<opc, OpcodeStr#"sd", f64x_info, f64, OpNode, VecNode,
                               sched.PD.Scl, IsCommutable>,
              avx512_fp_scalar_round<opc, OpcodeStr#"sd", f64x_info, RndNode,
                               sched.PD.Scl>,
                               TB, XD, REX_W, EVEX, VVVV, VEX_LIG, EVEX_CD8<64, CD8VT1>;
   let Predicates = [HasFP16] in
-    defm SHZ : avx512_fp_scalar<opc, OpcodeStr#"sh", f16x_info, OpNode,
+    defm SHZ : avx512_fp_scalar<opc, OpcodeStr#"sh", f16x_info, f16, OpNode,
                                 VecNode, sched.PH.Scl, IsCommutable>,
                avx512_fp_scalar_round<opc, OpcodeStr#"sh", f16x_info, RndNode,
                                 sched.PH.Scl>,
@@ -5474,14 +5469,14 @@ multiclass avx512_binop_s_round<bits<8> opc, string OpcodeStr, SDPatternOperator
 multiclass avx512_binop_s_sae<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
                               SDNode VecNode, SDNode SaeNode,
                               X86SchedWriteSizes sched, bit IsCommutable> {
-  defm SSZ : avx512_fp_scalar_sae<opc, OpcodeStr#"ss", f32x_info, OpNode,
+  defm SSZ : avx512_fp_scalar_sae<opc, OpcodeStr#"ss", f32x_info, f32, OpNode,
                               VecNode, SaeNode, sched.PS.Scl, IsCommutable>,
                               TB, XS, EVEX, VVVV, VEX_LIG,  EVEX_CD8<32, CD8VT1>;
-  defm SDZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sd", f64x_info, OpNode,
+  defm SDZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sd", f64x_info, f64, OpNode,
                               VecNode, SaeNode, sched.PD.Scl, IsCommutable>,
                               TB, XD, REX_W, EVEX, VVVV, VEX_LIG, EVEX_CD8<64, CD8VT1>;
   let Predicates = [HasFP16] in {
-    defm SHZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sh", f16x_info, OpNode,
+    defm SHZ : avx512_fp_scalar_sae<opc, OpcodeStr#"sh", f16x_info, f16, OpNode,
                                 VecNode, SaeNode, sched.PH.Scl, IsCommutable>,
                                 T_MAP5, XS, EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>;
   }
@@ -5502,13 +5497,13 @@ defm VMAX : avx512_binop_s_sae<0x5F, "vmax", X86any_fmax, X86fmaxs, X86fmaxSAEs,
 // MIN/MAX nodes are commutable under "unsafe-fp-math". In this case we use
 // X86fminc and X86fmaxc instead of X86fmin and X86fmax
 multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
-                                    X86VectorVTInfo _, SDNode OpNode,
+                                    X86VectorVTInfo _, ValueType CT, SDNode OpNode,
                                     X86FoldableSchedWrite sched> {
   let isCodeGenOnly = 1, Predicates = [HasAVX512], ExeDomain = _.ExeDomain in {
   def rr : I< opc, MRMSrcReg, (outs _.FRC:$dst),
                          (ins _.FRC:$src1, _.FRC:$src2),
                           OpcodeStr#"\t{$src2, $src1, $dst|$dst, $src1, $src2}",
-                          [(set _.FRC:$dst, (OpNode _.FRC:$src1, _.FRC:$src2))]>,
+                          [(set _.FRC:$dst, (OpNode (CT _.FRC:$src1), (CT _.FRC:$src2)))]>,
                           Sched<[sched]> {
     let isCommutable = 1;
   }
@@ -5520,29 +5515,29 @@ multiclass avx512_comutable_binop_s<bits<8> opc, string OpcodeStr,
                          Sched<[sched.Folded, sched.ReadAfterFold]>;
   }
 }
-defm VMINCSSZ : avx512_comutable_binop_s<0x5D, "vminss", f32x_info, X86fminc,
+defm VMINCSSZ : avx512_comutable_binop_s<0x5D, "vminss", f32x_info, f32, X86fminc,
                                          SchedWriteFCmp.Scl>, TB, XS,
                                          EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>, SIMD_EXC;
 
-defm VMINCSDZ : avx512_comutable_binop_s<0x5D, "vminsd", f64x_info, X86fminc,
+defm VMINCSDZ : avx512_comutable_binop_s<0x5D, "vminsd", f64x_info, f64, X86fminc,
                                          SchedWriteFCmp.Scl>, TB, XD,
                                          REX_W, EVEX, VVVV, VEX_LIG,
                                          EVEX_CD8<64, CD8VT1>, SIMD_EXC;
 
-defm VMAXCSSZ : avx512_comutable_binop_s<0x5F, "vmaxss", f32x_info, X86fmaxc,
+defm VMAXCSSZ : avx512_comutable_binop_s<0x5F, "vmaxss", f32x_info, f32, X86fmaxc,
                                          SchedWriteFCmp.Scl>, TB, XS,
                                          EVEX, VVVV, VEX_LIG, EVEX_CD8<32, CD8VT1>, SIMD_EXC;
 
-defm VMAXCSDZ : avx512_comutable_binop_s<0x5F, "vmaxsd", f64x_info, X86fmaxc,
+defm VMAXCSDZ : avx512_comutable_binop_s<0x5F, "vmaxsd", f64x_info, f64, X86fmaxc,
                                          SchedWriteFCmp.Scl>, TB, XD,
                                          REX_W, EVEX, VVVV, VEX_LIG,
                                          EVEX_CD8<64, CD8VT1>, SIMD_EXC;
 
-defm VMINCSHZ : avx512_comutable_binop_s<0x5D, "vminsh", f16x_info, X86fminc,
+defm VMINCSHZ : avx512_comutable_binop_s<0x5D, "vminsh", f16x_info, f16, X86fminc,
                                          SchedWriteFCmp.Scl>, T_MAP5, XS,
                                          EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>, SIMD_EXC;
 
-defm VMAXCSHZ : avx512_comutable_binop_s<0x5F, "vmaxsh", f16x_info, X86fmaxc,
+defm VMAXCSHZ : avx512_comutable_binop_s<0x5F, "vmaxsh", f16x_info, f16, X86fmaxc,
                                          SchedWriteFCmp.Scl>, T_MAP5, XS,
                                          EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>, SIMD_EXC;
 
@@ -7565,15 +7560,15 @@ def : Pat<(v2f64 (X86Movsd
 
 // Convert float/double to signed/unsigned int 32/64 with truncation
 multiclass avx512_cvt_s_all<bits<8> opc, string asm, X86VectorVTInfo _SrcRC,
-                            X86VectorVTInfo _DstRC, SDPatternOperator OpNode,
-                            SDNode OpNodeInt, SDNode OpNodeSAE,
+                            X86VectorVTInfo _DstRC, ValueType CT,
+                            SDPatternOperator OpNode, SDNode OpNodeInt, SDNode OpNodeSAE,
                             X86FoldableSchedWrite sched, string aliasStr,
                             Predicate prd = HasAVX512> {
 let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
   let isCodeGenOnly = 1 in {
   def rr : AVX512<opc, MRMSrcReg, (outs _DstRC.RC:$dst), (ins _SrcRC.FRC:$src),
               !strconcat(asm,"\t{$src, $dst|$dst, $src}"),
-              [(set _DstRC.RC:$dst, (OpNode _SrcRC.FRC:$src))]>,
+              [(set _DstRC.RC:$dst, (OpNode (CT _SrcRC.FRC:$src)))]>,
               EVEX, VEX_LIG, Sched<[sched]>, SIMD_EXC;
   def rm : AVX512<opc, MRMSrcMem, (outs _DstRC.RC:$dst), (ins _SrcRC.ScalarMemOp:$src),
               !strconcat(asm,"\t{$src, $dst|$dst, $src}"),
@@ -7607,29 +7602,29 @@ let Predicates = [prd], ExeDomain = _SrcRC.ExeDomain in {
                                           _SrcRC.IntScalarMemOp:$src), 0, "att">;
 }
 
-defm VCVTTSS2SIZ: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i32x_info,
+defm VCVTTSS2SIZ: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i32x_info, f32,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
                         "{l}">, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSS2SI64Z: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i64x_info,
+defm VCVTTSS2SI64Z: avx512_cvt_s_all<0x2C, "vcvttss2si", f32x_info, i64x_info, f32,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
                         "{q}">, REX_W, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSD2SIZ: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i32x_info,
+defm VCVTTSD2SIZ: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i32x_info, f64,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSD2I,
                         "{l}">, TB, XD, EVEX_CD8<64, CD8VT1>;
-defm VCVTTSD2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i64x_info,
+defm VCVTTSD2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsd2si", f64x_info, i64x_info, f64,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSD2I,
                         "{q}">, REX_W, TB, XD, EVEX_CD8<64, CD8VT1>;
 
-defm VCVTTSS2USIZ: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i32x_info,
+defm VCVTTSS2USIZ: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i32x_info, f32,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
                         "{l}">, TB, XS, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSS2USI64Z: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i64x_info,
+defm VCVTTSS2USI64Z: avx512_cvt_s_all<0x78, "vcvttss2usi", f32x_info, i64x_info, f32,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
                         "{q}">, TB, XS,REX_W, EVEX_CD8<32, CD8VT1>;
-defm VCVTTSD2USIZ: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i32x_info,
+defm VCVTTSD2USIZ: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i32x_info, f64,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSD2I,
                         "{l}">, TB, XD, EVEX_CD8<64, CD8VT1>;
-defm VCVTTSD2USI64Z: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i64x_info,
+defm VCVTTSD2USI64Z: avx512_cvt_s_all<0x78, "vcvttsd2usi", f64x_info, i64x_info, f64,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSD2I,
                         "{q}">, TB, XD, REX_W, EVEX_CD8<64, CD8VT1>;
 
@@ -7747,15 +7742,15 @@ def : Pat<(f32 (any_fpround FR64X:$src)),
           (VCVTSD2SSZrr (f32 (IMPLICIT_DEF)), FR64X:$src)>,
            Requires<[HasAVX512]>;
 
-def : Pat<(f32 (any_fpextend FR16X:$src)),
-          (VCVTSH2SSZrr (f32 (IMPLICIT_DEF)), FR16X:$src)>,
+def : Pat<(f32 (any_fpextend (f16 FR16X:$src))),
+          (VCVTSH2SSZrr (f32 (IMPLICIT_DEF)), (f16 FR16X:$src))>,
           Requires<[HasFP16]>;
 def : Pat<(f32 (any_fpextend (loadf16 addr:$src))),
           (VCVTSH2SSZrm (f32 (IMPLICIT_DEF)), addr:$src)>,
           Requires<[HasFP16, OptForSize]>;
 
-def : Pat<(f64 (any_fpextend FR16X:$src)),
-          (VCVTSH2SDZrr (f64 (IMPLICIT_DEF)), FR16X:$src)>,
+def : Pat<(f64 (any_fpextend (f16 FR16X:$src))),
+          (VCVTSH2SDZrr (f64 (IMPLICIT_DEF)), (f16 FR16X:$src))>,
           Requires<[HasFP16]>;
 def : Pat<(f64 (any_fpextend (loadf16 addr:$src))),
           (VCVTSH2SDZrm (f64 (IMPLICIT_DEF)), addr:$src)>,
@@ -11587,6 +11582,9 @@ let Predicates = [HasBWI], AddedComplexity = -10 in {
   def : Pat<(f16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWZrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
 }
 
+def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
+def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
+
 //===----------------------------------------------------------------------===//
 // VSHUFPS - VSHUFPD Operations
 //===----------------------------------------------------------------------===//
@@ -12950,9 +12948,9 @@ def VMOVWmr  : AVX512<0x7E, MRMDestMem, (outs),
                                      (iPTR 0))), addr:$dst)]>,
                        T_MAP5, PD, EVEX, EVEX_CD8<16, CD8VT1>, Sched<[WriteFStore]>;
 
-def : Pat<(i16 (bitconvert FR16X:$src)),
+def : Pat<(i16 (bitconvert (f16 FR16X:$src))),
           (i16 (EXTRACT_SUBREG
-                (VMOVSH2Wrr (COPY_TO_REGCLASS FR16X:$src, VR128X)),
+                (VMOVSH2Wrr (COPY_TO_REGCLASS (f16 FR16X:$src), VR128X)),
                 sub_16bit))>;
 def : Pat<(i16 (extractelt (v8i16 VR128X:$src), (iPTR 0))),
           (i16 (EXTRACT_SUBREG (VMOVSH2Wrr VR128X:$src), sub_16bit))>;
@@ -13281,16 +13279,16 @@ defm VCVTSH2USI64Z: avx512_cvt_s_int_round<0x79, f16x_info, i64x_info, X86cvts2u
                                    X86cvts2usiRnd, WriteCvtSS2I, "cvtsh2usi", "{q}", HasFP16>,
                                    T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
 
-defm VCVTTSH2SIZ: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i32x_info,
+defm VCVTTSH2SIZ: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i32x_info, f16,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
                         "{l}", HasFP16>, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i64x_info,
+defm VCVTTSH2SI64Z: avx512_cvt_s_all<0x2C, "vcvttsh2si", f16x_info, i64x_info, f16,
                         any_fp_to_sint, X86cvtts2Int, X86cvtts2IntSAE, WriteCvtSS2I,
                         "{q}", HasFP16>, REX_W, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2USIZ: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i32x_info,
+defm VCVTTSH2USIZ: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i32x_info, f16,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
                         "{l}", HasFP16>, T_MAP5, XS, EVEX_CD8<16, CD8VT1>;
-defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info,
+defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info, f16,
                         any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
                         "{q}", HasFP16>, T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
 
diff --git a/llvm/lib/Target/X86/X86InstrVecCompiler.td b/llvm/lib/Target/X86/X86InstrVecCompiler.td
index 122627ca45d31..68609b66f7256 100644
--- a/llvm/lib/Target/X86/X86InstrVecCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrVecCompiler.td
@@ -47,8 +47,8 @@ let Predicates = [NoVLX] in {
 }
 
 let Predicates = [HasVLX] in {
-  def : Pat<(v8f16 (scalar_to_vector FR16X:$src)),
-            (COPY_TO_REGCLASS FR16X:$src, VR128X)>;
+  def : Pat<(v8f16 (scalar_to_vector (f16 FR16X:$src))),
+            (COPY_TO_REGCLASS (f16 FR16X:$src), VR128X)>;
   // Implicitly promote a 32-bit scalar to a vector.
   def : Pat<(v4f32 (scalar_to_vector FR32X:$src)),
             (COPY_TO_REGCLASS FR32X:$src, VR128X)>;
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td
index 48459b3aca508..6dc29a8a51065 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.td
+++ b/llvm/lib/Target/X86/X86RegisterInfo.td
@@ -806,7 +806,7 @@ def FR32X : RegisterClass<"X86", [f32], 32, (sequence "XMM%u", 0, 31)>;
 
 def FR64X : RegisterClass<"X86", [f64], 64, (add FR32X)>;
 
-def FR16X : RegisterClass<"X86", [f16], 16, (add FR32X)> {let Size = 32;}
+def FR16X : RegisterClass<"X86", [f16, bf16], 16, (add FR32X)> {let Size = 32;}
 
 // Extended VR128 and VR256 for AVX-512 instructions
 def VR128X : RegisterClass<"X86", [v4f32, v2f64, v8f16, v8bf16, v16i8, v8i16, v4i32, v2i64, f128],

>From b56a6ba5184836f89fc398b18a22355a37fad146 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Sat, 26 Apr 2025 19:24:37 +0200
Subject: [PATCH 3/3] !fixup custom hooks for trunc/ext, add missing
 instruction selection patterns

---
 llvm/lib/Target/X86/X86FastISel.cpp          |  5 +-
 llvm/lib/Target/X86/X86ISelLowering.cpp      | 55 +++++++++++++++++++-
 llvm/lib/Target/X86/X86InstrAVX512.td        | 23 +++++++-
 llvm/lib/Target/X86/X86InstrFragmentsSIMD.td |  4 ++
 llvm/lib/Target/X86/X86InstrSSE.td           | 22 ++++++--
 5 files changed, 102 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/X86/X86FastISel.cpp b/llvm/lib/Target/X86/X86FastISel.cpp
index afe5d7f4bc7ed..2136da48dcf84 100644
--- a/llvm/lib/Target/X86/X86FastISel.cpp
+++ b/llvm/lib/Target/X86/X86FastISel.cpp
@@ -147,7 +147,8 @@ class X86FastISel final : public FastISel {
   /// computed in an SSE register, not on the X87 floating point stack.
   bool isScalarFPTypeInSSEReg(EVT VT) const {
     return (VT == MVT::f64 && Subtarget->hasSSE2()) ||
-           (VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16;
+           (VT == MVT::f32 && Subtarget->hasSSE1()) || VT == MVT::f16 ||
+           VT == MVT::bf16;
   }
 
   bool isTypeLegal(Type *Ty, MVT &VT, bool AllowI1 = false);
@@ -2283,6 +2284,7 @@ bool X86FastISel::X86FastEmitPseudoSelect(MVT RetVT, const Instruction *I) {
   case MVT::i16: Opc = X86::CMOV_GR16;  break;
   case MVT::i32: Opc = X86::CMOV_GR32;  break;
   case MVT::f16:
+  case MVT::bf16:
     Opc = Subtarget->hasAVX512() ? X86::CMOV_FR16X : X86::CMOV_FR16; break;
   case MVT::f32:
     Opc = Subtarget->hasAVX512() ? X86::CMOV_FR32X : X86::CMOV_FR32; break;
@@ -3972,6 +3974,7 @@ Register X86FastISel::fastMaterializeFloatZero(const ConstantFP *CF) {
   switch (VT.SimpleTy) {
   default: return 0;
   case MVT::f16:
+  case MVT::bf16:
     Opc = HasAVX512 ? X86::AVX512_FsFLD0SH : X86::FsFLD0SH;
     break;
   case MVT::f32:
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ad170f49aebb5..daa19ac43d10f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -678,9 +678,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
     // non-optsize case.
     setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
 
-    // Set the operation action Custom for bitcast to do the customization
-    // later.
+    // Set the operation action Custom for bitcast and conversion, and fall-back
+    // to software libcalls for the latter for the now.
     setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
+    setOperationAction(ISD::FP_EXTEND, MVT::bf16, Custom);
+    setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
 
     for (auto VT : { MVT::f32, MVT::f64 }) {
       // Use ANDPD to simulate FABS.
@@ -22066,6 +22068,31 @@ SDValue X86TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const {
     return Res;
   }
 
+  if (SVT == MVT::bf16 && VT == MVT::f32) {
+    TargetLowering::CallLoweringInfo CLI(DAG);
+    Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
+
+    TargetLowering::ArgListTy Args;
+    TargetLowering::ArgListEntry Entry;
+    Entry.Node = In;
+    Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
+    Args.push_back(Entry);
+
+    SDValue Callee =
+        DAG.getExternalSymbol(getLibcallName(RTLIB::FPEXT_BF16_F32),
+                              getPointerTy(DAG.getDataLayout()));
+    CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+        CallingConv::C, EVT(VT).getTypeForEVT(*DAG.getContext()), Callee,
+        std::move(Args));
+
+    SDValue Res;
+    std::tie(Res, Chain) = LowerCallTo(CLI);
+    if (IsStrict)
+      Res = DAG.getMergeValues({Res, Chain}, DL);
+
+    return Res;
+  }
+
   if (!SVT.isVector() || SVT.getVectorElementType() == MVT::bf16)
     return Op;
 
@@ -22149,6 +22176,30 @@ SDValue X86TargetLowering::LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
         ((Subtarget.hasBF16() && Subtarget.hasVLX()) ||
          Subtarget.hasAVXNECONVERT()))
       return Op;
+
+    // Need a soft libcall if the target has not BF16.
+    if (SVT.getScalarType() == MVT::f32 || SVT.getScalarType() == MVT::f64) {
+      TargetLowering::CallLoweringInfo CLI(DAG);
+      Chain = IsStrict ? Op.getOperand(0) : DAG.getEntryNode();
+
+      TargetLowering::ArgListTy Args;
+      TargetLowering::ArgListEntry Entry;
+      Entry.Node = In;
+      Entry.Ty = EVT(SVT).getTypeForEVT(*DAG.getContext());
+      Args.push_back(Entry);
+      SDValue Callee = DAG.getExternalSymbol(
+          getLibcallName(SVT == MVT::f64 ? RTLIB::FPROUND_F64_BF16
+                                         : RTLIB::FPROUND_F32_BF16),
+          getPointerTy(DAG.getDataLayout()));
+      CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
+          CallingConv::C, EVT(MVT::bf16).getTypeForEVT(*DAG.getContext()),
+          Callee, std::move(Args));
+
+      SDValue Res;
+      std::tie(Res, Chain) = LowerCallTo(CLI);
+      return IsStrict ? DAG.getMergeValues({Res, Chain}, DL) : Res;
+    }
+
     return SDValue();
   }
 
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index ebaba0ac29799..d476f92d34340 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -11585,6 +11585,13 @@ let Predicates = [HasBWI], AddedComplexity = -10 in {
 def : Pat<(f16 (bitconvert (bf16 FR16X:$src))), (f16 FR16X:$src)>;
 def : Pat<(bf16 (bitconvert (f16 FR16X:$src))), (bf16 FR16X:$src)>;
 
+let Predicates = [HasBWI, HasBF16] in {
+  def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWZrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+  def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWZmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+  def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWZrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128X)), 0), sub_16bit)>;
+  def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWZrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
+}
+
 //===----------------------------------------------------------------------===//
 // VSHUFPS - VSHUFPD Operations
 //===----------------------------------------------------------------------===//
@@ -12809,9 +12816,17 @@ let Predicates = [HasBF16, HasVLX] in {
 }
 
 let Predicates = [HasBF16] in {
+  def : Pat<(v8bf16 (X86VBroadcastld16 addr:$src)),
+            (VPBROADCASTWrm addr:$src)>;
+  def : Pat<(v16bf16 (X86VBroadcastld16 addr:$src)),
+            (VPBROADCASTWYrm addr:$src)>;
   def : Pat<(v32bf16 (X86VBroadcastld16 addr:$src)),
             (VPBROADCASTWZrm addr:$src)>;
 
+  def : Pat<(v8bf16 (X86VBroadcast (v8bf16 VR128:$src))),
+            (VPBROADCASTWrr VR128:$src)>;
+  def : Pat<(v16bf16 (X86VBroadcast (v8bf16 VR128:$src))),
+            (VPBROADCASTWYrr VR128:$src)>;
   def : Pat<(v32bf16 (X86VBroadcast (v8bf16 VR128X:$src))),
             (VPBROADCASTWZrr VR128X:$src)>;
 
@@ -12819,7 +12834,13 @@ let Predicates = [HasBF16] in {
             (VCVTNEPS2BF16Zrr VR512:$src)>;
   def : Pat<(v16bf16 (X86vfpround (loadv16f32 addr:$src))),
             (VCVTNEPS2BF16Zrm addr:$src)>;
-  // TODO: No scalar broadcast due to we don't support legal scalar bf16 so far.
+
+  def : Pat<(v8bf16 (X86VBroadcast (bf16 FR16X:$src))),
+            (VPBROADCASTWrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
+  def : Pat<(v16bf16 (X86VBroadcast (bf16 FR16X:$src))),
+            (VPBROADCASTWYrr (COPY_TO_REGCLASS FR16X:$src, VR128))>;
+  def : Pat<(v32bf16 (X86VBroadcast (bf16 FR16X:$src))),
+            (VPBROADCASTWZrr (COPY_TO_REGCLASS FR16X:$src, VR128X))>;
 }
 
 let Constraints = "$src1 = $dst" in {
diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
index de70570481fc2..b81119c1f6d26 100644
--- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -1193,6 +1193,10 @@ def fp16imm0 : PatLeaf<(f16 fpimm), [{
   return N->isExactlyValue(+0.0);
 }]>;
 
+def bfp16imm0 : PatLeaf<(bf16 fpimm), [{
+  return N->isExactlyValue(+0.0);
+}]>;
+
 def fp32imm0 : PatLeaf<(f32 fpimm), [{
   return N->isExactlyValue(+0.0);
 }]>;
diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index 49a62fd3422d0..e2aa90b59bbfa 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -4048,6 +4048,19 @@ let Predicates = [HasAVX, NoBWI] in {
   def : Pat<(f16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
 }
 
+let Predicates = [UseSSE2] in {
+  def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (PINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+  def : Pat<(store bf16:$src, addr:$dst), (MOV16mr addr:$dst, (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit))>;
+  def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (PEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
+  def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (PINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16X)>;
+}
+
+let Predicates = [HasAVX, NoBWI] in {
+  def : Pat<(bf16 (load addr:$src)), (COPY_TO_REGCLASS (VPINSRWrmi (v8i16 (IMPLICIT_DEF)), addr:$src, 0), FR16X)>;
+  def : Pat<(i16 (bitconvert bf16:$src)), (EXTRACT_SUBREG (VPEXTRWrri (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0), sub_16bit)>;
+  def : Pat<(bf16 (bitconvert i16:$src)), (COPY_TO_REGCLASS (VPINSRWrri (v8i16 (IMPLICIT_DEF)), (INSERT_SUBREG (IMPLICIT_DEF), GR16:$src, sub_16bit), 0), FR16)>;
+}
+
 //===---------------------------------------------------------------------===//
 // SSE2 - Packed Mask Creation
 //===---------------------------------------------------------------------===//
@@ -5279,12 +5292,15 @@ let Predicates = [HasAVX, NoBWI] in
 
 defm PEXTRW      : SS41I_extract16<0x15, "pextrw">;
 
-let Predicates = [UseSSE41] in
+let Predicates = [UseSSE41] in {
   def : Pat<(store f16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
+  def : Pat<(store bf16:$src, addr:$dst), (PEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+}
 
-let Predicates = [HasAVX, NoBWI] in
+let Predicates = [HasAVX, NoBWI] in {
   def : Pat<(store f16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16:$src, VR128)), 0)>;
-
+  def : Pat<(store bf16:$src, addr:$dst), (VPEXTRWmri addr:$dst, (v8i16 (COPY_TO_REGCLASS FR16X:$src, VR128)), 0)>;
+}
 
 /// SS41I_extract32 - SSE 4.1 extract 32 bits to int reg or memory destination
 multiclass SS41I_extract32<bits<8> opc, string OpcodeStr> {



More information about the llvm-commits mailing list