[llvm] [RISCV] Allow swapped operands in reduction formation (PR #68634)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 9 14:36:08 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
<details>
<summary>Changes</summary>
Very straight forward, but worth lnading on it's own in advance of a more complicated generalization.
---
Full diff: https://github.com/llvm/llvm-project/pull/68634.diff
2 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+28-23)
- (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll (+62-4)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..b0fc99f6eff860b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -11363,16 +11363,20 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
const unsigned ReduceOpc = getVecReduceOpcode(Opc);
assert(Opc == ISD::getVecReduceBaseOpcode(ReduceOpc) &&
"Inconsistent mappings");
- const SDValue LHS = N->getOperand(0);
- const SDValue RHS = N->getOperand(1);
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
if (!LHS.hasOneUse() || !RHS.hasOneUse())
return SDValue();
+ if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+ std::swap(LHS, RHS);
+
if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
!isa<ConstantSDNode>(RHS.getOperand(1)))
return SDValue();
+ uint64_t RHSIdx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
SDValue SrcVec = RHS.getOperand(0);
EVT SrcVecVT = SrcVec.getValueType();
assert(SrcVecVT.getVectorElementType() == VT);
@@ -11385,14 +11389,17 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
// match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
// reduce_op (extract_subvector [2 x VT] from V). This will form the
// root of our reduction tree. TODO: We could extend this to any two
- // adjacent constant indices if desired.
+ // adjacent aligned constant indices if desired.
if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
- isOneConstant(RHS.getOperand(1))) {
- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
- DAG.getVectorIdxConstant(0, DL));
- return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+ LHS.getOperand(0) == SrcVec && isa<ConstantSDNode>(LHS.getOperand(1))) {
+ uint64_t LHSIdx =
+ cast<ConstantSDNode>(LHS.getOperand(1))->getLimitedValue();
+ if (0 == std::min(LHSIdx, RHSIdx) && 1 == std::max(LHSIdx, RHSIdx)) {
+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+ DAG.getVectorIdxConstant(0, DL));
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, N->getFlags());
+ }
}
// Match (binop (reduce (extract_subvector V, 0),
@@ -11404,20 +11411,18 @@ combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
SDValue ReduceVec = LHS.getOperand(0);
if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
- isNullConstant(ReduceVec.getOperand(1))) {
- uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
- if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
- // For illegal types (e.g. 3xi32), most will be combined again into a
- // wider (hopefully legal) type. If this is a terminal state, we are
- // relying on type legalization here to produce something reasonable
- // and this lowering quality could probably be improved. (TODO)
- EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, Idx + 1);
- SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
- DAG.getVectorIdxConstant(0, DL));
- auto Flags = ReduceVec->getFlags();
- Flags.intersectWith(N->getFlags());
- return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
- }
+ isNullConstant(ReduceVec.getOperand(1)) &&
+ ReduceVec.getValueType().getVectorNumElements() == RHSIdx) {
+ // For illegal types (e.g. 3xi32), most will be combined again into a
+ // wider (hopefully legal) type. If this is a terminal state, we are
+ // relying on type legalization here to produce something reasonable
+ // and this lowering quality could probably be improved. (TODO)
+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, RHSIdx + 1);
+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+ DAG.getVectorIdxConstant(0, DL));
+ auto Flags = ReduceVec->getFlags();
+ Flags.intersectWith(N->getFlags());
+ return DAG.getNode(ReduceOpc, DL, VT, Vec, Flags);
}
return SDValue();
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
index 76df097a7697162..fd4a54b468f15fd 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll
@@ -34,7 +34,6 @@ define i32 @reduce_sum_4xi32(<4 x i32> %v) {
ret i32 %add2
}
-
define i32 @reduce_sum_8xi32(<8 x i32> %v) {
; CHECK-LABEL: reduce_sum_8xi32:
; CHECK: # %bb.0:
@@ -449,6 +448,68 @@ define i32 @reduce_sum_16xi32_prefix15(ptr %p) {
ret i32 %add13
}
+; Check that we can match with the operand ordered reversed, but the
+; reduction order unchanged.
+define i32 @reduce_sum_4xi32_op_order(<4 x i32> %v) {
+; CHECK-LABEL: reduce_sum_4xi32_op_order:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; CHECK-NEXT: vmv.s.x v9, zero
+; CHECK-NEXT: vredsum.vs v8, v8, v9
+; CHECK-NEXT: vmv.x.s a0, v8
+; CHECK-NEXT: ret
+ %e0 = extractelement <4 x i32> %v, i32 0
+ %e1 = extractelement <4 x i32> %v, i32 1
+ %e2 = extractelement <4 x i32> %v, i32 2
+ %e3 = extractelement <4 x i32> %v, i32 3
+ %add0 = add i32 %e1, %e0
+ %add1 = add i32 %e2, %add0
+ %add2 = add i32 %add1, %e3
+ ret i32 %add2
+}
+
+; Negative test - Reduction order isn't compatibile with current
+; incremental matching scheme.
+define i32 @reduce_sum_4xi32_reduce_order(<4 x i32> %v) {
+; RV32-LABEL: reduce_sum_4xi32_reduce_order:
+; RV32: # %bb.0:
+; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, ma
+; RV32-NEXT: vmv.x.s a0, v8
+; RV32-NEXT: vslidedown.vi v9, v8, 1
+; RV32-NEXT: vmv.x.s a1, v9
+; RV32-NEXT: vslidedown.vi v9, v8, 2
+; RV32-NEXT: vmv.x.s a2, v9
+; RV32-NEXT: vslidedown.vi v8, v8, 3
+; RV32-NEXT: vmv.x.s a3, v8
+; RV32-NEXT: add a1, a1, a2
+; RV32-NEXT: add a0, a0, a3
+; RV32-NEXT: add a0, a0, a1
+; RV32-NEXT: ret
+;
+; RV64-LABEL: reduce_sum_4xi32_reduce_order:
+; RV64: # %bb.0:
+; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, ma
+; RV64-NEXT: vmv.x.s a0, v8
+; RV64-NEXT: vslidedown.vi v9, v8, 1
+; RV64-NEXT: vmv.x.s a1, v9
+; RV64-NEXT: vslidedown.vi v9, v8, 2
+; RV64-NEXT: vmv.x.s a2, v9
+; RV64-NEXT: vslidedown.vi v8, v8, 3
+; RV64-NEXT: vmv.x.s a3, v8
+; RV64-NEXT: add a1, a1, a2
+; RV64-NEXT: add a0, a0, a3
+; RV64-NEXT: addw a0, a0, a1
+; RV64-NEXT: ret
+ %e0 = extractelement <4 x i32> %v, i32 0
+ %e1 = extractelement <4 x i32> %v, i32 1
+ %e2 = extractelement <4 x i32> %v, i32 2
+ %e3 = extractelement <4 x i32> %v, i32 3
+ %add0 = add i32 %e1, %e2
+ %add1 = add i32 %e0, %add0
+ %add2 = add i32 %add1, %e3
+ ret i32 %add2
+}
+
;; Most of the cornercases are exercised above, the following just
;; makes sure that other opcodes work as expected.
@@ -923,6 +984,3 @@ define float @reduce_fadd_4xi32_non_associative2(ptr %p) {
}
-;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
-; RV32: {{.*}}
-; RV64: {{.*}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/68634
More information about the llvm-commits
mailing list