[llvm] [RISCV] Move vnclip patterns into DAGCombiner. (PR #93728)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed May 29 13:04:35 PDT 2024
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/93728
Similar to #93596, this moves the signed vnclip patterns into DAG combine.
This will allows us to support more than 1 level of truncate in a
future patch.
There's a pre-commit that refactors the vnclipu code to make it is easier to share code.
>From 5f7c0162b261925227807f67d1193648e1a31608 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 29 May 2024 12:24:10 -0700
Subject: [PATCH 1/2] [RISCV] Refactor combineTruncToVnclipu to preparate for
adding signed vnclip support.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 51 +++++++++++++--------
1 file changed, 32 insertions(+), 19 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0242cfe178524..bb0df509a3c2c 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16185,8 +16185,8 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
// value for the truncated type.
-static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
+static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
MVT VT = N->getSimpleValueType(0);
@@ -16194,15 +16194,15 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
SDValue Mask = N->getOperand(1);
SDValue VL = N->getOperand(2);
- SDValue Src = N->getOperand(0);
+ auto MatchMinMax = [&VL, &Mask](SDValue V, unsigned Opc, unsigned OpcVL,
+ APInt &SplatVal) {
+ if (V.getOpcode() != Opc &&
+ !(V.getOpcode() == OpcVL && V.getOperand(2).isUndef() &&
+ V.getOperand(3) == Mask && V.getOperand(4) == VL))
+ return SDValue();
- // Src must be a UMIN or UMIN_VL.
- if (Src.getOpcode() != ISD::UMIN &&
- !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
- Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
- return SDValue();
+ SDValue Op = V.getOperand(1);
- auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
// Peek through conversion between fixed and scalable vectors.
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
isNullConstant(Op.getOperand(2)) &&
@@ -16213,32 +16213,45 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
Op = Op.getOperand(1).getOperand(0);
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
- return true;
+ return V.getOperand(0);
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
Op.getOperand(2) == VL) {
if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
SplatVal =
Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
- return true;
+ return V.getOperand(0);
}
}
- return false;
+ return SDValue();
};
- APInt C;
- if (!IsSplat(Src.getOperand(1), C))
- return SDValue();
+ auto DetectUSatPattern = [&](SDValue V) {
+ // Src must be a UMIN or UMIN_VL.
+ APInt C;
+ SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
+ if (!UMin)
+ return SDValue();
+
+ if (!C.isMask(VT.getScalarSizeInBits()))
+ return SDValue();
- if (!C.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+ };
+
+ SDValue Val;
+ unsigned ClipOpc;
+ if ((Val = DetectUSatPattern(N->getOperand(0)))) {
+ ClipOpc = RISCVISD::VNCLIPU_VL;
+ } else
return SDValue();
SDLoc DL(N);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
return DAG.getNode(
- RISCVISD::VNCLIPU_VL, DL, VT,
- {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+ ClipOpc, DL, VT,
+ {Val, DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
}
@@ -16462,7 +16475,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::TRUNCATE_VECTOR_VL:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
- return combineTruncToVnclipu(N, DAG, Subtarget);
+ return combineTruncToVnclip(N, DAG, Subtarget);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
>From 7faf709b0d8e6e8dd829db7048506a1dd8d0572d Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 29 May 2024 13:00:52 -0700
Subject: [PATCH 2/2] [RISCV] Move vnclip patterns into DAGCombiner.
Similar to #93596, this moves the signed vnclip patterns into DAG combine.
This will allows us to support more than 1 level of truncate in a
future patch.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 22 ++++++++++
.../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 34 ----------------
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 40 -------------------
3 files changed, 22 insertions(+), 74 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index bb0df509a3c2c..f4b64df927418 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16240,10 +16240,32 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
return UMin;
};
+ auto DetectSSatPattern = [&](SDValue V) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = V.getScalarValueSizeInBits();
+ APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+
+ APInt CMin, CMax;
+ if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+ if (CMin == SignedMax && CMax == SignedMin)
+ return SMax;
+
+ if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+ if (CMin == SignedMax && CMax == SignedMin)
+ return SMin;
+
+ return SDValue();
+ };
+
SDValue Val;
unsigned ClipOpc;
if ((Val = DetectUSatPattern(N->getOperand(0)))) {
ClipOpc = RISCVISD::VNCLIPU_VL;
+ } else if ((Val = DetectSSatPattern(N->getOperand(0)))) {
+ ClipOpc = RISCVISD::VNCLIP_VL;
} else
return SDValue();
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 691f2052ab29d..3163e4bafd4b0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1168,40 +1168,6 @@ defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10, suffix = "U">;
defm : VPatAVGADD_VV_VX_RM<avgceils, 0b00>;
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00, suffix = "U">;
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
- defvar sew = vti.SEW;
- defvar uminval = !sub(!shl(1, sew), 1);
- defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
- defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (smin
- (wti.Vector (smax (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (smax
- (wti.Vector (smin (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
- }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
- defm : VPatTruncSatClipSDNode<vtiToWti.Vti, vtiToWti.Wti>;
-
// 15. Vector Mask Instructions
// 15.1. Vector Mask-Register Logical Instructions
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 610a72dd02b38..ce8133a5a297b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2470,46 +2470,6 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
- defvar sew = vti.SEW;
- defvar uminval = !sub(!shl(1, sew), 1);
- defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
- defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_smin_vl
- (wti.Vector (riscv_smax_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
- (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_smax_vl
- (wti.Vector (riscv_smin_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
- (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
- }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
- defm : VPatTruncSatClipVL<vtiToWti.Vti, vtiToWti.Wti>;
-
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
More information about the llvm-commits
mailing list