[llvm] [RISCV] Form vredsum from explode_vector + scalar (left) reduce (PR #67821)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 29 15:13:04 PDT 2023
================
@@ -11122,6 +11122,85 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
}
}
+/// Perform two related transforms whose purpose is to incrementally recognize
+/// an explode_vector followed by scalar reduction as a vector reduction node.
+/// This exists to recover from a deficiency in SLP which can't handle
+/// forests with multiple roots sharing common nodes. In some cases, one
+/// of the trees will be vectorized, and the other will remain (unprofitably)
+/// scalarized.
+static SDValue combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
+ // This transforms need to run before all integer types have been legalized
+ // to i64 (so that the vector element type matches the add type), and while
+ // it's safe to introduce odd sized vector types.
+ if (DAG.NewNodesMustHaveLegalTypes)
+ return SDValue();
+
+ const SDLoc DL(N);
+ const EVT VT = N->getValueType(0);
+ [[maybe_unused]] const unsigned Opc = N->getOpcode();
+ assert(Opc == ISD::ADD && "extend this to other reduction types");
+ const SDValue LHS = N->getOperand(0);
+ const SDValue RHS = N->getOperand(1);
+
+ if (!LHS.hasOneUse() || !RHS.hasOneUse())
+ return SDValue();
+
+ if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+ !isa<ConstantSDNode>(RHS.getOperand(1)))
+ return SDValue();
+
+ SDValue SrcVec = RHS.getOperand(0);
+ EVT SrcVecVT = SrcVec.getValueType();
+ assert(SrcVecVT.getVectorElementType() == VT);
+ if (SrcVecVT.isScalableVector())
+ return SDValue();
+
+ if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen())
+ return SDValue();
+
+ // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
+ // reduce_op (extract_subvector [2 x VT] from V). This will form the
+ // root of our reduction tree. TODO: We could extend this to any two
+ // adjacent constant indices if desired.
+ if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+ LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
+ isOneConstant(RHS.getOperand(1))) {
+ EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+ SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+ DAG.getVectorIdxConstant(0, DL));
+ return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
+ }
+
+ // Match (binop (reduce (extract_subvector V, 0),
+ // (extract_vector_elt V, sizeof(SubVec))))
+ // into a reduction of one more element from the original vector V.
+ if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
+ return SDValue();
+
+ SDValue ReduceVec = LHS.getOperand(0);
+ if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
+ isNullConstant(ReduceVec.getOperand(1)) &&
+ isa<ConstantSDNode>(RHS.getOperand(1))) {
+ uint64_t Idx = cast<ConstantSDNode>(RHS.getOperand(1))->getLimitedValue();
+ if (ReduceVec.getValueType().getVectorNumElements() == Idx) {
+ // For illegal types (e.g. 3xi32), most will be combined again into a
+ // wider (hopefully legal) type. If this is a terminal state, we are
+ // relying on type legalization here to poduce something reasonable
----------------
topperc wrote:
poduce -> produce
https://github.com/llvm/llvm-project/pull/67821
More information about the llvm-commits
mailing list