[llvm] [RISCV] Move vnclipu patterns into DAGCombiner. (PR #93596)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed May 29 12:08:22 PDT 2024
https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/93596
>From 20f81ebbbfcabd65f079fb507fd8281ce07f23af Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 11:29:22 -0700
Subject: [PATCH] [RISCV] Move vnclipu patterns into DAGCombiner.
I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.
This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).
This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch will be
made for the signed vnclip patterns.
---
.../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h | 9 ++
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 70 +++++++++++-
llvm/lib/Target/RISCV/RISCVISelLowering.h | 4 +
.../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 7 --
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 104 ++++++++++++++++--
5 files changed, 174 insertions(+), 20 deletions(-)
diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 08f056f78979a..550904516ac8e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
}
} // namespace RISCVFPRndMode
+namespace RISCVVXRndMode {
+enum RoundingMode {
+ RNU = 0,
+ RNE = 1,
+ RDN = 2,
+ ROD = 3,
+};
+} // namespace RISCVVXRndMode
+
//===----------------------------------------------------------------------===//
// Floating-point Immediates
//
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f4da46f82a810..0242cfe178524 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 130 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
"not a RISC-V target specific op");
static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
- 128 &&
+ 130 &&
RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
ISD::FIRST_TARGET_STRICTFP_OPCODE ==
21 &&
@@ -16183,6 +16183,66 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
}
+// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
+// value for the truncated type.
+static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+
+ MVT VT = N->getSimpleValueType(0);
+
+ SDValue Mask = N->getOperand(1);
+ SDValue VL = N->getOperand(2);
+
+ SDValue Src = N->getOperand(0);
+
+ // Src must be a UMIN or UMIN_VL.
+ if (Src.getOpcode() != ISD::UMIN &&
+ !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
+ Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
+ return SDValue();
+
+ auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
+ // Peek through conversion between fixed and scalable vectors.
+ if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
+ isNullConstant(Op.getOperand(2)) &&
+ Op.getOperand(1).getValueType().isFixedLengthVector() &&
+ Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+ Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
+ isNullConstant(Op.getOperand(1).getOperand(1)))
+ Op = Op.getOperand(1).getOperand(0);
+
+ if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
+ return true;
+
+ if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
+ Op.getOperand(2) == VL) {
+ if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+ SplatVal =
+ Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ APInt C;
+ if (!IsSplat(Src.getOperand(1), C))
+ return SDValue();
+
+ if (!C.isMask(VT.getScalarSizeInBits()))
+ return SDValue();
+
+ SDLoc DL(N);
+ // Rounding mode here is arbitrary since we aren't shifting out any bits.
+ return DAG.getNode(
+ RISCVISD::VNCLIPU_VL, DL, VT,
+ {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+ DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
+ VL});
+}
+
SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SelectionDAG &DAG = DCI.DAG;
@@ -16400,7 +16460,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
return SDValue();
case RISCVISD::TRUNCATE_VECTOR_VL:
- return combineTruncOfSraSext(N, DAG);
+ if (SDValue V = combineTruncOfSraSext(N, DAG))
+ return V;
+ return combineTruncToVnclipu(N, DAG, Subtarget);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
@@ -20019,6 +20081,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UADDSAT_VL)
NODE_NAME_CASE(SSUBSAT_VL)
NODE_NAME_CASE(USUBSAT_VL)
+ NODE_NAME_CASE(VNCLIP_VL)
+ NODE_NAME_CASE(VNCLIPU_VL)
NODE_NAME_CASE(FADD_VL)
NODE_NAME_CASE(FSUB_VL)
NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 856ce06ba1c4f..3b8eb3c88901a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
// Rounding averaging adds of unsigned integers.
AVGCEILU_VL,
+ // Operands are (source, shift, merge, mask, roundmode, vl)
+ VNCLIPU_VL,
+ VNCLIP_VL,
+
MULHS_VL,
MULHU_VL,
FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 66df24f2a458d..691f2052ab29d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 91f3abe22331e..610a72dd02b38 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
SDTCisSameNumEltsAs<0, 4>,
SDTCisVT<5, XLenVT>]>;
+// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
+def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
+ SDTCisSameAs<0, 3>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisVec<1>,
+ SDTCisOpSmallerThanOp<2, 1>,
+ SDTCisSameAs<0, 2>,
+ SDTCisSameNumEltsAs<0, 4>,
+ SDTCVecEltisVT<4, i1>,
+ SDTCisVT<5, XLenVT>,
+ SDTCisVT<6, XLenVT>]>;
+
def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisVec<0>, SDTCisFP<0>,
SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
+def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
+
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
+multiclass VPatBinaryRM_VL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op1_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op1_reg_class,
+ VReg op2_reg_class> {
+ def : Pat<(result_type (vop
+ (op1_type op1_reg_class:$rs1),
+ (op2_type op2_reg_class:$rs2),
+ (result_type result_reg_class:$merge),
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ VLOpFrag)),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ op1_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
class VPatBinaryVL_V_RM<SDPatternOperator vop,
string instruction_name,
string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
xop_kind:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
+multiclass VPatBinaryRM_VL_XI<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType vop1_type,
+ ValueType vop2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg vop_reg_class,
+ ComplexPattern SplatPatKind,
+ DAGOperand xop_kind> {
+ def : Pat<(result_type (vop
+ (vop1_type vop_reg_class:$rs1),
+ (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+ (result_type result_reg_class:$merge),
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ VLOpFrag)),
+ (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ vop_reg_class:$rs1,
+ xop_kind:$rs2,
+ (mask_type V0),
+ (XLenVT timm:$roundmode),
+ GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
}
}
+multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
+ foreach VtiToWti = AllWidenableIntVectors in {
+ defvar vti = VtiToWti.Vti;
+ defvar wti = VtiToWti.Wti;
+ defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
+ defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
+ defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
+ vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+ vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
+ !cast<ComplexPattern>(SplatPat#_#uimm5),
+ uimm5>;
+ }
+}
+
class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
+// 12.5. Vector Narrowing Fixed-Point Clip Instructions
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
(!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
(vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
(vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_umin_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
}
}
More information about the llvm-commits
mailing list