[llvm] 25da9bb - [RISCV] Allow swapped operands in reduction formation (#68634)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 23 10:38:01 PDT 2023


Author: Philip Reames
Date: 2023-10-23T10:37:56-07:00
New Revision: 25da9bb7d44c91b0339382af6c91b6a346685212

URL: https://github.com/llvm/llvm-project/commit/25da9bb7d44c91b0339382af6c91b6a346685212
DIFF: https://github.com/llvm/llvm-project/commit/25da9bb7d44c91b0339382af6c91b6a346685212.diff

LOG: [RISCV] Allow swapped operands in reduction formation (#68634)

Very straight forward, but worth landing on it's own in advance of a
more complicated generalization.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af52e01c27a8627..1f56ca17b785bc0 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11379,16 +11379,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   const unsigned ReduceOpc = getVecReduceOpcode(Opc);
   assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
          "Inconsistent mappings");
-  const SDValue LHS = N->getOperand(0);
-  const SDValue RHS = N->getOperand(1);
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
 
   if (!LHS.hasOneUse() || !RHS.hasOneUse())
     return SDValue();
 
+  if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+    std::swap(LHS, RHS);
+
   if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
       !isa<ConstantSDNode>(RHS.getOperand(1)))
     return SDValue();
 
+  uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
   SDValue SrcVec = RHS.getOperand(0);
   EVT SrcVecVT = SrcVec.getValueType();
   assert(SrcVecVT.getVectorElementType() == VT);
@@ -11401,14 +11405,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
   // reduce_op (extract_subvector [2 x VT] from V).  This will form the
   // root of our reduction tree. TODO: We could extend this to any two
-  // adjacent constant indices if desired.
+  // adjacent aligned constant indices if desired.
   if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-      LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
-      isOneConstant(RHS.getOperand(1))) {
-    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
-    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                              DAG.getVectorIdxConstant(0, DL));
-    return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+      LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
+    uint64_t LHSIdx =
+      cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+    if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
+      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                                DAG.getVectorIdxConstant(0, DL));
+      return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+    }
   }
 
   // Match (binop (reduce (extract_subvector V, 0),
@@ -11420,20 +11427,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
   SDValue ReduceVec = LHS.getOperand(0);
   if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
       ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
-      isNullConstant(ReduceVec.getOperand(1))) {
-    uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
-    if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
-      // For illegal types (e.g. 3xi32), most will be combined again into a
-      // wider (hopefully legal) type.  If this is a terminal state, we are
-      // relying on type legalization here to produce something reasonable
-      // and this lowering quality could probably be improved. (TODO)
-      EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
-      SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
-                                DAG.getVectorIdxConstant(0, DL));
-      auto Flags = ReduceVec->getFlags();
-      Flags.intersectWith(N->getFlags());
-      return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
-    }
+      isNullConstant(ReduceVec.getOperand(1)) &&
+      ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
+    // For illegal types (e.g. 3xi32), most will be combined again into a
+    // wider (hopefully legal) type.  If this is a terminal state, we are
+    // relying on type legalization here to produce something reasonable
+    // and this lowering quality could probably be improved. (TODO)
+    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
+    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                              DAG.getVectorIdxConstant(0, DL));
+    auto Flags = ReduceVec->getFlags();
+    Flags.intersectWith(N->getFlags());
+    return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
   }
 
   return SDValue();

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 76df097a7697162..fd4a54b468f15fd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
   ret i32 %add2
 }
 
-
 define i32 @reduce_sum_8xi32(<8 x i32> %v) {
 ; CHECK-LABEL: reduce_sum_8xi32:
 ; CHECK:       # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
   ret i32 %add13
 }
 
+; Check that we can match with the operand ordered reversed, but the
+; reduction order unchanged.
+define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
+; CHECK-LABEL: reduce_sum_4xi32_op_order:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT:    vmv.s.x v9, zero
+; CHECK-NEXT:    vredsum.vs v8, v8, v9
+; CHECK-NEXT:    vmv.x.s a0, v8
+; CHECK-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e0
+  %add1 = add i32 %e2, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
+; Negative test - Reduction order isn't compatibile with current
+; incremental matching scheme.
+define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
+; RV32-LABEL: reduce_sum_4xi32_reduce_order:
+; RV32:       # %bb.0:
+; RV32-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV32-NEXT:    vmv.x.s a0, v8
+; RV32-NEXT:    vslidedown.vi v9, v8, 1
+; RV32-NEXT:    vmv.x.s a1, v9
+; RV32-NEXT:    vslidedown.vi v9, v8, 2
+; RV32-NEXT:    vmv.x.s a2, v9
+; RV32-NEXT:    vslidedown.vi v8, v8, 3
+; RV32-NEXT:    vmv.x.s a3, v8
+; RV32-NEXT:    add a1, a1, a2
+; RV32-NEXT:    add a0, a0, a3
+; RV32-NEXT:    add a0, a0, a1
+; RV32-NEXT:    ret
+;
+; RV64-LABEL: reduce_sum_4xi32_reduce_order:
+; RV64:       # %bb.0:
+; RV64-NEXT:    vsetivli zero, 1, e32, m1, ta, ma
+; RV64-NEXT:    vmv.x.s a0, v8
+; RV64-NEXT:    vslidedown.vi v9, v8, 1
+; RV64-NEXT:    vmv.x.s a1, v9
+; RV64-NEXT:    vslidedown.vi v9, v8, 2
+; RV64-NEXT:    vmv.x.s a2, v9
+; RV64-NEXT:    vslidedown.vi v8, v8, 3
+; RV64-NEXT:    vmv.x.s a3, v8
+; RV64-NEXT:    add a1, a1, a2
+; RV64-NEXT:    add a0, a0, a3
+; RV64-NEXT:    addw a0, a0, a1
+; RV64-NEXT:    ret
+  %e0 = extractelement <4 x i32> %v, i32 0
+  %e1 = extractelement <4 x i32> %v, i32 1
+  %e2 = extractelement <4 x i32> %v, i32 2
+  %e3 = extractelement <4 x i32> %v, i32 3
+  %add0 = add i32 %e1, %e2
+  %add1 = add i32 %e0, %add0
+  %add2 = add i32 %add1, %e3
+  ret i32 %add2
+}
+
 ;; Most of the cornercases are exercised above, the following just
 ;; makes sure that other opcodes work as expected.
 
@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
 }
 
 
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; RV32: {{.*}}
-; RV64: {{.*}}


        


More information about the llvm-commits mailing list