[llvm] 060b302 - [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (#93574)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 14:50:01 PDT 2024
Author: Craig Topper
Date: 2024-05-28T14:49:57-07:00
New Revision: 060b3023e198d197b47c652f19af5f7dea3a22cc
URL: https://github.com/llvm/llvm-project/commit/060b3023e198d197b47c652f19af5f7dea3a22cc
DIFF: https://github.com/llvm/llvm-project/commit/060b3023e198d197b47c652f19af5f7dea3a22cc.diff
LOG: [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (#93574)
I plan to add other combines on TRUNCATE_VECTOR_VL.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c826892c1668e..5fc613c1b2a14 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+ // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+ // This would be benefit for the cases where X and Y are both the same value
+ // type of low precision vectors. Since the truncate would be lowered into
+ // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+ // restriction, such pattern would be expanded into a series of "vsetvli"
+ // and "vnsrl" instructions later to reach this point.
+ auto IsTruncNode = [](SDValue V) {
+ if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+ return false;
+ SDValue VL = V.getOperand(2);
+ auto *C = dyn_cast<ConstantSDNode>(VL);
+ // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+ bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+ (isa<RegisterSDNode>(VL) &&
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+ return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+ };
+
+ SDValue Op = N->getOperand(0);
+
+ // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+ // to distinguish such pattern.
+ while (IsTruncNode(Op)) {
+ if (!Op.hasOneUse())
+ return SDValue();
+ Op = Op.getOperand(0);
+ }
+
+ if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+ return SDValue();
+
+ SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
+ if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+ N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N10 = N1.getOperand(0);
+ if (!N00.getValueType().isVector() ||
+ N00.getValueType() != N10.getValueType() ||
+ N->getValueType(0) != N10.getValueType())
+ return SDValue();
+
+ unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+ SDValue SMin =
+ DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+ DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+ return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
return SDValue();
- case RISCVISD::TRUNCATE_VECTOR_VL: {
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
- // This would be benefit for the cases where X and Y are both the same value
- // type of low precision vectors. Since the truncate would be lowered into
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
- // restriction, such pattern would be expanded into a series of "vsetvli"
- // and "vnsrl" instructions later to reach this point.
- auto IsTruncNode = [](SDValue V) {
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
- return false;
- SDValue VL = V.getOperand(2);
- auto *C = dyn_cast<ConstantSDNode>(VL);
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
- (isa<RegisterSDNode>(VL) &&
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
- IsVLMAXForVMSET;
- };
-
- SDValue Op = N->getOperand(0);
-
- // We need to first find the inner level of TRUNCATE_VECTOR_VL node
- // to distinguish such pattern.
- while (IsTruncNode(Op)) {
- if (!Op.hasOneUse())
- return SDValue();
- Op = Op.getOperand(0);
- }
-
- if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
- SDValue N0 = Op.getOperand(0);
- SDValue N1 = Op.getOperand(1);
- if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
- N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
- SDValue N00 = N0.getOperand(0);
- SDValue N10 = N1.getOperand(0);
- if (N00.getValueType().isVector() &&
- N00.getValueType() == N10.getValueType() &&
- N->getValueType(0) == N10.getValueType()) {
- unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
- SDValue SMin = DAG.getNode(
- ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
- DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
- return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
- }
- }
- }
- break;
- }
+ case RISCVISD::TRUNCATE_VECTOR_VL:
+ return combineTruncOfSraSext(N, DAG);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
More information about the llvm-commits
mailing list