[llvm] [llvm][RISCV] Support fma codegen for zvfbfa (PR #172949)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 18 20:32:43 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Brandon Wu (4vtomat)

<details>
<summary>Changes</summary>

This patch supports codegen for both widen and non-widen fma.


---

Patch is 247.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172949.diff


13 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+7-4) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (+42-14) 
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+38-14) 
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll (+225) 
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmsub-sdnode.ll (+245) 
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfnmadd-sdnode.ll (+265) 
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfnmsub-sdnode.ll (+245) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll (+969-4) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vfmadd-sdnode.ll (+376-477) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vfmsub-sdnode.ll (+450-8) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vfnmadd-sdnode.ll (+493-4) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vfnmsub-sdnode.ll (+435-4) 
- (modified) llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll (+932-8) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 439b148576e23..2daaf1e0e3b87 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -93,7 +93,8 @@ static const unsigned ZvfbfaVPOps[] = {
 static const unsigned ZvfbfaOps[] = {
     ISD::FNEG,        ISD::FABS,        ISD::FCOPYSIGN, ISD::FADD,
     ISD::FSUB,        ISD::FMUL,        ISD::FMINNUM,   ISD::FMAXNUM,
-    ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM, ISD::FMINIMUM,  ISD::FMAXIMUM};
+    ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM, ISD::FMINIMUM,  ISD::FMAXIMUM,
+    ISD::FMA};
 
 RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                                          const RISCVSubtarget &STI)
@@ -1100,7 +1101,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
     // TODO: Make more of these ops legal.
     static const unsigned ZvfbfaPromoteOps[] = {ISD::FDIV,
-                                                ISD::FMA,
                                                 ISD::FSQRT,
                                                 ISD::FCEIL,
                                                 ISD::FTRUNC,
@@ -17851,7 +17851,8 @@ struct NodeExtensionHelper {
   }
 
   bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
-    return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
+    return NarrowEltVT == MVT::bf16 &&
+           (Subtarget.hasStdExtZvfbfwma() || Subtarget.hasVInstructionsBF16());
   }
 
   /// Helper method to set the various fields of this struct based on the
@@ -18306,7 +18307,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   case RISCVISD::VFNMADD_VL:
   case RISCVISD::VFNMSUB_VL:
     Strategies.push_back(canFoldToVWWithSameExtension);
-    if (Root->getOpcode() == RISCVISD::VFMADD_VL)
+    // FIXME: Once other widen operations are supported we can merge
+    // canFoldToVWWithSameExtension and canFoldToVWWithSameExtBF16.
+    if (Root->getOpcode() != RISCVISD::FMUL_VL)
       Strategies.push_back(canFoldToVWWithSameExtBF16);
     break;
   case ISD::MUL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index b3cc33d31761d..1bdbfe40b1521 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -674,10 +674,8 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
-    let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
-                                 !if(!eq(vti.Scalar, bf16),
-                                     [HasStdExtZvfbfwma],
-                                     GetVTypePredicates<vti>.Predicates)) in {
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
       def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
                                       (vti.Vector vti.RegClass:$rs1),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
@@ -685,9 +683,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
                                       (vti.Vector vti.RegClass:$rs2),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
                      (wti.Vector wti.RegClass:$rd)),
-                (!cast<Instruction>(instruction_name#
-                                    !if(!eq(vti.Scalar, bf16), "BF16", "")#
-                                    "_VV_"#suffix)
+                (!cast<Instruction>(instruction_name#"_VV_"#suffix)
                    wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    // Value to indicate no rounding mode change in
                    // RISCVInsertReadWriteCSR
@@ -699,9 +695,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
                                       (vti.Vector vti.RegClass:$rs2),
                                       (vti.Mask true_mask), (XLenVT srcvalue))),
                      (wti.Vector wti.RegClass:$rd)),
-                (!cast<Instruction>(instruction_name#
-                                    !if(!eq(vti.Scalar, bf16), "BF16", "")#
-                                    "_V"#vti.ScalarSuffix#"_"#suffix)
+                (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix)
                    wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    // Value to indicate no rounding mode change in
                    // RISCVInsertReadWriteCSR
@@ -712,7 +706,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
 }
 
 multiclass VPatWidenFPNegMulAccSDNode_VV_VF_RM<string instruction_name> {
-  foreach vtiToWti = AllWidenableFloatVectors in {
+  foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -756,7 +750,7 @@ multiclass VPatWidenFPNegMulAccSDNode_VV_VF_RM<string instruction_name> {
 }
 
 multiclass VPatWidenFPMulSacSDNode_VV_VF_RM<string instruction_name> {
-  foreach vtiToWti = AllWidenableFloatVectors in {
+  foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -789,7 +783,7 @@ multiclass VPatWidenFPMulSacSDNode_VV_VF_RM<string instruction_name> {
 }
 
 multiclass VPatWidenFPNegMulSacSDNode_VV_VF_RM<string instruction_name> {
-  foreach vtiToWti = AllWidenableFloatVectors in {
+  foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -1235,7 +1229,7 @@ defm : VPatBinaryFPSDNode_R_VF_RM<any_fdiv, "PseudoVFRDIV", isSEWAware=1>;
 defm : VPatWidenBinaryFPSDNode_VV_VF_RM<fmul, "PseudoVFWMUL">;
 
 // 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
-foreach fvti = AllFloatVectors in {
+foreach fvti = AllFloatAndBF16Vectors in {
   // NOTE: We choose VFMADD because it has the most commuting freedom. So it
   // works best with how TwoAddressInstructionPass tries commuting.
   defvar suffix = fvti.LMul.MX # "_E" # fvti.SEW;
@@ -1336,6 +1330,40 @@ defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
 defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
 defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;
 
+// Zvfbfwma
+foreach vtiToWti = AllWidenableBF16ToFloatVectors in {
+  defvar vti = vtiToWti.Vti;
+  defvar wti = vtiToWti.Wti;
+  defvar suffix = vti.LMul.MX # "_E16";
+  let Predicates = [HasStdExtZvfbfwma, HasVInstructionsAnyF] in {
+    def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
+                                    (vti.Vector vti.RegClass:$rs1),
+                                    (vti.Mask true_mask), (XLenVT srcvalue))),
+                   (wti.Vector (riscv_fpextend_vl_sameuser
+                                    (vti.Vector vti.RegClass:$rs2),
+                                    (vti.Mask true_mask), (XLenVT srcvalue))),
+                   (wti.Vector wti.RegClass:$rd)),
+              (!cast<Instruction>("PseudoVFWMACCBF16_VV_"#suffix)
+                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 // Value to indicate no rounding mode change in
+                 // RISCVInsertReadWriteCSR
+                 FRM_DYN,
+                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+    def : Pat<(fma (wti.Vector (SplatFPOp
+                                    (fpext_oneuse (vti.Scalar vti.ScalarRegClass:$rs1)))),
+                   (wti.Vector (riscv_fpextend_vl_oneuse
+                                    (vti.Vector vti.RegClass:$rs2),
+                                    (vti.Mask true_mask), (XLenVT srcvalue))),
+                   (wti.Vector wti.RegClass:$rd)),
+              (!cast<Instruction>("PseudoVFWMACCBF16_V"#vti.ScalarSuffix#"_"#suffix)
+                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 // Value to indicate no rounding mode change in
+                 // RISCVInsertReadWriteCSR
+                 FRM_DYN,
+                 vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+  }
+}
+
 foreach vti = AllFloatAndBF16Vectors in {
   let Predicates = GetVTypePredicates<vti>.Predicates in {
     // 13.8. Vector Floating-Point Square-Root Instruction
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 4c41667560a98..ab94566350319 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1812,7 +1812,7 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
 }
 
 multiclass VPatFPMulAddVL_VV_VF_RM<SDPatternOperator vop, string instruction_name> {
-  foreach vti = AllFloatVectors in {
+  foreach vti = AllFloatAndBF16Vectors in {
   defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
   let Predicates = GetVTypePredicates<vti>.Predicates in {
     def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
@@ -1843,22 +1843,18 @@ multiclass VPatFPMulAddVL_VV_VF_RM<SDPatternOperator vop, string instruction_nam
 
 multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
                                         list<VTypeInfoToWide> vtiToWtis =
-                                        AllWidenableFloatVectors> {
+                                        AllWidenableFloatAndBF16Vectors> {
   foreach vtiToWti = vtiToWtis in {
     defvar vti = vtiToWti.Vti;
     defvar wti = vtiToWti.Wti;
     defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
-    let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
-                                 !if(!eq(vti.Scalar, bf16),
-                                     [HasStdExtZvfbfwma],
-                                     GetVTypePredicates<vti>.Predicates)) in {
+    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                                 GetVTypePredicates<wti>.Predicates) in {
       def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
                      (vti.Vector vti.RegClass:$rs2),
                      (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
                      VLOpFrag),
-                (!cast<Instruction>(instruction_name#
-                                    !if(!eq(vti.Scalar, bf16), "BF16", "")#
-                                    "_VV_"#suffix#"_MASK")
+                (!cast<Instruction>(instruction_name#"_VV_"#suffix#"_MASK")
                    wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask VMV0:$vm),
                    // Value to indicate no rounding mode change in
@@ -1869,9 +1865,7 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
                      (vti.Vector vti.RegClass:$rs2),
                      (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
                      VLOpFrag),
-                (!cast<Instruction>(instruction_name#
-                                    !if(!eq(vti.Scalar, bf16), "BF16", "")#
-                                    "_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
+                (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
                    wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
                    (vti.Mask VMV0:$vm),
                    // Value to indicate no rounding mode change in
@@ -2344,12 +2338,42 @@ defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmadd_vl, "PseudoVFNMADD">;
 defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmsub_vl, "PseudoVFNMSUB">;
 
 // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
-defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC",
-                                    AllWidenableFloatAndBF16Vectors>;
+defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC">;
 defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
 defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
 defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
 
+// Zvfbfwma
+foreach vtiToWti = AllWidenableBF16ToFloatVectors in {
+  defvar vti = vtiToWti.Vti;
+  defvar wti = vtiToWti.Wti;
+  defvar suffix = vti.LMul.MX # "_E16";
+  let Predicates = [HasStdExtZvfbfwma, HasVInstructionsAnyF] in {
+    def : Pat<(riscv_vfwmadd_vl (vti.Vector vti.RegClass:$rs1),
+                                (vti.Vector vti.RegClass:$rs2),
+                                (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
+                                VLOpFrag),
+              (!cast<Instruction>("PseudoVFWMACCBF16_VV_"#suffix#"_MASK")
+                 wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+                 (vti.Mask VMV0:$vm),
+                 // Value to indicate no rounding mode change in
+                 // RISCVInsertReadWriteCSR
+                 FRM_DYN,
+                 GPR:$vl, vti.Log2SEW, TA_MA)>;
+    def : Pat<(riscv_vfwmadd_vl (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
+                                (vti.Vector vti.RegClass:$rs2),
+                                (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
+                                VLOpFrag),
+              (!cast<Instruction>("PseudoVFWMACCBF16_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
+                 wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+                 (vti.Mask VMV0:$vm),
+                 // Value to indicate no rounding mode change in
+                 // RISCVInsertReadWriteCSR
+                 FRM_DYN,
+                 GPR:$vl, vti.Log2SEW, TA_MA)>;
+  }
+}
+
 // 13.11. Vector Floating-Point MIN/MAX Instructions
 defm : VPatBinaryFPVL_VV_VF<riscv_vfmin_vl, "PseudoVFMIN", isSEWAware=1>;
 defm : VPatBinaryFPVL_VV_VF<riscv_vfmax_vl, "PseudoVFMAX", isSEWAware=1>;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll
new file mode 100644
index 0000000000000..814ee7739ee62
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll
@@ -0,0 +1,225 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+experimental-zvfbfa,+v \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+experimental-zvfbfa,+v \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s
+
+define <1 x bfloat> @vfmadd_vv_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+  %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc)
+  ret <1 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmadd_vf_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmadd.vf v8, fa0, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <1 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <1 x bfloat> %head, <1 x bfloat> poison, <1 x i32> zeroinitializer
+  %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %va, <1 x bfloat> %splat, <1 x bfloat> %vb)
+  ret <1 x bfloat> %vd
+}
+
+define <2 x bfloat> @vfmadd_vv_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, <2 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+  %vd = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, <2 x bfloat> %vc)
+  ret <2 x bfloat> %vd
+}
+
+define <2 x bfloat> @vfmadd_vf_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v2bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmadd.vf v8, fa0, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <2 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <2 x bfloat> %head, <2 x bfloat> poison, <2 x i32> zeroinitializer
+  %vd = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %va, <2 x bfloat> %splat, <2 x bfloat> %vb)
+  ret <2 x bfloat> %vd
+}
+
+define <4 x bfloat> @vfmadd_vv_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, <4 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vfmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+  %vd = call <4 x bfloat> @llvm.fma.v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, <4 x bfloat> %vc)
+  ret <4 x bfloat> %vd
+}
+
+define <4 x bfloat> @vfmadd_vf_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v4bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vfmadd.vf v8, fa0, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <4 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <4 x bfloat> %head, <4 x bfloat> poison, <4 x i32> zeroinitializer
+  %vd = call <4 x bfloat> @llvm.fma.v4bf16(<4 x bfloat> %va, <4 x bfloat> %splat, <4 x bfloat> %vb)
+  ret <4 x bfloat> %vd
+}
+
+define <8 x bfloat> @vfmadd_vv_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, <8 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 8, e16, m1, ta, ma
+; CHECK-NEXT:    vfmadd.vv v8, v9, v10
+; CHECK-NEXT:    ret
+  %vd = call <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, <8 x bfloat> %vc)
+  ret <8 x bfloat> %vd
+}
+
+define <8 x bfloat> @vfmadd_vf_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v8bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 8, e16, m1, ta, ma
+; CHECK-NEXT:    vfmadd.vf v8, fa0, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <8 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <8 x bfloat> %head, <8 x bfloat> poison, <8 x i32> zeroinitializer
+  %vd = call <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat> %va, <8 x bfloat> %splat, <8 x bfloat> %vb)
+  ret <8 x bfloat> %vd
+}
+
+define <16 x bfloat> @vfmadd_vv_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, <16 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vfmadd.vv v8, v10, v12
+; CHECK-NEXT:    ret
+  %vd = call <16 x bfloat> @llvm.fma.v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, <16 x bfloat> %vc)
+  ret <16 x bfloat> %vd
+}
+
+define <16 x bfloat> @vfmadd_vf_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v16bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT:    vfmadd.vf v8, fa0, v10
+; CHECK-NEXT:    ret
+  %head = insertelement <16 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <16 x bfloat> %head, <16 x bfloat> poison, <16 x i32> zeroinitializer
+  %vd = call <16 x bfloat> @llvm.fma.v16bf16(<16 x bfloat> %va, <16 x bfloat> %splat, <16 x bfloat> %vb)
+  ret <16 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmacc_vv_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc) {
+; CHECK-LABEL: vfmacc_vv_v1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmacc.vv v8, v10, v9
+; CHECK-NEXT:    ret
+  %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %vb, <1 x bfloat> %vc, <1 x bfloat> %va)
+  ret <1 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmacc_vf_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmacc_vf_v1bf16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT:    vfmacc.vf v8, fa0, v9
+; CHECK-NEXT:    ret
+  %head = insertelement <1 x bfloat> poison, bfloat %c, i32 0
+  %splat = shufflevector <1 x bfloat> %head, <1 x bfloat> poison, <1 x i32> zeroinitializer
+  %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %vb, <1 ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172949


More information about the llvm-commits mailing list