[llvm] [LLVM][CodeGen][SVE] Use BFDOT for fadd reductions. (PR #147981)

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 09:49:03 PDT 2025


https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/147981

>From 57d7a097f30b91ef5f8941844dbf3a45c7568c75 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Thu, 10 Jul 2025 15:40:13 +0100
Subject: [PATCH 1/2] [LLVM][CodeGen][SVE] Use BFDOT for reductions.

We typically lower bfloat add reductions by promoting the input to
a float vector, reducing using FADDV before rounding the result. By
using BFDOT we can get the promotion for "free" whilst partially
reducing to a vector of floats, which can then be reduced and rounded
as is done today.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 19 ++++++++++++
 .../CodeGen/AArch64/sve-bf16-reductions.ll    | 30 ++++++++++++-------
 2 files changed, 38 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 331c8036e26f1..49e7ba49ff668 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1784,6 +1784,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
     }
 
+    if (Subtarget->hasBF16())
+      setOperationAction(ISD::VECREDUCE_FADD, MVT::nxv8bf16, Custom);
+
     if (!Subtarget->hasSVEB16B16()) {
       for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
                           ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
@@ -16063,6 +16066,22 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
     if (SrcVT.getVectorElementType() == MVT::i1)
       return LowerPredReductionToSVE(Op, DAG);
 
+    if (SrcVT == MVT::nxv8bf16 && Op.getOpcode() == ISD::VECREDUCE_FADD) {
+      assert(Subtarget->hasBF16() &&
+             "VECREDUCE custom lowering expected +bf16!");
+      SDLoc DL(Op);
+      SDValue ID =
+          DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
+      SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+      SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+      // Use BFDOT's implicitly promotion to float with partial reduction.
+      SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
+                                  Zero, Src, One);
+      SDValue FADDV = DAG.getNode(ISD::VECREDUCE_FADD, DL, MVT::f32, BFDOT);
+      return DAG.getNode(ISD::FP_ROUND, DL, MVT::bf16, FADDV,
+                         DAG.getTargetConstant(0, DL, MVT::i64));
+    }
+
     switch (Op.getOpcode()) {
     case ISD::VECREDUCE_ADD:
       return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index 7f79c9c5431ea..d05b66b5842fd 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -31,17 +31,25 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
 }
 
 define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
-; CHECK-LABEL: faddv_nxv8bf16:
-; CHECK:       // %bb.0:
-; CHECK-NEXT:    uunpkhi z1.s, z0.h
-; CHECK-NEXT:    uunpklo z0.s, z0.h
-; CHECK-NEXT:    ptrue p0.s
-; CHECK-NEXT:    lsl z1.s, z1.s, #16
-; CHECK-NEXT:    lsl z0.s, z0.s, #16
-; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
-; CHECK-NEXT:    faddv s0, p0, z0.s
-; CHECK-NEXT:    bfcvt h0, s0
-; CHECK-NEXT:    ret
+; SVE-LABEL: faddv_nxv8bf16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    movi v1.2d, #0000000000000000
+; SVE-NEXT:    fmov z2.s, #1.00000000
+; SVE-NEXT:    ptrue p0.s
+; SVE-NEXT:    bfdot z1.s, z0.h, z2.h
+; SVE-NEXT:    faddv s0, p0, z1.s
+; SVE-NEXT:    bfcvt h0, s0
+; SVE-NEXT:    ret
+;
+; SME-LABEL: faddv_nxv8bf16:
+; SME:       // %bb.0:
+; SME-NEXT:    fmov z1.s, #1.00000000
+; SME-NEXT:    mov z2.s, #0 // =0x0
+; SME-NEXT:    ptrue p0.s
+; SME-NEXT:    bfdot z2.s, z0.h, z1.h
+; SME-NEXT:    faddv s0, p0, z2.s
+; SME-NEXT:    bfcvt h0, s0
+; SME-NEXT:    ret
   %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
   ret bfloat %res
 }

>From 6e4bf2fa64e96b822dc4784e2de613fded6c8d79 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 14 Jul 2025 16:48:39 +0000
Subject: [PATCH 2/2] Use source VT when creating neutral vector.

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp  | 2 +-
 llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 49e7ba49ff668..f1fec8db1b79f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16073,7 +16073,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
       SDValue ID =
           DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
       SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
-      SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+      SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv8bf16);
       // Use BFDOT's implicitly promotion to float with partial reduction.
       SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
                                   Zero, Src, One);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index d05b66b5842fd..3acea3a95e787 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -30,11 +30,12 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
   ret bfloat %res
 }
 
+; NOTE: f16(1.875) == bf16(1.0)
 define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
 ; SVE-LABEL: faddv_nxv8bf16:
 ; SVE:       // %bb.0:
 ; SVE-NEXT:    movi v1.2d, #0000000000000000
-; SVE-NEXT:    fmov z2.s, #1.00000000
+; SVE-NEXT:    fmov z2.h, #1.87500000
 ; SVE-NEXT:    ptrue p0.s
 ; SVE-NEXT:    bfdot z1.s, z0.h, z2.h
 ; SVE-NEXT:    faddv s0, p0, z1.s
@@ -43,7 +44,7 @@ define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
 ;
 ; SME-LABEL: faddv_nxv8bf16:
 ; SME:       // %bb.0:
-; SME-NEXT:    fmov z1.s, #1.00000000
+; SME-NEXT:    fmov z1.h, #1.87500000
 ; SME-NEXT:    mov z2.s, #0 // =0x0
 ; SME-NEXT:    ptrue p0.s
 ; SME-NEXT:    bfdot z2.s, z0.h, z1.h



More information about the llvm-commits mailing list