[llvm] [LLVM][CodeGen][SVE] Use BFDOT for fadd reductions. (PR #147981)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 10 08:04:12 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

We typically lower bfloat add reductions by promoting the input to a float vector, reducing using FADDV before rounding the result. By using BFDOT we can get the promotion for "free" whilst partially reducing to a vector of floats, which can then be reduced and rounded as is done today.

---
Full diff: https://github.com/llvm/llvm-project/pull/147981.diff


2 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+19) 
- (modified) llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll (+19-11) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 331c8036e26f1..49e7ba49ff668 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1784,6 +1784,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
     }
 
+    if (Subtarget->hasBF16())
+      setOperationAction(ISD::VECREDUCE_FADD, MVT::nxv8bf16, Custom);
+
     if (!Subtarget->hasSVEB16B16()) {
       for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
                           ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
@@ -16063,6 +16066,22 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
     if (SrcVT.getVectorElementType() == MVT::i1)
       return LowerPredReductionToSVE(Op, DAG);
 
+    if (SrcVT == MVT::nxv8bf16 && Op.getOpcode() == ISD::VECREDUCE_FADD) {
+      assert(Subtarget->hasBF16() &&
+             "VECREDUCE custom lowering expected +bf16!");
+      SDLoc DL(Op);
+      SDValue ID =
+          DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
+      SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+      SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+      // Use BFDOT's implicitly promotion to float with partial reduction.
+      SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
+                                  Zero, Src, One);
+      SDValue FADDV = DAG.getNode(ISD::VECREDUCE_FADD, DL, MVT::f32, BFDOT);
+      return DAG.getNode(ISD::FP_ROUND, DL, MVT::bf16, FADDV,
+                         DAG.getTargetConstant(0, DL, MVT::i64));
+    }
+
     switch (Op.getOpcode()) {
     case ISD::VECREDUCE_ADD:
       return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index 7f79c9c5431ea..d05b66b5842fd 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -31,17 +31,25 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
 }
 
 define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
-; CHECK-LABEL: faddv_nxv8bf16:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    uunpkhi z1.s, z0.h
-; CHECK-NEXT:    uunpklo z0.s, z0.h
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    lsl z1.s, z1.s, #16
-; CHECK-NEXT:    lsl z0.s, z0.s, #16
-; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
-; CHECK-NEXT:    faddv s0, p0, z0.s
-; CHECK-NEXT:    bfcvt h0, s0
-; CHECK-NEXT:    ret
+; SVE-LABEL: faddv_nxv8bf16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v1.2d, #0000000000000000
+; SVE-NEXT:    fmov z2.s, #1.00000000
+; SVE-NEXT:    ptrue p0.s
+; SVE-NEXT:    bfdot z1.s, z0.h, z2.h
+; SVE-NEXT:    faddv s0, p0, z1.s
+; SVE-NEXT:    bfcvt h0, s0
+; SVE-NEXT:    ret
+;
+; SME-LABEL: faddv_nxv8bf16:
+; SME:       // %bb.0:
+; SME-NEXT:    fmov z1.s, #1.00000000
+; SME-NEXT:    mov z2.s, #0 // =0x0
+; SME-NEXT:    ptrue p0.s
+; SME-NEXT:    bfdot z2.s, z0.h, z1.h
+; SME-NEXT:    faddv s0, p0, z2.s
+; SME-NEXT:    bfcvt h0, s0
+; SME-NEXT:    ret
   %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
   ret bfloat %res
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/147981


More information about the llvm-commits mailing list