[llvm] 08b20d8 - [RISCV] Generaize reduction tree matching to fp sum reductions (#68599)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 9 11:31:47 PDT 2023


Author: Philip Reames
Date: 2023-10-09T11:31:41-07:00
New Revision: 08b20d838549e8ff3939c6793b5067b0d9b36337

URL: https://github.com/llvm/llvm-project/commit/08b20d838549e8ff3939c6793b5067b0d9b36337
DIFF: https://github.com/llvm/llvm-project/commit/08b20d838549e8ff3939c6793b5067b0d9b36337.diff

LOG: [RISCV] Generaize reduction tree matching to fp sum reductions (#68599)

This builds on the transform introduced in f0505c3, and generalizes to
all integer operations in 45a334d. This change adds support for floating
point sumation.

A couple of notes:
* I chose to leave fmaxnum and fminnum unhandled for the moment. They
have a slightly different set of legality rules.
* We could form strictly sequenced FADD reductions for FADDs without
fast math flags. As the ordered reductions are more expensive, I left
thinking about this as a future exercise.
* This can't yet match the full vector reduce + start value idiom. That
will be an upcoming set of changes.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b66a596f1cbe692..6be3fa71479be5c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11299,7 +11299,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
-/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
+/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
 /// which corresponds to it.
 static unsigned getVecReduceOpcode(unsigned Opc) {
   switch (Opc) {
@@ -11321,6 +11321,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
     return ISD::VECREDUCE_OR;
   case ISD::XOR:
     return ISD::VECREDUCE_XOR;
+  case ISD::FADD:
+    // Note: This is the associative form of the generic reduction opcode.
+    return ISD::VECREDUCE_FADD;
   }
 }
 
@@ -11347,12 +11350,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
 
   const SDLoc DL(N);
   const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
 
-  // TODO: Handle floating point here.
-  if (!VT.isInteger())
+  // For FADD, we only handle the case with reassociation allowed.  We
+  // could handle strict reduction order, but at the moment, there's no
+  // known reason to, and the complexity isn't worth it.
+  // TODO: Handle fminnum and fmaxnum here
+  if (!VT.isInteger() &&
+      (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
     return SDValue();
 
-  const unsigned Opc = N->getOpcode();
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
@@ -11385,7 +11392,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
     EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
     SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                               DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec);
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11407,7 +11414,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                                 DAG.getVectorIdxConstant(0, DL));
-      return DAG.getNode(ReduceOpc, DL, VT, Vec);
+      auto Flags = ReduceVec->getFlags();
+      Flags.intersectWith(N->getFlags());
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
     }
   }
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index dd9a1118ab821d4..76df097a7697162 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
   %umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
   ret i32 %umin3
 }
+
+define float @reduce_fadd_16xf32_prefix2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %fadd0 = fadd fast float %e0, %e1
+  ret float %fadd0
+}
+
+define float @reduce_fadd_16xi32_prefix5(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, 524288
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v10, a1
+; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 5
+; CHECK-NEXT:    vsetivli zero, 7, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 6
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 7
+; CHECK-NEXT:    vfredusum.vs v8, v8, v10
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %e2 = extractelement <16 x float> %v, i32 2
+  %e3 = extractelement <16 x float> %v, i32 3
+  %e4 = extractelement <16 x float> %v, i32 4
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  %fadd3 = fadd fast float %fadd2, %e4
+  ret float %fadd3
+}
+
+;; Corner case tests for fadd associativity
+
+; Negative test, not associative.  Would need strict opcode.
+define float @reduce_fadd_2xf32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd float %e0, %e1
+  ret float %fadd0
+}
+
+; Positive test - minimal set of fast math flags
+define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd reassoc float %e0, %e1
+  ret float %fadd0
+}
+
+; Negative test - wrong fast math flag.
+define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd ninf float %e0, %e1
+  ret float %fadd0
+}
+
+
+; Negative test - last fadd is not associative
+define float @reduce_fadd_4xi32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vslidedown.vi v9, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa5, v9
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vslideup.vi v8, v9, 3
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa4, fa5
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd float %fadd1, %e3
+  ret float %fadd2
+}
+
+; Negative test - first fadd is not associative
+; We could form a reduce for elements 2 and 3.
+define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v9, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v9
+; CHECK-NEXT:    vslidedown.vi v9, v8, 2
+; CHECK-NEXT:    vfmv.f.s fa3, v9
+; CHECK-NEXT:    vslidedown.vi v8, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa2, v8
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fadd.s fa4, fa3, fa2
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  ret float %fadd2
+}
+
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV32: {{.*}}
 ; RV64: {{.*}}


        


More information about the llvm-commits mailing list