[llvm] [RISCV] Verify the VL and Mask on the outer TRUNCATE_VECTOR_VL in combineTruncOfSraSext. (PR #93578)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 23:01:24 PDT 2024
https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/93578
>From e10966f1484d655c9ca532e3942cde38acc0328a Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 09:22:05 -0700
Subject: [PATCH 1/4] [RISCV] Move TRUNCATE_VECTOR_VL combine into a helper
function. NFC
I plan to add other combines on TRUNCATE_VECTOR_VL.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 103 ++++++++++----------
1 file changed, 53 insertions(+), 50 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f0e5a7d393b6c..47b1cc1ba6460 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16087,6 +16087,57 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}
+static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+ // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+ // This would be benefit for the cases where X and Y are both the same value
+ // type of low precision vectors. Since the truncate would be lowered into
+ // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+ // restriction, such pattern would be expanded into a series of "vsetvli"
+ // and "vnsrl" instructions later to reach this point.
+ auto IsTruncNode = [](SDValue V) {
+ if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+ return false;
+ SDValue VL = V.getOperand(2);
+ auto *C = dyn_cast<ConstantSDNode>(VL);
+ // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
+ bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
+ (isa<RegisterSDNode>(VL) &&
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+ return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+ };
+
+ SDValue Op = N->getOperand(0);
+
+ // We need to first find the inner level of TRUNCATE_VECTOR_VL node
+ // to distinguish such pattern.
+ while (IsTruncNode(Op)) {
+ if (!Op.hasOneUse())
+ return SDValue();
+ Op = Op.getOperand(0);
+ }
+
+ if (Op.getOpcode() != ISD::SRA || !Op.hasOneUse())
+ return SDValue();
+
+ SDValue N0 = Op.getOperand(0);
+ SDValue N1 = Op.getOperand(1);
+ if (N0.getOpcode() != ISD::SIGN_EXTEND || !N0.hasOneUse() ||
+ N1.getOpcode() != ISD::ZERO_EXTEND || !N1.hasOneUse())
+ return SDValue();
+
+ SDValue N00 = N0.getOperand(0);
+ SDValue N10 = N1.getOperand(0);
+ if (!N00.getValueType().isVector() ||
+ N00.getValueType() != N10.getValueType() ||
+ N->getValueType(0) != N10.getValueType())
+ return SDValue();
+
+ unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
+ SDValue SMin =
+ DAG.getNode(ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
+ DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
+ return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
+}
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -16304,56 +16355,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
return SDValue();
- case RISCVISD::TRUNCATE_VECTOR_VL: {
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
- // This would be benefit for the cases where X and Y are both the same value
- // type of low precision vectors. Since the truncate would be lowered into
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
- // restriction, such pattern would be expanded into a series of "vsetvli"
- // and "vnsrl" instructions later to reach this point.
- auto IsTruncNode = [](SDValue V) {
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
- return false;
- SDValue VL = V.getOperand(2);
- auto *C = dyn_cast<ConstantSDNode>(VL);
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
- (isa<RegisterSDNode>(VL) &&
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
- IsVLMAXForVMSET;
- };
-
- SDValue Op = N->getOperand(0);
-
- // We need to first find the inner level of TRUNCATE_VECTOR_VL node
- // to distinguish such pattern.
- while (IsTruncNode(Op)) {
- if (!Op.hasOneUse())
- return SDValue();
- Op = Op.getOperand(0);
- }
-
- if (Op.getOpcode() == ISD::SRA && Op.hasOneUse()) {
- SDValue N0 = Op.getOperand(0);
- SDValue N1 = Op.getOperand(1);
- if (N0.getOpcode() == ISD::SIGN_EXTEND && N0.hasOneUse() &&
- N1.getOpcode() == ISD::ZERO_EXTEND && N1.hasOneUse()) {
- SDValue N00 = N0.getOperand(0);
- SDValue N10 = N1.getOperand(0);
- if (N00.getValueType().isVector() &&
- N00.getValueType() == N10.getValueType() &&
- N->getValueType(0) == N10.getValueType()) {
- unsigned MaxShAmt = N10.getValueType().getScalarSizeInBits() - 1;
- SDValue SMin = DAG.getNode(
- ISD::SMIN, SDLoc(N1), N->getValueType(0), N10,
- DAG.getConstant(MaxShAmt, SDLoc(N1), N->getValueType(0)));
- return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
- }
- }
- }
- break;
- }
+ case RISCVISD::TRUNCATE_VECTOR_VL:
+ return combineTruncOfSraSext(N, DAG);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
>From 4e227983c1e3c290724f09e4968610e7b0c21689 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 09:55:48 -0700
Subject: [PATCH 2/4] [RISCV] Verify the VL and Mask on the outer
TRUNCATE_VECTOR_VL in combineTruncOfSraSext.
We checked the VL and mask of any additional TRUNCATE_VECTOR_VL
nodes we peek through, but not the outermost.
This moves the check to the outer node and then verifies all the
additional nodes have the same VL and Mask.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 23 ++++++++++++---------
1 file changed, 13 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 47b1cc1ba6460..288e874276e07 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16088,22 +16088,25 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
}
static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
+ SDValue Mask = N->getOperand(1);
+ SDValue VL = N->getOperand(2);
+
+ bool IsVLMAX = isAllOnesConstant(VL) ||
+ (isa<RegisterSDNode>(VL) &&
+ cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
+ if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL ||
+ Mask.getOperand(0) != VL)
+ return SDValue();
+
// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
// This would be benefit for the cases where X and Y are both the same value
// type of low precision vectors. Since the truncate would be lowered into
// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
// restriction, such pattern would be expanded into a series of "vsetvli"
// and "vnsrl" instructions later to reach this point.
- auto IsTruncNode = [](SDValue V) {
- if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
- return false;
- SDValue VL = V.getOperand(2);
- auto *C = dyn_cast<ConstantSDNode>(VL);
- // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand
- bool IsVLMAXForVMSET = (C && C->isAllOnes()) ||
- (isa<RegisterSDNode>(VL) &&
- cast<RegisterSDNode>(VL)->getReg() == RISCV::X0);
- return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET;
+ auto IsTruncNode = [&](SDValue V) {
+ return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
+ V.getOperand(1) == Mask && V.getOperand(2) == VL;
};
SDValue Op = N->getOperand(0);
>From 78777ec5442cd9a71c639ce685512e699241200b Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 16:12:22 -0700
Subject: [PATCH 3/4] fixup! move comment.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6a9a61e480294..f4da46f82a810 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16128,6 +16128,12 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask,
return true;
}
+// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
+// This would be benefit for the cases where X and Y are both the same value
+// type of low precision vectors. Since the truncate would be lowered into
+// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
+// restriction, such pattern would be expanded into a series of "vsetvli"
+// and "vnsrl" instructions later to reach this point.
static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
SDValue Mask = N->getOperand(1);
SDValue VL = N->getOperand(2);
@@ -16139,12 +16145,6 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
Mask.getOperand(0) != VL)
return SDValue();
- // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1))
- // This would be benefit for the cases where X and Y are both the same value
- // type of low precision vectors. Since the truncate would be lowered into
- // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate
- // restriction, such pattern would be expanded into a series of "vsetvli"
- // and "vnsrl" instructions later to reach this point.
auto IsTruncNode = [&](SDValue V) {
return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL &&
V.getOperand(1) == Mask && V.getOperand(2) == VL;
>From 0a0682e168dd275e2fd7139bbd6c5ca472418630 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 22:58:30 -0700
Subject: [PATCH 4/4] fixup! Update new test.
---
llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll | 18 +++++++++++-------
1 file changed, 11 insertions(+), 7 deletions(-)
diff --git a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
index 8dbb57fd15cf1..1bd83734a03cb 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
define <vscale x 1 x i8> @vsra_vv_nxv1i8(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb) {
; CHECK-LABEL: vsra_vv_nxv1i8:
@@ -937,13 +937,17 @@ define <vscale x 8 x i32> @vsra_vi_mask_nxv8i32(<vscale x 8 x i32> %va, <vscale
; Negative test. We shouldn't look through the vp.trunc as it isn't vlmax like
; the rest of the code.
-define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 %evl) {
+define <vscale x 1 x i8> @vsra_vv_nxv1i8_sext_zext_mixed_trunc(<vscale x 1 x i8> %va, <vscale x 1 x i8> %vb, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext_mixed_trunc:
; CHECK: # %bb.0:
-; CHECK-NEXT: li a0, 7
-; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma
-; CHECK-NEXT: vmin.vx v9, v8, a0
-; CHECK-NEXT: vsra.vv v8, v8, v9
+; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vsext.vf4 v9, v8
+; CHECK-NEXT: vzext.vf4 v10, v8
+; CHECK-NEXT: vsra.vv v8, v9, v10
+; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma
+; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma
+; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t
; CHECK-NEXT: ret
%sexted_va = sext <vscale x 1 x i8> %va to <vscale x 1 x i32>
%zexted_vb = zext <vscale x 1 x i8> %va to <vscale x 1 x i32>
More information about the llvm-commits
mailing list