[llvm] [LLVM][CodeGen][SVE] Use BFDOT for fadd reductions. (PR #147981)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 09:49:03 PDT 2025
https://github.com/paulwalker-arm updated https://github.com/llvm/llvm-project/pull/147981
>From 57d7a097f30b91ef5f8941844dbf3a45c7568c75 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Thu, 10 Jul 2025 15:40:13 +0100
Subject: [PATCH 1/2] [LLVM][CodeGen][SVE] Use BFDOT for reductions.
We typically lower bfloat add reductions by promoting the input to
a float vector, reducing using FADDV before rounding the result. By
using BFDOT we can get the promotion for "free" whilst partially
reducing to a vector of floats, which can then be reduced and rounded
as is done today.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 19 ++++++++++++
.../CodeGen/AArch64/sve-bf16-reductions.ll | 30 ++++++++++++-------
2 files changed, 38 insertions(+), 11 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 331c8036e26f1..49e7ba49ff668 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1784,6 +1784,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
}
+ if (Subtarget->hasBF16())
+ setOperationAction(ISD::VECREDUCE_FADD, MVT::nxv8bf16, Custom);
+
if (!Subtarget->hasSVEB16B16()) {
for (auto Opcode : {ISD::FADD, ISD::FMA, ISD::FMAXIMUM, ISD::FMAXNUM,
ISD::FMINIMUM, ISD::FMINNUM, ISD::FMUL, ISD::FSUB}) {
@@ -16063,6 +16066,22 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
if (SrcVT.getVectorElementType() == MVT::i1)
return LowerPredReductionToSVE(Op, DAG);
+ if (SrcVT == MVT::nxv8bf16 && Op.getOpcode() == ISD::VECREDUCE_FADD) {
+ assert(Subtarget->hasBF16() &&
+ "VECREDUCE custom lowering expected +bf16!");
+ SDLoc DL(Op);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
+ SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
+ SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+ // Use BFDOT's implicitly promotion to float with partial reduction.
+ SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
+ Zero, Src, One);
+ SDValue FADDV = DAG.getNode(ISD::VECREDUCE_FADD, DL, MVT::f32, BFDOT);
+ return DAG.getNode(ISD::FP_ROUND, DL, MVT::bf16, FADDV,
+ DAG.getTargetConstant(0, DL, MVT::i64));
+ }
+
switch (Op.getOpcode()) {
case ISD::VECREDUCE_ADD:
return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index 7f79c9c5431ea..d05b66b5842fd 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -31,17 +31,25 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
}
define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
-; CHECK-LABEL: faddv_nxv8bf16:
-; CHECK: // %bb.0:
-; CHECK-NEXT: uunpkhi z1.s, z0.h
-; CHECK-NEXT: uunpklo z0.s, z0.h
-; CHECK-NEXT: ptrue p0.s
-; CHECK-NEXT: lsl z1.s, z1.s, #16
-; CHECK-NEXT: lsl z0.s, z0.s, #16
-; CHECK-NEXT: fadd z0.s, z0.s, z1.s
-; CHECK-NEXT: faddv s0, p0, z0.s
-; CHECK-NEXT: bfcvt h0, s0
-; CHECK-NEXT: ret
+; SVE-LABEL: faddv_nxv8bf16:
+; SVE: // %bb.0:
+; SVE-NEXT: movi v1.2d, #0000000000000000
+; SVE-NEXT: fmov z2.s, #1.00000000
+; SVE-NEXT: ptrue p0.s
+; SVE-NEXT: bfdot z1.s, z0.h, z2.h
+; SVE-NEXT: faddv s0, p0, z1.s
+; SVE-NEXT: bfcvt h0, s0
+; SVE-NEXT: ret
+;
+; SME-LABEL: faddv_nxv8bf16:
+; SME: // %bb.0:
+; SME-NEXT: fmov z1.s, #1.00000000
+; SME-NEXT: mov z2.s, #0 // =0x0
+; SME-NEXT: ptrue p0.s
+; SME-NEXT: bfdot z2.s, z0.h, z1.h
+; SME-NEXT: faddv s0, p0, z2.s
+; SME-NEXT: bfcvt h0, s0
+; SME-NEXT: ret
%res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
ret bfloat %res
}
>From 6e4bf2fa64e96b822dc4784e2de613fded6c8d79 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Mon, 14 Jul 2025 16:48:39 +0000
Subject: [PATCH 2/2] Use source VT when creating neutral vector.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll | 5 +++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 49e7ba49ff668..f1fec8db1b79f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16073,7 +16073,7 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_bfdot, DL, MVT::i64);
SDValue Zero = DAG.getConstantFP(0.0, DL, MVT::nxv4f32);
- SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv4f32);
+ SDValue One = DAG.getConstantFP(1.0, DL, MVT::nxv8bf16);
// Use BFDOT's implicitly promotion to float with partial reduction.
SDValue BFDOT = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv4f32, ID,
Zero, Src, One);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
index d05b66b5842fd..3acea3a95e787 100644
--- a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -30,11 +30,12 @@ define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
ret bfloat %res
}
+; NOTE: f16(1.875) == bf16(1.0)
define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
; SVE-LABEL: faddv_nxv8bf16:
; SVE: // %bb.0:
; SVE-NEXT: movi v1.2d, #0000000000000000
-; SVE-NEXT: fmov z2.s, #1.00000000
+; SVE-NEXT: fmov z2.h, #1.87500000
; SVE-NEXT: ptrue p0.s
; SVE-NEXT: bfdot z1.s, z0.h, z2.h
; SVE-NEXT: faddv s0, p0, z1.s
@@ -43,7 +44,7 @@ define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
;
; SME-LABEL: faddv_nxv8bf16:
; SME: // %bb.0:
-; SME-NEXT: fmov z1.s, #1.00000000
+; SME-NEXT: fmov z1.h, #1.87500000
; SME-NEXT: mov z2.s, #0 // =0x0
; SME-NEXT: ptrue p0.s
; SME-NEXT: bfdot z2.s, z0.h, z1.h
More information about the llvm-commits
mailing list