[llvm] [RISCV] Allow swapped operands in reduction formation (PR #68634)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 9 14:36:08 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

<details>
<summary>Changes</summary>

Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.

---
Full diff: https://github.com/llvm/llvm-project/pull/68634.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28-23) 
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+62-4) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..b0fc99f6eff860b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11363,16 +11363,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
-  const SDValue LHS = N->getOperand(0);
-  const SDValue RHS = N->getOperand(1);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
 
   if (!LHS.hasOneUse() || !RHS.hasOneUse())
     return SDValue();
 
+  if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+    std::swap(LHS, RHS);
+
   if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
       !isa<ConstantSDNode>(RHS.getOperand(1)))
     return SDValue();
 
+  uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
   SDValue SrcVec = RHS.getOperand(0);
   EVT SrcVecVT = SrcVec.getValueType();
   assert(SrcVecVT.getVectorElementType() == VT);
@@ -11385,14 +11389,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
   // reduce_op (extract_subvector [2 x VT] from V).  This will form the
   // root of our reduction tree. TODO: We could extend this to any two
-  // adjacent constant indices if desired.
+  // adjacent aligned constant indices if desired.
   if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-      LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
-      isOneConstant(RHS.getOperand(1))) {
-    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
-    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                              DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+      LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
+    uint64_t LHSIdx =
+      cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+    if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
+      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                                DAG.getVectorIdxConstant(0, DL));
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+    }
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11404,20 +11411,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   SDValue ReduceVec = LHS.getOperand(0);
   if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
       ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
-      isNullConstant(ReduceVec.getOperand(1))) {
-    uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
-    if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
-      // For illegal types (e.g. 3xi32), most will be combined again into a
-      // wider (hopefully legal) type.  If this is a terminal state, we are
-      // relying on type legalization here to produce something reasonable
-      // and this lowering quality could probably be improved. (TODO)
-      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
-      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                                DAG.getVectorIdxConstant(0, DL));
-      auto Flags = ReduceVec->getFlags();
-      Flags.intersectWith(N->getFlags());
-      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
-    }
+      isNullConstant(ReduceVec.getOperand(1)) &&
+      ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
+    // For illegal types (e.g. 3xi32), most will be combined again into a
+    // wider (hopefully legal) type.  If this is a terminal state, we are
+    // relying on type legalization here to produce something reasonable
+    // and this lowering quality could probably be improved. (TODO)
+    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
+    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                              DAG.getVectorIdxConstant(0, DL));
+    auto Flags = ReduceVec->getFlags();
+    Flags.intersectWith(N->getFlags());
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
   }
 
   return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 76df097a7697162..fd4a54b468f15fd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
   ret i32 %add2
 }
 
-
 define i32 @reduce_sum_8xi32(<8 x i32> %v) {
 ; CHECK-LABEL: reduce_sum_8xi32:
 ; CHECK:       # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
   ret i32 %add13
 }
 
+; Check that we can match with the operand ordered reversed, but the
+; reduction order unchanged.
+define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
+; CHECK-LABEL: reduce_sum_4xi32_op_order:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e0
+  %add1 = add i32 %e2, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
+; Negative test - Reduction order isn't compatibile with current
+; incremental matching scheme.
+define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
+; RV32-LABEL: reduce_sum_4xi32_reduce_order:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV32-NEXT:    vmv.x.s a0, v8
+; RV32-NEXT:    vslidedown.vi v9, v8, 1
+; RV32-NEXT:    vmv.x.s a1, v9
+; RV32-NEXT:    vslidedown.vi v9, v8, 2
+; RV32-NEXT:    vmv.x.s a2, v9
+; RV32-NEXT:    vslidedown.vi v8, v8, 3
+; RV32-NEXT:    vmv.x.s a3, v8
+; RV32-NEXT:    add a1, a1, a2
+; RV32-NEXT:    add a0, a0, a3
+; RV32-NEXT:    add a0, a0, a1
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: reduce_sum_4xi32_reduce_order:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV64-NEXT:    vmv.x.s a0, v8
+; RV64-NEXT:    vslidedown.vi v9, v8, 1
+; RV64-NEXT:    vmv.x.s a1, v9
+; RV64-NEXT:    vslidedown.vi v9, v8, 2
+; RV64-NEXT:    vmv.x.s a2, v9
+; RV64-NEXT:    vslidedown.vi v8, v8, 3
+; RV64-NEXT:    vmv.x.s a3, v8
+; RV64-NEXT:    add a1, a1, a2
+; RV64-NEXT:    add a0, a0, a3
+; RV64-NEXT:    addw a0, a0, a1
+; RV64-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e2
+  %add1 = add i32 %e0, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
 ;; Most of the cornercases are exercised above, the following just
 ;; makes sure that other opcodes work as expected.
 
@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
 }
 
 
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; RV32: {{.*}}
-; RV64: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68634


More information about the llvm-commits mailing list