[llvm] [RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/… (PR #100173)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 23 10:58:30 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/100173

…USAT opcodes.

These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.

>From e781297b635ee2f1abcca60e8a5ea778228f5523 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 23 Jul 2024 10:25:21 -0700
Subject: [PATCH] [RISCV] Replace VNCLIP RISCVISD opcodes with
 TRUNCATE_VECTOR_VL_SSAT/USAT opcodes.

These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount,
rounding mode, and passthru are added in isel patterns similar to
how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  25 ++--
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  10 +-
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 115 +++++-------------
 3 files changed, 44 insertions(+), 106 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 22cdfdcfd80d9..dd7b0b4ed5ef7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
     CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
     CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
     // Rounding mode here is arbitrary since we aren't shifting out any bits.
-    unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
-    Res = DAG.getNode(
-        ClipOpc, DL, CvtContainerVT,
-        {Res, DAG.getConstant(0, DL, CvtContainerVT),
-         DAG.getUNDEF(CvtContainerVT), Mask,
-         DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
-         VL});
+    unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
+                                : RISCVISD::TRUNCATE_VECTOR_VL_USAT;
+    Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
   }
 
   SDValue SplatZero = DAG.getNode(
@@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
   SDValue Val;
   unsigned ClipOpc;
   if ((Val = DetectUSatPattern(Src)))
-    ClipOpc = RISCVISD::VNCLIPU_VL;
+    ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
   else if ((Val = DetectSSatPattern(Src)))
-    ClipOpc = RISCVISD::VNCLIP_VL;
+    ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
   else
     return SDValue();
 
@@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
   do {
     MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
     ValVT = ValVT.changeVectorElementType(ValEltVT);
-    // Rounding mode here is arbitrary since we aren't shifting out any bits.
-    Val = DAG.getNode(
-        ClipOpc, DL, ValVT,
-        {Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
-         DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
-         VL});
+    Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
   } while (ValVT != VT);
 
   return Val;
@@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
   NODE_NAME_CASE(READ_VLENB)
   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDE1UP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
@@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(UADDSAT_VL)
   NODE_NAME_CASE(SSUBSAT_VL)
   NODE_NAME_CASE(USUBSAT_VL)
-  NODE_NAME_CASE(VNCLIP_VL)
-  NODE_NAME_CASE(VNCLIPU_VL)
   NODE_NAME_CASE(FADD_VL)
   NODE_NAME_CASE(FSUB_VL)
   NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..e469a4b1238c7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,12 @@ enum NodeType : unsigned {
   // Truncates a RVV integer vector by one power-of-two. Carries both an extra
   // mask and VL operand.
   TRUNCATE_VECTOR_VL,
+  // Truncates a RVV integer vector by one power-of-two. If the value doesn't
+  // fit in the destination type, the result is saturated. These correspond to
+  // vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
+  // operand.
+  TRUNCATE_VECTOR_VL_SSAT,
+  TRUNCATE_VECTOR_VL_USAT,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the XLenVT
   // index (either constant or non-constant), the fourth is the mask, the fifth
@@ -273,10 +279,6 @@ enum NodeType : unsigned {
   // Rounding averaging adds of unsigned integers.
   AVGCEILU_VL,
 
-  // Operands are (source, shift, merge, mask, roundmode, vl)
-  VNCLIPU_VL,
-  VNCLIP_VL,
-
   MULHS_VL,
   MULHU_VL,
   FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index cc294bf9254e8..2ed71f6b88974 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -132,9 +132,6 @@ def riscv_uaddsat_vl   : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
 def riscv_ssubsat_vl   : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 def riscv_usubsat_vl   : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 
-def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
-def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
-
 def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
 def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
 def riscv_fmul_vl  : SDNode<"RISCVISD::FMUL_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
                             [(riscv_sext_vl node:$A, node:$B, node:$C),
                              (riscv_zext_vl node:$A, node:$B, node:$C)]>;
 
+def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
+                                                 SDTCisSameNumEltsAs<0, 1>,
+                                                 SDTCisSameNumEltsAs<0, 2>,
+                                                 SDTCVecEltisVT<2, i1>,
+                                                 SDTCisVT<3, XLenVT>]>;
 def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
-                                   SDTypeProfile<1, 3, [SDTCisVec<0>,
-                                                        SDTCisSameNumEltsAs<0, 1>,
-                                                        SDTCisSameNumEltsAs<0, 2>,
-                                                        SDTCVecEltisVT<2, i1>,
-                                                        SDTCisVT<3, XLenVT>]>>;
+                                   SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
+                                        SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
+                                        SDT_RISCVVTRUNCATE_VL>;
 
 def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
                                                   SDTCisInt<1>,
@@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
                    op2_reg_class:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
-multiclass VPatBinaryRM_VL_V<SDNode vop,
-                             string instruction_name,
-                             string suffix,
-                             ValueType result_type,
-                             ValueType op1_type,
-                             ValueType op2_type,
-                             ValueType mask_type,
-                             int sew,
-                             LMULInfo vlmul,
-                             VReg result_reg_class,
-                             VReg op1_reg_class,
-                             VReg op2_reg_class> {
-  def : Pat<(result_type (vop
-                         (op1_type op1_reg_class:$rs1),
-                         (op2_type op2_reg_class:$rs2),
-                         (result_type result_reg_class:$merge),
-                         (mask_type V0),
-                         (XLenVT timm:$roundmode),
-                         VLOpFrag)),
-        (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
-                     result_reg_class:$merge,
-                     op1_reg_class:$rs1,
-                     op2_reg_class:$rs2,
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
 class VPatBinaryVL_V_RM<SDPatternOperator vop,
                         string instruction_name,
                         string suffix,
@@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
                    xop_kind:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
-multiclass VPatBinaryRM_VL_XI<SDNode vop,
-                              string instruction_name,
-                              string suffix,
-                              ValueType result_type,
-                              ValueType vop1_type,
-                              ValueType vop2_type,
-                              ValueType mask_type,
-                              int sew,
-                              LMULInfo vlmul,
-                              VReg result_reg_class,
-                              VReg vop_reg_class,
-                              ComplexPattern SplatPatKind,
-                              DAGOperand xop_kind> {
-  def : Pat<(result_type (vop
-                     (vop1_type vop_reg_class:$rs1),
-                     (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
-                     (result_type result_reg_class:$merge),
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     VLOpFrag)),
-        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
-                     result_reg_class:$merge,
-                     vop_reg_class:$rs1,
-                     xop_kind:$rs2,
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
 multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
                               list<VTypeInfo> vtilist = AllIntegerVectors,
                               bit isSEWAware = 0> {
@@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
   }
 }
 
-multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
-  foreach VtiToWti = AllWidenableIntVectors in {
-    defvar vti = VtiToWti.Vti;
-    defvar wti = VtiToWti.Wti;
-    defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
-                             vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                             vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
-    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
-                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
-    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
-                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
-                              !cast<ComplexPattern>(SplatPat#_#uimm5),
-                              uimm5>;
-  }
-}
-
 class VPatBinaryVL_VF<SDPatternOperator vop,
                       string instruction_name,
                       ValueType result_type,
@@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
 
 // 12.5. Vector Narrowing Fixed-Point Clip Instructions
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+foreach vtiTowti = AllWidenableIntVectors in {
+  defvar vti = vtiTowti.Vti;
+  defvar wti = vtiTowti.Wti;
+  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                               GetVTypePredicates<wti>.Predicates) in {
+    // Rounding mode here is arbitrary since we aren't shifting out any bits.
+    def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
+                                                      (vti.Mask V0),
+                                                      VLOpFrag)),
+              (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
+                  (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+                  (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+    def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
+                                                      (vti.Mask V0),
+                                                      VLOpFrag)),
+              (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
+                  (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+                  (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+  }
+}
 
 // 13. Vector Floating-Point Instructions
 



More information about the llvm-commits mailing list