[llvm] [RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/… (PR #100173)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 23 10:58:30 PDT 2024
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/100173
…USAT opcodes.
These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.
This should simplify #99418 a little.
>From e781297b635ee2f1abcca60e8a5ea778228f5523 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 23 Jul 2024 10:25:21 -0700
Subject: [PATCH] [RISCV] Replace VNCLIP RISCVISD opcodes with
TRUNCATE_VECTOR_VL_SSAT/USAT opcodes.
These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount,
rounding mode, and passthru are added in isel patterns similar to
how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.
This should simplify #99418 a little.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 ++--
llvm/lib/Target/RISCV/RISCVISelLowering.h | 10 +-
.../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 115 +++++-------------
3 files changed, 44 insertions(+), 106 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 22cdfdcfd80d9..dd7b0b4ed5ef7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
- unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
- Res = DAG.getNode(
- ClipOpc, DL, CvtContainerVT,
- {Res, DAG.getConstant(0, DL, CvtContainerVT),
- DAG.getUNDEF(CvtContainerVT), Mask,
- DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
- VL});
+ unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
+ : RISCVISD::TRUNCATE_VECTOR_VL_USAT;
+ Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
}
SDValue SplatZero = DAG.getNode(
@@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
SDValue Val;
unsigned ClipOpc;
if ((Val = DetectUSatPattern(Src)))
- ClipOpc = RISCVISD::VNCLIPU_VL;
+ ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
else if ((Val = DetectSSatPattern(Src)))
- ClipOpc = RISCVISD::VNCLIP_VL;
+ ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
else
return SDValue();
@@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
do {
MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
ValVT = ValVT.changeVectorElementType(ValEltVT);
- // Rounding mode here is arbitrary since we aren't shifting out any bits.
- Val = DAG.getNode(
- ClipOpc, DL, ValVT,
- {Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
- DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
- VL});
+ Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
} while (ValVT != VT);
return Val;
@@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+ NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
NODE_NAME_CASE(VSLIDEUP_VL)
NODE_NAME_CASE(VSLIDE1UP_VL)
NODE_NAME_CASE(VSLIDEDOWN_VL)
@@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(UADDSAT_VL)
NODE_NAME_CASE(SSUBSAT_VL)
NODE_NAME_CASE(USUBSAT_VL)
- NODE_NAME_CASE(VNCLIP_VL)
- NODE_NAME_CASE(VNCLIPU_VL)
NODE_NAME_CASE(FADD_VL)
NODE_NAME_CASE(FSUB_VL)
NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..e469a4b1238c7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,12 @@ enum NodeType : unsigned {
// Truncates a RVV integer vector by one power-of-two. Carries both an extra
// mask and VL operand.
TRUNCATE_VECTOR_VL,
+ // Truncates a RVV integer vector by one power-of-two. If the value doesn't
+ // fit in the destination type, the result is saturated. These correspond to
+ // vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
+ // operand.
+ TRUNCATE_VECTOR_VL_SSAT,
+ TRUNCATE_VECTOR_VL_USAT,
// Matches the semantics of vslideup/vslidedown. The first operand is the
// pass-thru operand, the second is the source vector, the third is the XLenVT
// index (either constant or non-constant), the fourth is the mask, the fifth
@@ -273,10 +279,6 @@ enum NodeType : unsigned {
// Rounding averaging adds of unsigned integers.
AVGCEILU_VL,
- // Operands are (source, shift, merge, mask, roundmode, vl)
- VNCLIPU_VL,
- VNCLIP_VL,
-
MULHS_VL,
MULHU_VL,
FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index cc294bf9254e8..2ed71f6b88974 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -132,9 +132,6 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
-def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
-def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
-
def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>;
def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
[(riscv_sext_vl node:$A, node:$B, node:$C),
(riscv_zext_vl node:$A, node:$B, node:$C)]>;
+def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
+ SDTCisSameNumEltsAs<0, 1>,
+ SDTCisSameNumEltsAs<0, 2>,
+ SDTCVecEltisVT<2, i1>,
+ SDTCisVT<3, XLenVT>]>;
def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
- SDTypeProfile<1, 3, [SDTCisVec<0>,
- SDTCisSameNumEltsAs<0, 1>,
- SDTCisSameNumEltsAs<0, 2>,
- SDTCVecEltisVT<2, i1>,
- SDTCisVT<3, XLenVT>]>>;
+ SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
+ SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
+ SDT_RISCVVTRUNCATE_VL>;
def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
SDTCisInt<1>,
@@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
op2_reg_class:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
-multiclass VPatBinaryRM_VL_V<SDNode vop,
- string instruction_name,
- string suffix,
- ValueType result_type,
- ValueType op1_type,
- ValueType op2_type,
- ValueType mask_type,
- int sew,
- LMULInfo vlmul,
- VReg result_reg_class,
- VReg op1_reg_class,
- VReg op2_reg_class> {
- def : Pat<(result_type (vop
- (op1_type op1_reg_class:$rs1),
- (op2_type op2_reg_class:$rs2),
- (result_type result_reg_class:$merge),
- (mask_type V0),
- (XLenVT timm:$roundmode),
- VLOpFrag)),
- (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
- result_reg_class:$merge,
- op1_reg_class:$rs1,
- op2_reg_class:$rs2,
- (mask_type V0),
- (XLenVT timm:$roundmode),
- GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
class VPatBinaryVL_V_RM<SDPatternOperator vop,
string instruction_name,
string suffix,
@@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
xop_kind:$rs2,
(mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
-multiclass VPatBinaryRM_VL_XI<SDNode vop,
- string instruction_name,
- string suffix,
- ValueType result_type,
- ValueType vop1_type,
- ValueType vop2_type,
- ValueType mask_type,
- int sew,
- LMULInfo vlmul,
- VReg result_reg_class,
- VReg vop_reg_class,
- ComplexPattern SplatPatKind,
- DAGOperand xop_kind> {
- def : Pat<(result_type (vop
- (vop1_type vop_reg_class:$rs1),
- (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
- (result_type result_reg_class:$merge),
- (mask_type V0),
- (XLenVT timm:$roundmode),
- VLOpFrag)),
- (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
- result_reg_class:$merge,
- vop_reg_class:$rs1,
- xop_kind:$rs2,
- (mask_type V0),
- (XLenVT timm:$roundmode),
- GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
list<VTypeInfo> vtilist = AllIntegerVectors,
bit isSEWAware = 0> {
@@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
}
}
-multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
- foreach VtiToWti = AllWidenableIntVectors in {
- defvar vti = VtiToWti.Vti;
- defvar wti = VtiToWti.Wti;
- defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
- defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
- defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
- vti.Vector, wti.Vector, vti.Vector, vti.Mask,
- vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
- !cast<ComplexPattern>(SplatPat#_#uimm5),
- uimm5>;
- }
-}
-
class VPatBinaryVL_VF<SDPatternOperator vop,
string instruction_name,
ValueType result_type,
@@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+foreach vtiTowti = AllWidenableIntVectors in {
+ defvar vti = vtiTowti.Vti;
+ defvar wti = vtiTowti.Wti;
+ let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+ GetVTypePredicates<wti>.Predicates) in {
+ // Rounding mode here is arbitrary since we aren't shifting out any bits.
+ def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
+ (vti.Mask V0),
+ VLOpFrag)),
+ (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
+ (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+ (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+ def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
+ (vti.Mask V0),
+ VLOpFrag)),
+ (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
+ (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+ (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+ }
+}
// 13. Vector Floating-Point Instructions
More information about the llvm-commits
mailing list