[llvm] [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper function. NFC (PR #93574)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 09:29:03 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Craig Topper (topperc)
<details>
<summary>Changes</summary>
I plan to add other combines on TRUNCATE_VECTOR_VL.
---
Full diff: https://github.com/llvm/llvm-project/pull/93574.diff
1 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+53-50)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..47b1cc1ba6460 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+ // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+ // This would be benefit for the cases where X and Y are both the same value
+ // type of low precision vectors. Since the truncate would be lowered into
+ // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+ // restriction, such pattern would be expanded into a series of "vsetvli"
+ // and "vnsrl" instructions later to reach this point.
+ auto IsTruncNode = [](SDValue V) {
+ if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+ return false;
+ SDValue VL = V.getOperand(2);
+ auto *C = dyn_cast<ConstantSDNode>(VL);
+ // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+ bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+ (isa<RegisterSDNode>(VL) &&
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+ return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+ };
+
+ SDValue Op = N->getOperand(0);
+
+ // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+ // to distinguish such pattern.
+ while (IsTruncNode(Op)) {
+ if (!Op.hasOneUse())
+ return SDValue();
+ Op = Op.getOperand(0);
+ }
+
+ if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+ return SDValue();
+
+ SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
+ if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+ N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N10 = N1.getOperand(0);
+ if (!N00.getValueType().isVector() ||
+ N00.getValueType() != N10.getValueType() ||
+ N->getValueType(0) != N10.getValueType())
+ return SDValue();
+
+ unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+ SDValue SMin =
+ DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+ DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+ return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
return SDValue();
- case RISCVISD::TRUNCATE_VECTOR_VL: {
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
- // This would be benefit for the cases where X and Y are both the same value
- // type of low precision vectors. Since the truncate would be lowered into
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
- // restriction, such pattern would be expanded into a series of "vsetvli"
- // and "vnsrl" instructions later to reach this point.
- auto IsTruncNode = [](SDValue V) {
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
- return false;
- SDValue VL = V.getOperand(2);
- auto *C = dyn_cast<ConstantSDNode>(VL);
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
- (isa<RegisterSDNode>(VL) &&
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
- IsVLMAXForVMSET;
- };
-
- SDValue Op = N->getOperand(0);
-
- // We need to first find the inner level of TRUNCATE_VECTOR_VL node
- // to distinguish such pattern.
- while (IsTruncNode(Op)) {
- if (!Op.hasOneUse())
- return SDValue();
- Op = Op.getOperand(0);
- }
-
- if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
- SDValue N0 = Op.getOperand(0);
- SDValue N1 = Op.getOperand(1);
- if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
- N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
- SDValue N00 = N0.getOperand(0);
- SDValue N10 = N1.getOperand(0);
- if (N00.getValueType().isVector() &&
- N00.getValueType() == N10.getValueType() &&
- N->getValueType(0) == N10.getValueType()) {
- unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
- SDValue SMin = DAG.getNode(
- ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
- DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
- return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
- }
- }
- }
- break;
- }
+ case RISCVISD::TRUNCATE_VECTOR_VL:
+ return combineTruncOfSraSext(N, DAG);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
``````````
</details>
https://github.com/llvm/llvm-project/pull/93574
More information about the llvm-commits
mailing list