[llvm] 1ebe8f4 - [RISCV] Share reduction lowering code for vp.reduce

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 9 12:23:14 PST 2022


Author: Philip Reames
Date: 2022-12-09T12:22:59-08:00
New Revision: 1ebe8f4c45b8715c34c5a56e930694fa98478c93

URL: https://github.com/llvm/llvm-project/commit/1ebe8f4c45b8715c34c5a56e930694fa98478c93
DIFF: https://github.com/llvm/llvm-project/commit/1ebe8f4c45b8715c34c5a56e930694fa98478c93.diff

LOG: [RISCV] Share reduction lowering code for vp.reduce

We can consolidate code and clarify edge case behavior at the same time.

There are two functional differences here.

First, I remove the ResVT handling, and always use the reduction element type. This appears to be dead code. There's no test coverage, and this code doesn't need to account for scalar type legalization anyways.

Second, if the VL happens to be known non-zero, we can avoid passing through start. This is mostly needed to allow reuse of the existing code; I don't consider it interesting as an optimization on it's own.

Differential Revision: https://reviews.llvm.org/D139733

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 11461fe8614e7..69e1c8cf1fa04 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5796,6 +5796,13 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
   return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
 }
 
+static bool hasNonZeroAVL(SDValue AVL) {
+  auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
+  auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
+  return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
+         (ImmAVL && ImmAVL->getZExtValue() >= 1);
+}
+
 /// Helper to lower a reduction sequence of the form:
 /// scalar = reduce_op vec, scalar_start
 static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
@@ -5808,7 +5815,8 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue
   SDValue InitialSplat =
       lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
                        M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+  SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
+  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
                                   InitialSplat, Mask, VL);
   return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
                      DAG.getConstant(0, DL, XLenVT));
@@ -5951,29 +5959,17 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
     return SDValue();
 
   MVT VecVT = VecEVT.getSimpleVT();
-  MVT VecEltVT = VecVT.getVectorElementType();
   unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());
 
-  MVT ContainerVT = VecVT;
   if (VecVT.isFixedLengthVector()) {
-    ContainerVT = getContainerForFixedLengthVector(VecVT);
+    auto ContainerVT = getContainerForFixedLengthVector(VecVT);
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
   SDValue VL = Op.getOperand(3);
   SDValue Mask = Op.getOperand(2);
-
-  MVT M1VT = getLMUL1VT(ContainerVT);
-  MVT XLenVT = Subtarget.getXLenVT();
-  MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT;
-
-  SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0),
-                                        DAG.getConstant(1, DL, XLenVT), M1VT,
-                                        DL, DAG, Subtarget);
-  SDValue Reduction =
-      DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL);
-  SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
-                             DAG.getConstant(0, DL, XLenVT));
+  SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
+                                   DL, DAG, Subtarget);
   if (!VecVT.isInteger())
     return Elt0;
   return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());


        


More information about the llvm-commits mailing list