[llvm] [LLVM][CodeGen][SVE] Add isel for bfloat unordered reductions. (PR #143540)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 07:24:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

The omissions are VECREDUCE_SEQ_* and MUL. The former goes down a different code path and the latter is generally unsupported across all element types.

A future extension is to use BFDOT for add reductions when available, especially for the nxv8bf16 case.


---
Full diff: https://github.com/llvm/llvm-project/pull/143540.diff


4 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+27-9) 
- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+10-5) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+3-1) 
- (added) llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll (+235) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 4a1cd642233ef..1fc5fc66c56e5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -188,6 +188,7 @@ class VectorLegalizer {
   void PromoteSETCC(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   void PromoteSTRICT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
+  void PromoteVECREDUCE(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
 public:
   VectorLegalizer(SelectionDAG& dag) :
@@ -500,20 +501,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   case ISD::VECREDUCE_UMAX:
   case ISD::VECREDUCE_UMIN:
   case ISD::VECREDUCE_FADD:
-  case ISD::VECREDUCE_FMUL:
-  case ISD::VECTOR_FIND_LAST_ACTIVE:
-    Action = TLI.getOperationAction(Node->getOpcode(),
-                                    Node->getOperand(0).getValueType());
-    break;
   case ISD::VECREDUCE_FMAX:
-  case ISD::VECREDUCE_FMIN:
   case ISD::VECREDUCE_FMAXIMUM:
+  case ISD::VECREDUCE_FMIN:
   case ISD::VECREDUCE_FMINIMUM:
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECTOR_FIND_LAST_ACTIVE:
     Action = TLI.getOperationAction(Node->getOpcode(),
                                     Node->getOperand(0).getValueType());
-    // Defer non-vector results to LegalizeDAG.
-    if (Action == TargetLowering::Promote)
-      Action = TargetLowering::Legal;
     break;
   case ISD::VECREDUCE_SEQ_FADD:
   case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +683,22 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
   Results.push_back(Round.getValue(1));
 }
 
+void VectorLegalizer::PromoteVECREDUCE(SDNode *Node,
+                                       SmallVectorImpl<SDValue> &Results) {
+  MVT OpVT = Node->getOperand(0).getSimpleValueType();
+  assert(OpVT.isFloatingPoint() && "Expected floating point reduction!");
+  MVT NewOpVT = TLI.getTypeToPromoteTo(Node->getOpcode(), OpVT);
+
+  SDLoc DL(Node);
+  SDValue NewOp = DAG.getNode(ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand(0));
+  SDValue Rdx =
+      DAG.getNode(Node->getOpcode(), DL, NewOpVT.getVectorElementType(), NewOp,
+                  Node->getFlags());
+  SDValue Res = DAG.getNode(ISD::FP_ROUND, DL, Node->getValueType(0), Rdx,
+                            DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
+  Results.push_back(Res);
+}
+
 void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   // For a few operations there is a specific concept for promotion based on
   // the operand's type.
@@ -719,6 +730,13 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
   case ISD::STRICT_FMA:
     PromoteSTRICT(Node, Results);
     return;
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_FMAX:
+  case ISD::VECREDUCE_FMAXIMUM:
+  case ISD::VECREDUCE_FMIN:
+  case ISD::VECREDUCE_FMINIMUM:
+    PromoteVECREDUCE(Node, Results);
+    return;
   case ISD::FP_ROUND:
   case ISD::FP_EXTEND:
     // These operations are used to do promotion so they can't be promoted
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index a0ffb4b6d5a4c..0d23666383cda 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -11412,13 +11412,9 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
   SDValue Op = Node->getOperand(0);
   EVT VT = Op.getValueType();
 
-  if (VT.isScalableVector())
-    report_fatal_error(
-        "Expanding reductions for scalable vectors is undefined.");
-
   // Try to use a shuffle reduction for power of two vectors.
   if (VT.isPow2VectorType()) {
-    while (VT.getVectorNumElements() > 1) {
+    while (VT.getVectorElementCount().isKnownMultipleOf(2)) {
       EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
       if (!isOperationLegalOrCustom(BaseOpcode, HalfVT))
         break;
@@ -11427,9 +11423,18 @@ SDValue TargetLowering::expandVecReduce(SDNode *Node, SelectionDAG &DAG) const {
       std::tie(Lo, Hi) = DAG.SplitVector(Op, dl);
       Op = DAG.getNode(BaseOpcode, dl, HalfVT, Lo, Hi, Node->getFlags());
       VT = HalfVT;
+
+      // Stop if splitting is enough to make the reduction legal.
+      if (isOperationLegalOrCustom(Node->getOpcode(), HalfVT))
+        return DAG.getNode(Node->getOpcode(), dl, Node->getValueType(0), Op,
+                           Node->getFlags());
     }
   }
 
+  if (VT.isScalableVector())
+    report_fatal_error(
+        "Expanding reductions for scalable vectors is undefined.");
+
   EVT EltVT = VT.getVectorElementType();
   unsigned NumElts = VT.getVectorNumElements();
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index caac00c5b2faa..9322f615827d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1780,7 +1780,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     for (auto Opcode :
          {ISD::FCEIL, ISD::FDIV, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
-          ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC}) {
+          ISD::FROUND, ISD::FROUNDEVEN, ISD::FSQRT, ISD::FTRUNC, ISD::SETCC,
+          ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMAXIMUM,
+          ISD::VECREDUCE_FMIN, ISD::VECREDUCE_FMINIMUM}) {
       setOperationPromotedToType(Opcode, MVT::nxv2bf16, MVT::nxv2f32);
       setOperationPromotedToType(Opcode, MVT::nxv4bf16, MVT::nxv4f32);
       setOperationPromotedToType(Opcode, MVT::nxv8bf16, MVT::nxv8f32);
diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
new file mode 100644
index 0000000000000..eb462c780437f
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-bf16-reductions.ll
@@ -0,0 +1,235 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve,+bf16            < %s | FileCheck %s
+; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+; FADDV
+
+define bfloat @faddv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    faddv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call fast bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat zeroinitializer, <vscale x 2 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @faddv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    faddv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call fast bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat zeroinitializer, <vscale x 4 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @faddv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: faddv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fadd z0.s, z0.s, z1.s
+; CHECK-NEXT:    faddv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call fast bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat zeroinitializer, <vscale x 8 x bfloat> %a)
+  ret bfloat %res
+}
+
+; FMAXNMV
+
+define bfloat @fmaxv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmaxnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fmaxnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fmaxv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaxv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fmaxnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    fmaxnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat> %a)
+  ret bfloat %res
+}
+
+; FMINNMV
+
+define bfloat @fminv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fminnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fminv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fminnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fminv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fminnm z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    fminnmv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat> %a)
+  ret bfloat %res
+}
+
+; FMAXV
+
+define bfloat @fmaximumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fmaxv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fmaxv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fmaximumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fmaximumv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fmax z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    fmaxv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat> %a)
+  ret bfloat %res
+}
+
+; FMINV
+
+define bfloat @fminimumv_nxv2bf16(<vscale x 2 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv2bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fminv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv4bf16(<vscale x 4 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv4bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fminv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat> %a)
+  ret bfloat %res
+}
+
+define bfloat @fminimumv_nxv8bf16(<vscale x 8 x bfloat> %a) {
+; CHECK-LABEL: fminimumv_nxv8bf16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    lsl z1.s, z1.s, #16
+; CHECK-NEXT:    lsl z0.s, z0.s, #16
+; CHECK-NEXT:    fmin z0.s, p0/m, z0.s, z1.s
+; CHECK-NEXT:    fminv s0, p0, z0.s
+; CHECK-NEXT:    bfcvt h0, s0
+; CHECK-NEXT:    ret
+  %res = call bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat> %a)
+  ret bfloat %res
+}
+
+declare bfloat @llvm.vector.reduce.fadd.nxv2bf16(bfloat, <vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv4bf16(bfloat, <vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fadd.nxv8bf16(bfloat, <vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmax.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmax.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmin.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmin.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fmaximum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fmaximum.nxv8bf16(<vscale x 8 x bfloat>)
+
+declare bfloat @llvm.vector.reduce.fminimum.nxv2bf16(<vscale x 2 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv4bf16(<vscale x 4 x bfloat>)
+declare bfloat @llvm.vector.reduce.fminimum.nxv8bf16(<vscale x 8 x bfloat>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/143540


More information about the llvm-commits mailing list