[llvm] [llvm][RISCV] Support fma codegen for zvfbfa (PR #172949)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 18 20:32:43 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Brandon Wu (4vtomat)
<details>
<summary>Changes</summary>
This patch supports codegen for both widen and non-widen fma.
---
Patch is 247.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172949.diff
13 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+7-4)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (+42-14)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+38-14)
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll (+225)
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmsub-sdnode.ll (+245)
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfnmadd-sdnode.ll (+265)
- (added) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfnmsub-sdnode.ll (+245)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll (+969-4)
- (modified) llvm/test/CodeGen/RISCV/rvv/vfmadd-sdnode.ll (+376-477)
- (modified) llvm/test/CodeGen/RISCV/rvv/vfmsub-sdnode.ll (+450-8)
- (modified) llvm/test/CodeGen/RISCV/rvv/vfnmadd-sdnode.ll (+493-4)
- (modified) llvm/test/CodeGen/RISCV/rvv/vfnmsub-sdnode.ll (+435-4)
- (modified) llvm/test/CodeGen/RISCV/rvv/vfwmacc-sdnode.ll (+932-8)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 439b148576e23..2daaf1e0e3b87 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -93,7 +93,8 @@ static const unsigned ZvfbfaVPOps[] = {
static const unsigned ZvfbfaOps[] = {
ISD::FNEG, ISD::FABS, ISD::FCOPYSIGN, ISD::FADD,
ISD::FSUB, ISD::FMUL, ISD::FMINNUM, ISD::FMAXNUM,
- ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM, ISD::FMINIMUM, ISD::FMAXIMUM};
+ ISD::FMINIMUMNUM, ISD::FMAXIMUMNUM, ISD::FMINIMUM, ISD::FMAXIMUM,
+ ISD::FMA};
RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
@@ -1100,7 +1101,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// TODO: Make more of these ops legal.
static const unsigned ZvfbfaPromoteOps[] = {ISD::FDIV,
- ISD::FMA,
ISD::FSQRT,
ISD::FCEIL,
ISD::FTRUNC,
@@ -17851,7 +17851,8 @@ struct NodeExtensionHelper {
}
bool isSupportedBF16Extend(MVT NarrowEltVT, const RISCVSubtarget &Subtarget) {
- return NarrowEltVT == MVT::bf16 && Subtarget.hasStdExtZvfbfwma();
+ return NarrowEltVT == MVT::bf16 &&
+ (Subtarget.hasStdExtZvfbfwma() || Subtarget.hasVInstructionsBF16());
}
/// Helper method to set the various fields of this struct based on the
@@ -18306,7 +18307,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
- if (Root->getOpcode() == RISCVISD::VFMADD_VL)
+ // FIXME: Once other widen operations are supported we can merge
+ // canFoldToVWWithSameExtension and canFoldToVWWithSameExtBF16.
+ if (Root->getOpcode() != RISCVISD::FMUL_VL)
Strategies.push_back(canFoldToVWWithSameExtBF16);
break;
case ISD::MUL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index b3cc33d31761d..1bdbfe40b1521 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -674,10 +674,8 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
- let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
- !if(!eq(vti.Scalar, bf16),
- [HasStdExtZvfbfwma],
- GetVTypePredicates<vti>.Predicates)) in {
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
(vti.Vector vti.RegClass:$rs1),
(vti.Mask true_mask), (XLenVT srcvalue))),
@@ -685,9 +683,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
- (!cast<Instruction>(instruction_name#
- !if(!eq(vti.Scalar, bf16), "BF16", "")#
- "_VV_"#suffix)
+ (!cast<Instruction>(instruction_name#"_VV_"#suffix)
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
@@ -699,9 +695,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(vti.Mask true_mask), (XLenVT srcvalue))),
(wti.Vector wti.RegClass:$rd)),
- (!cast<Instruction>(instruction_name#
- !if(!eq(vti.Scalar, bf16), "BF16", "")#
- "_V"#vti.ScalarSuffix#"_"#suffix)
+ (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix)
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
@@ -712,7 +706,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM<string instruction_name,
}
multiclass VPatWidenFPNegMulAccSDNode_VV_VF_RM<string instruction_name> {
- foreach vtiToWti = AllWidenableFloatVectors in {
+ foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -756,7 +750,7 @@ multiclass VPatWidenFPNegMulAccSDNode_VV_VF_RM<string instruction_name> {
}
multiclass VPatWidenFPMulSacSDNode_VV_VF_RM<string instruction_name> {
- foreach vtiToWti = AllWidenableFloatVectors in {
+ foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -789,7 +783,7 @@ multiclass VPatWidenFPMulSacSDNode_VV_VF_RM<string instruction_name> {
}
multiclass VPatWidenFPNegMulSacSDNode_VV_VF_RM<string instruction_name> {
- foreach vtiToWti = AllWidenableFloatVectors in {
+ foreach vtiToWti = AllWidenableFloatAndBF16Vectors in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
@@ -1235,7 +1229,7 @@ defm : VPatBinaryFPSDNode_R_VF_RM<any_fdiv, "PseudoVFRDIV", isSEWAware=1>;
defm : VPatWidenBinaryFPSDNode_VV_VF_RM<fmul, "PseudoVFWMUL">;
// 13.6 Vector Single-Width Floating-Point Fused Multiply-Add Instructions.
-foreach fvti = AllFloatVectors in {
+foreach fvti = AllFloatAndBF16Vectors in {
// NOTE: We choose VFMADD because it has the most commuting freedom. So it
// works best with how TwoAddressInstructionPass tries commuting.
defvar suffix = fvti.LMul.MX # "_E" # fvti.SEW;
@@ -1336,6 +1330,40 @@ defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">;
defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">;
defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">;
+// Zvfbfwma
+foreach vtiToWti = AllWidenableBF16ToFloatVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ defvar suffix = vti.LMul.MX # "_E16";
+ let Predicates = [HasStdExtZvfbfwma, HasVInstructionsAnyF] in {
+ def : Pat<(fma (wti.Vector (riscv_fpextend_vl_sameuser
+ (vti.Vector vti.RegClass:$rs1),
+ (vti.Mask true_mask), (XLenVT srcvalue))),
+ (wti.Vector (riscv_fpextend_vl_sameuser
+ (vti.Vector vti.RegClass:$rs2),
+ (vti.Mask true_mask), (XLenVT srcvalue))),
+ (wti.Vector wti.RegClass:$rd)),
+ (!cast<Instruction>("PseudoVFWMACCBF16_VV_"#suffix)
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ def : Pat<(fma (wti.Vector (SplatFPOp
+ (fpext_oneuse (vti.Scalar vti.ScalarRegClass:$rs1)))),
+ (wti.Vector (riscv_fpextend_vl_oneuse
+ (vti.Vector vti.RegClass:$rs2),
+ (vti.Mask true_mask), (XLenVT srcvalue))),
+ (wti.Vector wti.RegClass:$rd)),
+ (!cast<Instruction>("PseudoVFWMACCBF16_V"#vti.ScalarSuffix#"_"#suffix)
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>;
+ }
+}
+
foreach vti = AllFloatAndBF16Vectors in {
let Predicates = GetVTypePredicates<vti>.Predicates in {
// 13.8. Vector Floating-Point Square-Root Instruction
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 4c41667560a98..ab94566350319 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -1812,7 +1812,7 @@ multiclass VPatFPMulAddVL_VV_VF<SDPatternOperator vop, string instruction_name>
}
multiclass VPatFPMulAddVL_VV_VF_RM<SDPatternOperator vop, string instruction_name> {
- foreach vti = AllFloatVectors in {
+ foreach vti = AllFloatAndBF16Vectors in {
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
let Predicates = GetVTypePredicates<vti>.Predicates in {
def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd,
@@ -1843,22 +1843,18 @@ multiclass VPatFPMulAddVL_VV_VF_RM<SDPatternOperator vop, string instruction_nam
multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
list<VTypeInfoToWide> vtiToWtis =
- AllWidenableFloatVectors> {
+ AllWidenableFloatAndBF16Vectors> {
foreach vtiToWti = vtiToWtis in {
defvar vti = vtiToWti.Vti;
defvar wti = vtiToWti.Wti;
defvar suffix = vti.LMul.MX # "_E" # vti.SEW;
- let Predicates = !listconcat(GetVTypePredicates<wti>.Predicates,
- !if(!eq(vti.Scalar, bf16),
- [HasStdExtZvfbfwma],
- GetVTypePredicates<vti>.Predicates)) in {
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
def : Pat<(vop (vti.Vector vti.RegClass:$rs1),
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>(instruction_name#
- !if(!eq(vti.Scalar, bf16), "BF16", "")#
- "_VV_"#suffix#"_MASK")
+ (!cast<Instruction>(instruction_name#"_VV_"#suffix#"_MASK")
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask VMV0:$vm),
// Value to indicate no rounding mode change in
@@ -1869,9 +1865,7 @@ multiclass VPatWidenFPMulAccVL_VV_VF_RM<SDNode vop, string instruction_name,
(vti.Vector vti.RegClass:$rs2),
(wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
VLOpFrag),
- (!cast<Instruction>(instruction_name#
- !if(!eq(vti.Scalar, bf16), "BF16", "")#
- "_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
+ (!cast<Instruction>(instruction_name#"_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
(vti.Mask VMV0:$vm),
// Value to indicate no rounding mode change in
@@ -2344,12 +2338,42 @@ defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmadd_vl, "PseudoVFNMADD">;
defm : VPatFPMulAddVL_VV_VF_RM<any_riscv_vfnmsub_vl, "PseudoVFNMSUB">;
// 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions
-defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC",
- AllWidenableFloatAndBF16Vectors>;
+defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmadd_vl, "PseudoVFWMACC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmadd_vl, "PseudoVFWNMACC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwmsub_vl, "PseudoVFWMSAC">;
defm : VPatWidenFPMulAccVL_VV_VF_RM<riscv_vfwnmsub_vl, "PseudoVFWNMSAC">;
+// Zvfbfwma
+foreach vtiToWti = AllWidenableBF16ToFloatVectors in {
+ defvar vti = vtiToWti.Vti;
+ defvar wti = vtiToWti.Wti;
+ defvar suffix = vti.LMul.MX # "_E16";
+ let Predicates = [HasStdExtZvfbfwma, HasVInstructionsAnyF] in {
+ def : Pat<(riscv_vfwmadd_vl (vti.Vector vti.RegClass:$rs1),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
+ VLOpFrag),
+ (!cast<Instruction>("PseudoVFWMACCBF16_VV_"#suffix#"_MASK")
+ wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask VMV0:$vm),
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ GPR:$vl, vti.Log2SEW, TA_MA)>;
+ def : Pat<(riscv_vfwmadd_vl (vti.Vector (SplatFPOp vti.ScalarRegClass:$rs1)),
+ (vti.Vector vti.RegClass:$rs2),
+ (wti.Vector wti.RegClass:$rd), (vti.Mask VMV0:$vm),
+ VLOpFrag),
+ (!cast<Instruction>("PseudoVFWMACCBF16_V"#vti.ScalarSuffix#"_"#suffix#"_MASK")
+ wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
+ (vti.Mask VMV0:$vm),
+ // Value to indicate no rounding mode change in
+ // RISCVInsertReadWriteCSR
+ FRM_DYN,
+ GPR:$vl, vti.Log2SEW, TA_MA)>;
+ }
+}
+
// 13.11. Vector Floating-Point MIN/MAX Instructions
defm : VPatBinaryFPVL_VV_VF<riscv_vfmin_vl, "PseudoVFMIN", isSEWAware=1>;
defm : VPatBinaryFPVL_VV_VF<riscv_vfmax_vl, "PseudoVFMAX", isSEWAware=1>;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll
new file mode 100644
index 0000000000000..814ee7739ee62
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfmadd-sdnode.ll
@@ -0,0 +1,225 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+experimental-zvfbfa,+v \
+; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+experimental-zvfbfa,+v \
+; RUN: -verify-machineinstrs < %s | FileCheck %s
+
+define <1 x bfloat> @vfmadd_vv_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v1bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT: vfmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+ %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc)
+ ret <1 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmadd_vf_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v1bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT: vfmadd.vf v8, fa0, v9
+; CHECK-NEXT: ret
+ %head = insertelement <1 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <1 x bfloat> %head, <1 x bfloat> poison, <1 x i32> zeroinitializer
+ %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %va, <1 x bfloat> %splat, <1 x bfloat> %vb)
+ ret <1 x bfloat> %vd
+}
+
+define <2 x bfloat> @vfmadd_vv_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, <2 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v2bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT: vfmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+ %vd = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, <2 x bfloat> %vc)
+ ret <2 x bfloat> %vd
+}
+
+define <2 x bfloat> @vfmadd_vf_v2bf16(<2 x bfloat> %va, <2 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v2bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
+; CHECK-NEXT: vfmadd.vf v8, fa0, v9
+; CHECK-NEXT: ret
+ %head = insertelement <2 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <2 x bfloat> %head, <2 x bfloat> poison, <2 x i32> zeroinitializer
+ %vd = call <2 x bfloat> @llvm.fma.v2bf16(<2 x bfloat> %va, <2 x bfloat> %splat, <2 x bfloat> %vb)
+ ret <2 x bfloat> %vd
+}
+
+define <4 x bfloat> @vfmadd_vv_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, <4 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v4bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vfmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+ %vd = call <4 x bfloat> @llvm.fma.v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, <4 x bfloat> %vc)
+ ret <4 x bfloat> %vd
+}
+
+define <4 x bfloat> @vfmadd_vf_v4bf16(<4 x bfloat> %va, <4 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v4bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vfmadd.vf v8, fa0, v9
+; CHECK-NEXT: ret
+ %head = insertelement <4 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <4 x bfloat> %head, <4 x bfloat> poison, <4 x i32> zeroinitializer
+ %vd = call <4 x bfloat> @llvm.fma.v4bf16(<4 x bfloat> %va, <4 x bfloat> %splat, <4 x bfloat> %vb)
+ ret <4 x bfloat> %vd
+}
+
+define <8 x bfloat> @vfmadd_vv_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, <8 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v8bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; CHECK-NEXT: vfmadd.vv v8, v9, v10
+; CHECK-NEXT: ret
+ %vd = call <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, <8 x bfloat> %vc)
+ ret <8 x bfloat> %vd
+}
+
+define <8 x bfloat> @vfmadd_vf_v8bf16(<8 x bfloat> %va, <8 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v8bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
+; CHECK-NEXT: vfmadd.vf v8, fa0, v9
+; CHECK-NEXT: ret
+ %head = insertelement <8 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <8 x bfloat> %head, <8 x bfloat> poison, <8 x i32> zeroinitializer
+ %vd = call <8 x bfloat> @llvm.fma.v8bf16(<8 x bfloat> %va, <8 x bfloat> %splat, <8 x bfloat> %vb)
+ ret <8 x bfloat> %vd
+}
+
+define <16 x bfloat> @vfmadd_vv_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, <16 x bfloat> %vc) {
+; CHECK-LABEL: vfmadd_vv_v16bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vfmadd.vv v8, v10, v12
+; CHECK-NEXT: ret
+ %vd = call <16 x bfloat> @llvm.fma.v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, <16 x bfloat> %vc)
+ ret <16 x bfloat> %vd
+}
+
+define <16 x bfloat> @vfmadd_vf_v16bf16(<16 x bfloat> %va, <16 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmadd_vf_v16bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
+; CHECK-NEXT: vfmadd.vf v8, fa0, v10
+; CHECK-NEXT: ret
+ %head = insertelement <16 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <16 x bfloat> %head, <16 x bfloat> poison, <16 x i32> zeroinitializer
+ %vd = call <16 x bfloat> @llvm.fma.v16bf16(<16 x bfloat> %va, <16 x bfloat> %splat, <16 x bfloat> %vb)
+ ret <16 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmacc_vv_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, <1 x bfloat> %vc) {
+; CHECK-LABEL: vfmacc_vv_v1bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT: vfmacc.vv v8, v10, v9
+; CHECK-NEXT: ret
+ %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %vb, <1 x bfloat> %vc, <1 x bfloat> %va)
+ ret <1 x bfloat> %vd
+}
+
+define <1 x bfloat> @vfmacc_vf_v1bf16(<1 x bfloat> %va, <1 x bfloat> %vb, bfloat %c) {
+; CHECK-LABEL: vfmacc_vf_v1bf16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 1, e16, mf4, ta, ma
+; CHECK-NEXT: vfmacc.vf v8, fa0, v9
+; CHECK-NEXT: ret
+ %head = insertelement <1 x bfloat> poison, bfloat %c, i32 0
+ %splat = shufflevector <1 x bfloat> %head, <1 x bfloat> poison, <1 x i32> zeroinitializer
+ %vd = call <1 x bfloat> @llvm.fma.v1bf16(<1 x bfloat> %vb, <1 ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172949
More information about the llvm-commits
mailing list