[llvm] 1ebe8f4 - [RISCV] Share reduction lowering code for vp.reduce
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 9 12:23:14 PST 2022
Author: Philip Reames
Date: 2022-12-09T12:22:59-08:00
New Revision: 1ebe8f4c45b8715c34c5a56e930694fa98478c93
URL: https://github.com/llvm/llvm-project/commit/1ebe8f4c45b8715c34c5a56e930694fa98478c93
DIFF: https://github.com/llvm/llvm-project/commit/1ebe8f4c45b8715c34c5a56e930694fa98478c93.diff
LOG: [RISCV] Share reduction lowering code for vp.reduce
We can consolidate code and clarify edge case behavior at the same time.
There are two functional differences here.
First, I remove the ResVT handling, and always use the reduction element type. This appears to be dead code. There's no test coverage, and this code doesn't need to account for scalar type legalization anyways.
Second, if the VL happens to be known non-zero, we can avoid passing through start. This is mostly needed to allow reuse of the existing code; I don't consider it interesting as an optimization on it's own.
Differential Revision: https://reviews.llvm.org/D139733
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 11461fe8614e7..69e1c8cf1fa04 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5796,6 +5796,13 @@ SDValue RISCVTargetLowering::lowerVectorMaskVecReduction(SDValue Op,
return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0));
}
+static bool hasNonZeroAVL(SDValue AVL) {
+ auto *RegisterAVL = dyn_cast<RegisterSDNode>(AVL);
+ auto *ImmAVL = dyn_cast<ConstantSDNode>(AVL);
+ return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) ||
+ (ImmAVL && ImmAVL->getZExtValue() >= 1);
+}
+
/// Helper to lower a reduction sequence of the form:
/// scalar = reduce_op vec, scalar_start
static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL,
@@ -5808,7 +5815,8 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue
SDValue InitialSplat =
lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT),
M1VT, DL, DAG, Subtarget);
- SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec,
+ SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat;
+ SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec,
InitialSplat, Mask, VL);
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction,
DAG.getConstant(0, DL, XLenVT));
@@ -5951,29 +5959,17 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
return SDValue();
MVT VecVT = VecEVT.getSimpleVT();
- MVT VecEltVT = VecVT.getVectorElementType();
unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode());
- MVT ContainerVT = VecVT;
if (VecVT.isFixedLengthVector()) {
- ContainerVT = getContainerForFixedLengthVector(VecVT);
+ auto ContainerVT = getContainerForFixedLengthVector(VecVT);
Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget);
}
SDValue VL = Op.getOperand(3);
SDValue Mask = Op.getOperand(2);
-
- MVT M1VT = getLMUL1VT(ContainerVT);
- MVT XLenVT = Subtarget.getXLenVT();
- MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT;
-
- SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0),
- DAG.getConstant(1, DL, XLenVT), M1VT,
- DL, DAG, Subtarget);
- SDValue Reduction =
- DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL);
- SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction,
- DAG.getConstant(0, DL, XLenVT));
+ SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL,
+ DL, DAG, Subtarget);
if (!VecVT.isInteger())
return Elt0;
return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());
More information about the llvm-commits
mailing list