[llvm] [LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. (PR #143540)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 07:24:18 PDT 2025
https://github.com/paulwalker-arm created https://github.com/llvm/llvm-project/pull/143540
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupported across all element types.
A future extension is to use BFDOT for add reductions when available, especially for the nxv8bf16 case.
>From 50a32d242e9c2f7c035863353a89b98f61b0e758 Mon Sep 17 00:00:00 2001
From: Paul Walker <paul.walker at arm.com>
Date: Tue, 10 Jun 2025 14:57:13 +0100
Subject: [PATCH] [LLVM][CodeGen][SVE] Add isel for bfloat unordered
reductions.
The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a
different code path and the latter is generally unsupport across all
element types.
---
.../SelectionDAG/LegalizeVectorOps.cpp | 36 ++-
.../CodeGen/SelectionDAG/TargetLowering.cpp | 15 +-
.../Target/AArch64/AArch64ISelLowering.cpp | 4 +-
.../CodeGen/AArch64/sve-bf16-reductions.ll | 235 ++++++++++++++++++
4 files changed, 275 insertions(+), 15 deletions(-)
create mode 100644 llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 4a1cd642233ef..1fc5fc66c56e5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -188,6 +188,7 @@ class VectorLegalizer {
void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+ void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
public:
VectorLegalizer(SelectionDAG& dag) :
@@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
case ISD::VECREDUCE_FADD:
- case ISD::VECREDUCE_FMUL:
- case ISD::VECTOR_FIND_LAST_ACTIVE:
- Action = TLI.getOperationAction(Node->getOpcode(),
- Node->getOperand(0).getValueType());
- break;
case ISD::VECREDUCE_FMAX:
- case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
+ case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMINIMUM:
+ case ISD::VECREDUCE_FMUL:
+ case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
- // Defer non-vector results to LegalizeDAG.
- if (Action == TargetLowering::Promote)
- Action = TargetLowering::Legal;
break;
case ISD::VECREDUCE_SEQ_FADD:
case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
Results.push_back(Round.getValue(1));
}
+void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
+ SmallVectorImpl<SDValue> &Results) {
+ MVT OpVT = Node->getOperand(0).getSimpleValueType();
+ assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
+ MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
+
+ SDLoc DL(Node);
+ SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
+ SDValue Rdx =
+ DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
+ Node->getFlags());
+ SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
+ DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+ Results.push_back(Res);
+}
+
void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// For a few operations there is a specific concept for promotion based on
// the operand's type.
@@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::STRICT_FMA:
PromoteSTRICT(Node, Results);
return;
+ case ISD::VECREDUCE_FADD:
+ case ISD::VECREDUCE_FMAX:
+ case ISD::VECREDUCE_FMAXIMUM:
+ case ISD::VECREDUCE_FMIN:
+ case ISD::VECREDUCE_FMINIMUM:
+ PromoteVECREDUCE(Node, Results);
+ return;
case ISD::FP_ROUND:
case ISD::FP_EXTEND:
// These operations are used to do promotion so they can't be promoted
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a0ffb4b6d5a4c..0d23666383cda 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
SDValue Op = Node->getOperand(0);
EVT VT = Op.getValueType();
- if (VT.isScalableVector())
- report_fatal_error(
- "Expanding reductions for scalable vectors is undefined.");
-
// Try to use a shuffle reduction for power of two vectors.
if (VT.isPow2VectorType()) {
- while (VT.getVectorNumElements() > 1) {
+ while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
break;
@@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
VT = HalfVT;
+
+ // Stop if splitting is enough to make the reduction legal.
+ if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
+ return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
+ Node->getFlags());
}
}
+ if (VT.isScalableVector())
+ report_fatal_error(
+ "Expanding reductions for scalable vectors is undefined.");
+
EVT EltVT = VT.getVectorElementType();
unsigned NumElts = VT.getVectorNumElements();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index caac00c5b2faa..9322f615827d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (auto Opcode :
{ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
- ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
+ ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
+ ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
+ ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
new file mode 100644
index 0000000000000..eb462c780437f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -0,0 +1,235 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16 < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; FADDV
+
+define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fadd z0.s, z0.s, z1.s
+; CHECK-NEXT: faddv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMAXNMV
+
+define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmaxnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fmaxnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMINNMV
+
+define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fminnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fminnmv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMAXV
+
+define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmax z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fmaxv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+; FMINV
+
+define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv2bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.d
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv4bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a)
+ ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv8bf16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: uunpklo z0.s, z0.h
+; CHECK-NEXT: ptrue p0.s
+; CHECK-NEXT: lsl z1.s, z1.s, #16
+; CHECK-NEXT: lsl z0.s, z0.s, #16
+; CHECK-NEXT: fmin z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT: fminv s0, p0, z0.s
+; CHECK-NEXT: bfcvt h0, s0
+; CHECK-NEXT: ret
+ %res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a)
+ ret bfloat %res
+}
+
+declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>)
More information about the llvm-commits
mailing list