[llvm] 4e5b3f6 - [RISCV] Consolidate a bit of common logic for forming reductions

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 9 08:18:58 PST 2022


Author: Philip Reames
Date: 2022-12-09T08:18:51-08:00
New Revision: 4e5b3f63075b03597a0ae5de894a7d181e3444f3

URL: https://github.com/llvm/llvm-project/commit/4e5b3f63075b03597a0ae5de894a7d181e3444f3
DIFF: https://github.com/llvm/llvm-project/commit/4e5b3f63075b03597a0ae5de894a7d181e3444f3.diff

LOG: [RISCV] Consolidate a bit of common logic for forming reductions

There's several patches in flght which change this code, better to only have one copy.

The VP case is left seperate for the moment as the result value type differs.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 2f10e947b12d1..11461fe8614e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5796,6 +5796,25 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
   return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
 }
 
+/// Helper to lower a reduction sequence of the form:
+/// scalar = reduce_op vec, scalar_start
+static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
+                                 SDLoc DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) {
+  const MVT VecVT = Vec.getSimpleValueType();
+  const MVT VecEltVT = VecVT.getVectorElementType();
+  const MVT M1VT = getLMUL1VT(VecVT);
+  const MVT XLenVT = Subtarget.getXLenVT();
+
+  SDValue InitialSplat =
+      lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
+                       M1VT, DL, DAG, Subtarget);
+  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+                                  InitialSplat, Mask, VL);
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
+                     DAG.getConstant(0, DL, XLenVT));
+}
+
+
 SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
                                             SelectionDAG &DAG) const {
   SDLoc DL(Op);
@@ -5828,20 +5847,12 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
     Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
   }
 
-  MVT M1VT = getLMUL1VT(ContainerVT);
-  MVT XLenVT = Subtarget.getXLenVT();
-
   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
 
   SDValue NeutralElem =
       DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags());
-  SDValue IdentitySplat =
-      lowerScalarSplat(SDValue(), NeutralElem, DAG.getConstant(1, DL, XLenVT),
-                       M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
-                                  IdentitySplat, Mask, VL);
-  SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
-                             DAG.getConstant(0, DL, XLenVT));
+  SDValue Elt0 = lowerReductionSeq(RVVOpcode, NeutralElem, Vec, Mask, VL,
+                                   DL, DAG, Subtarget);
   return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
 }
 
@@ -5892,18 +5903,9 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
     VectorVal = convertToScalableVector(ContainerVT, VectorVal, DAG, Subtarget);
   }
 
-  MVT M1VT = getLMUL1VT(VectorVal.getSimpleValueType());
-  MVT XLenVT = Subtarget.getXLenVT();
-
   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
-
-  SDValue ScalarSplat =
-      lowerScalarSplat(SDValue(), ScalarVal, DAG.getConstant(1, DL, XLenVT),
-                       M1VT, DL, DAG, Subtarget);
-  SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT),
-                                  VectorVal, ScalarSplat, Mask, VL);
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
-                     DAG.getConstant(0, DL, XLenVT));
+  return lowerReductionSeq(RVVOpcode, ScalarVal, VectorVal, Mask, VL, DL, DAG,
+                           Subtarget);
 }
 
 static unsigned getRVVVPReductionOp(unsigned ISDOpcode) {


        


More information about the llvm-commits mailing list