[llvm] [RISCV] Generaize reduction tree matching to fp sum reductions (PR #68599)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 9 08:40:49 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

<details>
<summary>Changes</summary>

This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d.  This change adds support for floating point sumation.

A couple of notes:
* I chose to leave fmaxnum and fminnum unhandled for the moment.  They have a slightly different set of legality rules.
* We could form strictly sequenced FADD reductions for FADDs without fast math flags.  As the ordered reductions are more expensive, I left thinking about this as a future exercise.
* This can't yet match the full vector reduce + start value idiom.  That will be an upcoming set of changes.

---
Full diff: https://github.com/llvm/llvm-project/pull/68599.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+15-6) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+159) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bd4150c87eabbd0..94c1c52a28462b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
-/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
+/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
 /// which corresponds to it.
 static unsigned getVecReduceOpcode(unsigned Opc) {
   switch (Opc) {
@@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
     return ISD::VECREDUCE_OR;
   case ISD::XOR:
     return ISD::VECREDUCE_XOR;
+  case ISD::FADD:
+    // Note: This is the associative form of the generic reduction opcode.
+    return ISD::VECREDUCE_FADD;
   }
 }
 
@@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
 
   const SDLoc DL(N);
   const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
 
-  // TODO: Handle floating point here.
-  if (!VT.isInteger())
+  // For FADD, we only handle the case with reassociation allowed.  We
+  // could handle strict reduction order, but at the moment, there's no
+  // known reason to, and the complexity isn't worth it.
+  // TODO: Handle fminnum and fmaxnum here
+  if (!VT.isInteger() &&
+      (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
     return SDValue();
 
-  const unsigned Opc = N->getOpcode();
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
@@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
     EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
     SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                               DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec);
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                                 DAG.getVectorIdxConstant(0, DL));
-      return DAG.getNode(ReduceOpc, DL, VT, Vec);
+      auto Flags = ReduceVec->getFlags();
+      Flags.intersectWith(N->getFlags());
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index dd9a1118ab821d4..76df097a7697162 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
   %umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
   ret i32 %umin3
 }
+
+define float @reduce_fadd_16xf32_prefix2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %fadd0 = fadd fast float %e0, %e1
+  ret float %fadd0
+}
+
+define float @reduce_fadd_16xi32_prefix5(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, 524288
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v10, a1
+; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 5
+; CHECK-NEXT:    vsetivli zero, 7, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 6
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 7
+; CHECK-NEXT:    vfredusum.vs v8, v8, v10
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %e2 = extractelement <16 x float> %v, i32 2
+  %e3 = extractelement <16 x float> %v, i32 3
+  %e4 = extractelement <16 x float> %v, i32 4
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  %fadd3 = fadd fast float %fadd2, %e4
+  ret float %fadd3
+}
+
+;; Corner case tests for fadd associativity
+
+; Negative test, not associative.  Would need strict opcode.
+define float @reduce_fadd_2xf32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd float %e0, %e1
+  ret float %fadd0
+}
+
+; Positive test - minimal set of fast math flags
+define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd reassoc float %e0, %e1
+  ret float %fadd0
+}
+
+; Negative test - wrong fast math flag.
+define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd ninf float %e0, %e1
+  ret float %fadd0
+}
+
+
+; Negative test - last fadd is not associative
+define float @reduce_fadd_4xi32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vslidedown.vi v9, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa5, v9
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vslideup.vi v8, v9, 3
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa4, fa5
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd float %fadd1, %e3
+  ret float %fadd2
+}
+
+; Negative test - first fadd is not associative
+; We could form a reduce for elements 2 and 3.
+define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v9, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v9
+; CHECK-NEXT:    vslidedown.vi v9, v8, 2
+; CHECK-NEXT:    vfmv.f.s fa3, v9
+; CHECK-NEXT:    vslidedown.vi v8, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa2, v8
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fadd.s fa4, fa3, fa2
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  ret float %fadd2
+}
+
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV32: {{.*}}
 ; RV64: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68599


More information about the llvm-commits mailing list