[llvm] [RISCV] Form vredsum from explode_vector + scalar (left) reduce (PR #67821)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 29 09:40:03 PDT 2023


================
@@ -11122,6 +11122,85 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
   }
 }
 
+/// Perform two related transforms whose purpose is to incrementally recognize
+/// an explode_vector followed by scalar reduction as a vector reduction node.
+/// This exists to recover from a deficiency in SLP which can't handle
+/// forrests with multiple roots sharing common nodes.  In some cases, one
+/// of the trees will be vectorized, and the other will remain (unprofitably)
+/// scalarized.
+static SDValue combineBinOpOfExtractToReduceTree(SDNode *N, SelectionDAG &DAG,
+                                                 const RISCVSubtarget &Subtarget) {
+
+  // This transforms need to run before all integer types have been legalized
+  // to i64 (so that the vector element type matches the add type), and while
+  // it's safe to introduce odd sized vector types.
+  if (DAG.NewNodesMustHaveLegalTypes)
+    return SDValue();
+
+  const SDLoc DL(N);
+  const EVT VT = N->getValueType(0);
+  const unsigned Opc = N->getOpcode();
+  assert(Opc == ISD::ADD && "extend this to other reduction types");
+  const SDValue LHS = N->getOperand(0);
+  const SDValue RHS = N->getOperand(1);
+
+  if (!LHS.hasOneUse() || !RHS.hasOneUse())
+    return SDValue();
+
+  if (RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
+      !isa<ConstantSDNode>(RHS.getOperand(1)))
+    return SDValue();
+
+  SDValue SrcVec = RHS.getOperand(0);
+  EVT SrcVecVT = SrcVec.getValueType();
+  assert(SrcVecVT.getVectorElementType() == VT);
+  if (SrcVecVT.isScalableVector())
+    return SDValue();
+
+  if (SrcVecVT.getScalarSizeInBits() > Subtarget.getELen())
+    return SDValue();
+
+  // match binop (extract_vector_elt V, 0), (extract_vector_elt V, 1) to
+  // reduce_op (extract_subvector [2 x VT] from V).  This will form the
+  // root of our reduction tree. TODO: We could extend this to any two
+  // adjacent constant indices if desired.
+  if (LHS.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
+      LHS.getOperand(0) == SrcVec && isNullConstant(LHS.getOperand(1)) &&
+      isOneConstant(RHS.getOperand(1))) {
+    EVT ReduceVT = EVT::getVectorVT(*DAG.getContext(), VT, 2);
+    SDValue Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ReduceVT, SrcVec,
+                              DAG.getVectorIdxConstant(0, DL));
+    return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, Vec);
+  }
+
+  // Match (binop reduce (extract_subvector V, 0),
+  //                      extract_vector_elt V, sizeof(SubVec))
+  // into a reduction of one more element from the original vector V.
+  if (LHS.getOpcode() != ISD::VECREDUCE_ADD)
+    return SDValue();
+
+  SDValue ReduceVec = LHS.getOperand(0);
+  if (ReduceVec.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+      ReduceVec.hasOneUse() && ReduceVec.getOperand(0) == RHS.getOperand(0) &&
+      isNullConstant(ReduceVec.getOperand(1)) &&
+      isa<ConstantSDNode>(RHS.getOperand(1))) {
----------------
topperc wrote:

RHS.getOperand(1) is known to be a ConstantSDNode from line 11151 right?

https://github.com/llvm/llvm-project/pull/67821


More information about the llvm-commits mailing list