[llvm] [LLVM][CodeGen][SVE] Use BFDOT for fadd reductions. (PR #147981)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 10 08:04:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: Paul Walker (paulwalker-arm)
<details>
<summary>Changes</summary>
We typically lower bfloat add reductions by promoting the input to a float vector, reducing using FADDV before rounding the result. By using BFDOT we can get the promotion for "free" whilst partially reducing to a vector of floats, which can then be reduced and rounded as is done today.
---
Full diff: https://github.com/llvm/llvm-project/pull/147981.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+19)
- (modified) llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll (+19-11)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 331c8036e26f1..49e7ba49ff668 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1784,6 +1784,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
}
+ if (Subtarget->hasBF16())
+ setOperationAction(ISD::VECREDUCE_FADD, MVT::nxv8bf16, Custom);
+
if (!Subtarget->hasSVEB16B16()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
@@ -16063,6 +16066,22 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
if (SrcVT.getVectorElementType() == MVT::i1)
return LowerPredReductionToSVE(Op, DAG);
+ if (SrcVT == MVT::nxv8bf16 && Op.getOpcode() == ISD::VECREDUCE_FADD) {
+ assert(Subtarget->hasBF16() &&
+ "VECREDUCE custom lowering expected +bf16!");
+ SDLoc DL(Op);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
+ SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+ SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+ // Use BFDOT's implicitly promotion to float with partial reduction.
+ SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
+ Zero, Src, One);
+ SDValue FADDV = DAG.getNode(ISD::VECREDUCE_FADD, DL, MVT::f32, BFDOT);
+ return DAG.getNode(ISD::FP_ROUND, DL, MVT::bf16, FADDV,
+ DAG.getTargetConstant(0, DL, MVT::i64));
+ }
+
switch (Op.getOpcode()) {
case ISD::VECREDUCE_ADD:
return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index 7f79c9c5431ea..d05b66b5842fd 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -31,17 +31,25 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
}
define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
-; CHECK-LABEL: faddv_nxv8bf16:
-; CHECK: // %bb.0:
-; CHECK-NEXT: uunpkhi z1.s, z0.h
-; CHECK-NEXT: uunpklo z0.s, z0.h
-; CHECK-NEXT: ptrue p0.s
-; CHECK-NEXT: lsl z1.s, z1.s, #16
-; CHECK-NEXT: lsl z0.s, z0.s, #16
-; CHECK-NEXT: fadd z0.s, z0.s, z1.s
-; CHECK-NEXT: faddv s0, p0, z0.s
-; CHECK-NEXT: bfcvt h0, s0
-; CHECK-NEXT: ret
+; SVE-LABEL: faddv_nxv8bf16:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v1.2d, #0000000000000000
+; SVE-NEXT: fmov z2.s, #1.00000000
+; SVE-NEXT: ptrue p0.s
+; SVE-NEXT: bfdot z1.s, z0.h, z2.h
+; SVE-NEXT: faddv s0, p0, z1.s
+; SVE-NEXT: bfcvt h0, s0
+; SVE-NEXT: ret
+;
+; SME-LABEL: faddv_nxv8bf16:
+; SME: // %bb.0:
+; SME-NEXT: fmov z1.s, #1.00000000
+; SME-NEXT: mov z2.s, #0 // =0x0
+; SME-NEXT: ptrue p0.s
+; SME-NEXT: bfdot z2.s, z0.h, z1.h
+; SME-NEXT: faddv s0, p0, z2.s
+; SME-NEXT: bfcvt h0, s0
+; SME-NEXT: ret
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
ret bfloat %res
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/147981
More information about the llvm-commits
mailing list