[llvm] bee9a92 - [RISCV] Use reduction result type for EXTRACT_VECTOR_ELT in lowerReductionSeq.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 13 09:15:36 PST 2022


Author: Craig Topper
Date: 2022-12-13T09:10:36-08:00
New Revision: bee9a92aec07db15f7ec89278c381654d60e4d02

URL: https://github.com/llvm/llvm-project/commit/bee9a92aec07db15f7ec89278c381654d60e4d02
DIFF: https://github.com/llvm/llvm-project/commit/bee9a92aec07db15f7ec89278c381654d60e4d02.diff

LOG: [RISCV] Use reduction result type for EXTRACT_VECTOR_ELT in lowerReductionSeq.

Remove the call to getSExtOrTrunc.

Reduction ISD nodes produce a scalar result and that result is
allowed to be larger than the vector element type due to type
legalization. This is the same rule we allow for EXTRACT_VECTOR_ELT
for the same reason.

We can copy the result type over from the reduction node to
EXTRACT_VECTOR_ELT.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D139757

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index df8ec9a2927a..bd6afe51fac5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5846,12 +5846,11 @@ static bool hasNonZeroAVL(SDValue AVL) {
 
 /// Helper to lower a reduction sequence of the form:
 /// scalar = reduce_op vec, scalar_start
-static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue,
-                                 SDValue Vec, SDValue Mask, SDValue VL,
-                                 SDLoc DL, SelectionDAG &DAG,
+static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
+                                 SDValue StartValue, SDValue Vec, SDValue Mask,
+                                 SDValue VL, SDLoc DL, SelectionDAG &DAG,
                                  const RISCVSubtarget &Subtarget) {
   const MVT VecVT = Vec.getSimpleValueType();
-  const MVT VecEltVT = VecVT.getVectorElementType();
   const MVT M1VT = getLMUL1VT(VecVT);
   const MVT XLenVT = Subtarget.getXLenVT();
 
@@ -5868,7 +5867,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue,
   SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialValue;
   SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
                                   InitialValue, Mask, VL);
-  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
+  return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
                      DAG.getConstant(0, DL, XLenVT));
 }
 
@@ -5908,9 +5907,8 @@ SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op,
 
   SDValue NeutralElem =
       DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags());
-  SDValue Elt0 = lowerReductionSeq(RVVOpcode, NeutralElem, Vec, Mask, VL,
-                                   DL, DAG, Subtarget);
-  return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
+  return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), NeutralElem, Vec,
+                           Mask, VL, DL, DAG, Subtarget);
 }
 
 // Given a reduction op, this function returns the matching reduction opcode,
@@ -5961,8 +5959,8 @@ SDValue RISCVTargetLowering::lowerFPVECREDUCE(SDValue Op,
   }
 
   auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget);
-  return lowerReductionSeq(RVVOpcode, ScalarVal, VectorVal, Mask, VL, DL, DAG,
-                           Subtarget);
+  return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), ScalarVal,
+                           VectorVal, Mask, VL, DL, DAG, Subtarget);
 }
 
 static unsigned getRVVVPReductionOp(unsigned ISDOpcode) {
@@ -6017,11 +6015,8 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
 
   SDValue VL = Op.getOperand(3);
   SDValue Mask = Op.getOperand(2);
-  SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
-                                   DL, DAG, Subtarget);
-  if (!VecVT.isInteger())
-    return Elt0;
-  return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
+  return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), Op.getOperand(0),
+                           Vec, Mask, VL, DL, DAG, Subtarget);
 }
 
 SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,


        


More information about the llvm-commits mailing list