[llvm] [RISCV] Generaize reduction tree matching to fp sum reductions (PR #68599)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 9 08:39:32 PDT 2023


https://github.com/preames created https://github.com/llvm/llvm-project/pull/68599

This builds on the transform introduced in f0505c3, and generalizes to all integer operations in 45a334d.  This change adds support for floating point sumation.

A couple of notes:
* I chose to leave fmaxnum and fminnum unhandled for the moment.  They have a slightly different set of legality rules.
* We could form strictly sequenced FADD reductions for FADDs without fast math flags.  As the ordered reductions are more expensive, I left thinking about this as a future exercise.
* This can't yet match the full vector reduce + start value idiom.  That will be an upcoming set of changes.

>From 2eb5c69fb34437fb1a349c25a6e9097dd3d4ebd7 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 9 Oct 2023 08:14:32 -0700
Subject: [PATCH] [RISCV] Generaize reduction tree matching to fp sum
 reductions

This builds on the transform introduced in f0505c3, and generalizes to all
integer operations in 45a334d.  This change adds support for floating point
sumation.

A couple of notes:
* I chose to leave fmaxnum and fminnum unhandled for the moment.  They have
  a slightly different set of legality rules.
* We could form strictly sequenced FADD reductions for FADDs without fast
  math flags.  As the ordered reductions are more expensive, I left thinking
  about this as a future exercise.
* This can't yet match the full vector reduce + start value idiom.  That will
  be an upcoming set of changes.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  21 ++-
 .../rvv/fixed-vectors-reduction-formation.ll  | 159 ++++++++++++++++++
 2 files changed, 174 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bd4150c87eabbd0..94c1c52a28462b7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11114,7 +11114,7 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
-/// Given an integer binary operator, return the generic ISD::VECREDUCE_OP
+/// Given a binary operator, return the *associative* generic ISD::VECREDUCE_OP
 /// which corresponds to it.
 static unsigned getVecReduceOpcode(unsigned Opc) {
   switch (Opc) {
@@ -11136,6 +11136,9 @@ static unsigned getVecReduceOpcode(unsigned Opc) {
     return ISD::VECREDUCE_OR;
   case ISD::XOR:
     return ISD::VECREDUCE_XOR;
+  case ISD::FADD:
+    // Note: This is the associative form of the generic reduction opcode.
+    return ISD::VECREDUCE_FADD;
   }
 }
 
@@ -11162,12 +11165,16 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
 
   const SDLoc DL(N);
   const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
 
-  // TODO: Handle floating point here.
-  if (!VT.isInteger())
+  // For FADD, we only handle the case with reassociation allowed.  We
+  // could handle strict reduction order, but at the moment, there's no
+  // known reason to, and the complexity isn't worth it.
+  // TODO: Handle fminnum and fmaxnum here
+  if (!VT.isInteger() &&
+      (Opc != ISD::FADD || !N->getFlags().hasAllowReassociation()))
     return SDValue();
 
-  const unsigned Opc = N->getOpcode();
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
@@ -11200,7 +11207,7 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
     EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
     SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                               DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec);
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11222,7 +11229,9 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
       EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
       SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
                                 DAG.getVectorIdxConstant(0, DL));
-      return DAG.getNode(ReduceOpc, DL, VT, Vec);
+      auto Flags = ReduceVec->getFlags();
+      Flags.intersectWith(N->getFlags());
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
     }
   }
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index dd9a1118ab821d4..76df097a7697162 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -764,6 +764,165 @@ define i32 @reduce_umin_16xi32_prefix5(ptr %p) {
   %umin3 = call i32 @llvm.umin.i32(i32 %umin2, i32 %e4)
   ret i32 %umin3
 }
+
+define float @reduce_fadd_16xf32_prefix2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xf32_prefix2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %fadd0 = fadd fast float %e0, %e1
+  ret float %fadd0
+}
+
+define float @reduce_fadd_16xi32_prefix5(ptr %p) {
+; CHECK-LABEL: reduce_fadd_16xi32_prefix5:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, 524288
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vmv.s.x v10, a1
+; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 5
+; CHECK-NEXT:    vsetivli zero, 7, e32, m2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 6
+; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 7
+; CHECK-NEXT:    vfredusum.vs v8, v8, v10
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <16 x float>, ptr %p, align 256
+  %e0 = extractelement <16 x float> %v, i32 0
+  %e1 = extractelement <16 x float> %v, i32 1
+  %e2 = extractelement <16 x float> %v, i32 2
+  %e3 = extractelement <16 x float> %v, i32 3
+  %e4 = extractelement <16 x float> %v, i32 4
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  %fadd3 = fadd fast float %fadd2, %e4
+  ret float %fadd3
+}
+
+;; Corner case tests for fadd associativity
+
+; Negative test, not associative.  Would need strict opcode.
+define float @reduce_fadd_2xf32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd float %e0, %e1
+  ret float %fadd0
+}
+
+; Positive test - minimal set of fast math flags
+define float @reduce_fadd_2xf32_reassoc_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_reassoc_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa0, v8
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd reassoc float %e0, %e1
+  ret float %fadd0
+}
+
+; Negative test - wrong fast math flag.
+define float @reduce_fadd_2xf32_ninf_only(ptr %p) {
+; CHECK-LABEL: reduce_fadd_2xf32_ninf_only:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v8, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <2 x float>, ptr %p, align 256
+  %e0 = extractelement <2 x float> %v, i32 0
+  %e1 = extractelement <2 x float> %v, i32 1
+  %fadd0 = fadd ninf float %e0, %e1
+  ret float %fadd0
+}
+
+
+; Negative test - last fadd is not associative
+define float @reduce_fadd_4xi32_non_associative(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vslidedown.vi v9, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa5, v9
+; CHECK-NEXT:    lui a0, 524288
+; CHECK-NEXT:    vmv.s.x v9, a0
+; CHECK-NEXT:    vslideup.vi v8, v9, 3
+; CHECK-NEXT:    vfredusum.vs v8, v8, v9
+; CHECK-NEXT:    vfmv.f.s fa4, v8
+; CHECK-NEXT:    fadd.s fa0, fa4, fa5
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd fast float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd float %fadd1, %e3
+  ret float %fadd2
+}
+
+; Negative test - first fadd is not associative
+; We could form a reduce for elements 2 and 3.
+define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
+; CHECK-LABEL: reduce_fadd_4xi32_non_associative2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v8, (a0)
+; CHECK-NEXT:    vfmv.f.s fa5, v8
+; CHECK-NEXT:    vslidedown.vi v9, v8, 1
+; CHECK-NEXT:    vfmv.f.s fa4, v9
+; CHECK-NEXT:    vslidedown.vi v9, v8, 2
+; CHECK-NEXT:    vfmv.f.s fa3, v9
+; CHECK-NEXT:    vslidedown.vi v8, v8, 3
+; CHECK-NEXT:    vfmv.f.s fa2, v8
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fadd.s fa4, fa3, fa2
+; CHECK-NEXT:    fadd.s fa0, fa5, fa4
+; CHECK-NEXT:    ret
+  %v = load <4 x float>, ptr %p, align 256
+  %e0 = extractelement <4 x float> %v, i32 0
+  %e1 = extractelement <4 x float> %v, i32 1
+  %e2 = extractelement <4 x float> %v, i32 2
+  %e3 = extractelement <4 x float> %v, i32 3
+  %fadd0 = fadd float %e0, %e1
+  %fadd1 = fadd fast float %fadd0, %e2
+  %fadd2 = fadd fast float %fadd1, %e3
+  ret float %fadd2
+}
+
+
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; RV32: {{.*}}
 ; RV64: {{.*}}



More information about the llvm-commits mailing list