[PATCH] D139733: [RISCV] Share reduction lowering code for vp.reduce

Philip Reames via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 9 11:22:00 PST 2022


reames created this revision.
reames added reviewers: craig.topper, kito-cheng, asb, frasercrmck.
Herald added subscribers: sunshaoce, VincentWu, StephenFan, vkmr, evandro, luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, zzheng, jrtc27, shiva0217, niosHD, sabuasal, bollu, simoncook, johnrusso, rbar, hiraditya, arichardson, mcrosier.
Herald added a project: All.
reames requested review of this revision.
Herald added subscribers: pcwang-thead, eopXD, MaskRay.
Herald added a project: LLVM.

We can consolidate code and clarify edge case behavior at the same time.

There are two functional differences here.

First, I remove the ResVT handling, and always use the reduction element type.  This appears to be dead code.  There's no test coverage, and I think such a construct wouldn't be legal anyways.  This is the main reason I posted this for review as I want someone to double check me on this point.

Second, if the VL happens to be known non-zero, we can avoid passing through start.  This is mostly needed to allow reuse of the existing code; I don't consider it interesting as an optimization on it's own.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D139733

Files:
  llvm/lib/Target/RISCV/RISCVISelLowering.cpp


Index: llvm/lib/Target/RISCV/RISCVISelLowering.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5796,6 +5796,13 @@
   return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
 }
 
+static bool hasNonZeroAVL(SDValue AVL) {
+  auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
+  auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
+  return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
+         (ImmAVL && ImmAVL->getZExtValue() >= 1);
+}
+
 /// Helper to lower a reduction sequence of the form:
 /// scalar = reduce_op vec, scalar_start
 static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
@@ -5808,7 +5815,8 @@
   SDValue InitialSplat =
       lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
                        M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+  SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
+  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
                                   InitialSplat, Mask, VL);
   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
                      DAG.getConstant(0, DL, XLenVT));
@@ -5951,29 +5959,17 @@
     return SDValue();
 
   MVT VecVT = VecEVT.getSimpleVT();
-  MVT VecEltVT = VecVT.getVectorElementType();
   unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());
 
-  MVT ContainerVT = VecVT;
   if (VecVT.isFixedLengthVector()) {
-    ContainerVT = getContainerForFixedLengthVector(VecVT);
+    auto ContainerVT = getContainerForFixedLengthVector(VecVT);
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
   SDValue VL = Op.getOperand(3);
   SDValue Mask = Op.getOperand(2);
-
-  MVT M1VT = getLMUL1VT(ContainerVT);
-  MVT XLenVT = Subtarget.getXLenVT();
-  MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT;
-
-  SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0),
-                                        DAG.getConstant(1, DL, XLenVT), M1VT,
-                                        DL, DAG, Subtarget);
-  SDValue Reduction =
-      DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL);
-  SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
-                             DAG.getConstant(0, DL, XLenVT));
+  SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
+                                   DL, DAG, Subtarget);
   if (!VecVT.isInteger())
     return Elt0;
   return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D139733.481704.patch
Type: text/x-patch
Size: 2776 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20221209/37fe8902/attachment.bin>


More information about the llvm-commits mailing list