[llvm] [RISCV] Move vnclipu patterns into DAGCombiner. (PR #93596)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed May 29 12:08:22 PDT 2024


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/93596

>From 20f81ebbbfcabd65f079fb507fd8281ce07f23af Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Tue, 28 May 2024 11:29:22 -0700
Subject: [PATCH] [RISCV] Move vnclipu patterns into DAGCombiner.

I plan to add support for multiple layers of vnclipu. For example,
i32->i8 using 2 vnclipu instructions. First clipping to 65535, then
clipping to 255. Similar for signed vnclip.

This scales poorly if we need to add patterns with 2 or 3 truncates.
Instead, move the code to DAGCombiner with new ISD opcodes to represent
VCLIP(U).

This patch just moves the existing patterns into DAG combine. Support
for multiple truncates will as a follow up. A similar patch will be
made for the signed vnclip patterns.
---
 .../Target/RISCV/MCTargetDesc/RISCVBaseInfo.h |   9 ++
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  70 +++++++++++-
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |   4 +
 .../Target/RISCV/RISCVInstrInfoVSDPatterns.td |   7 --
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 104 ++++++++++++++++--
 5 files changed, 174 insertions(+), 20 deletions(-)

diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
index 08f056f78979a..550904516ac8e 100644
--- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
+++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
@@ -373,6 +373,15 @@ inline static bool isValidRoundingMode(unsigned Mode) {
 }
 } // namespace RISCVFPRndMode
 
+namespace RISCVVXRndMode {
+enum RoundingMode {
+  RNU = 0,
+  RNE = 1,
+  RDN = 2,
+  ROD = 3,
+};
+} // namespace RISCVVXRndMode
+
 //===----------------------------------------------------------------------===//
 // Floating-point Immediates
 //
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index f4da46f82a810..0242cfe178524 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -5960,7 +5960,7 @@ static bool hasMergeOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    130 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -5986,7 +5986,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    130 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -16183,6 +16183,66 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
   return DAG.getNode(ISD::SRA, SDLoc(N), N->getValueType(0), N00, SMin);
 }
 
+// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
+// value for the truncated type.
+static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
+                                     const RISCVSubtarget &Subtarget) {
+  assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
+
+  MVT VT = N->getSimpleValueType(0);
+
+  SDValue Mask = N->getOperand(1);
+  SDValue VL = N->getOperand(2);
+
+  SDValue Src = N->getOperand(0);
+
+  // Src must be a UMIN or UMIN_VL.
+  if (Src.getOpcode() != ISD::UMIN &&
+      !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
+        Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
+    return SDValue();
+
+  auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
+    // Peek through conversion between fixed and scalable vectors.
+    if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
+        isNullConstant(Op.getOperand(2)) &&
+        Op.getOperand(1).getValueType().isFixedLengthVector() &&
+        Op.getOperand(1).getOpcode() == ISD::EXTRACT_SUBVECTOR &&
+        Op.getOperand(1).getOperand(0).getValueType() == Op.getValueType() &&
+        isNullConstant(Op.getOperand(1).getOperand(1)))
+      Op = Op.getOperand(1).getOperand(0);
+
+    if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
+      return true;
+
+    if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
+        Op.getOperand(2) == VL) {
+      if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
+        SplatVal =
+            Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
+        return true;
+      }
+    }
+
+    return false;
+  };
+
+  APInt C;
+  if (!IsSplat(Src.getOperand(1), C))
+    return SDValue();
+
+  if (!C.isMask(VT.getScalarSizeInBits()))
+    return SDValue();
+
+  SDLoc DL(N);
+  // Rounding mode here is arbitrary since we aren't shifting out any bits.
+  return DAG.getNode(
+      RISCVISD::VNCLIPU_VL, DL, VT,
+      {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+       DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
+       VL});
+}
+
 SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
@@ -16400,7 +16460,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     return SDValue();
   case RISCVISD::TRUNCATE_VECTOR_VL:
-    return combineTruncOfSraSext(N, DAG);
+    if (SDValue V = combineTruncOfSraSext(N, DAG))
+      return V;
+    return combineTruncToVnclipu(N, DAG, Subtarget);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:
@@ -20019,6 +20081,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(UADDSAT_VL)
   NODE_NAME_CASE(SSUBSAT_VL)
   NODE_NAME_CASE(USUBSAT_VL)
+  NODE_NAME_CASE(VNCLIP_VL)
+  NODE_NAME_CASE(VNCLIPU_VL)
   NODE_NAME_CASE(FADD_VL)
   NODE_NAME_CASE(FSUB_VL)
   NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 856ce06ba1c4f..3b8eb3c88901a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -273,6 +273,10 @@ enum NodeType : unsigned {
   // Rounding averaging adds of unsigned integers.
   AVGCEILU_VL,
 
+  // Operands are (source, shift, merge, mask, roundmode, vl)
+  VNCLIPU_VL,
+  VNCLIP_VL,
+
   MULHS_VL,
   MULHU_VL,
   FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 66df24f2a458d..691f2052ab29d 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1196,13 +1196,6 @@ multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
       (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
         (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
         (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (umin (wti.Vector wti.RegClass:$rs1),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))))), (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 91f3abe22331e..610a72dd02b38 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -36,6 +36,18 @@ def SDT_RISCVIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisSameAs<0, 1>,
                                                 SDTCisSameNumEltsAs<0, 4>,
                                                 SDTCisVT<5, XLenVT>]>;
 
+// Input: (vector, vector/scalar, merge, mask, roundmode, vl)
+def SDT_RISCVVNBinOp_RM_VL : SDTypeProfile<1, 6, [SDTCisVec<0>, SDTCisInt<0>,
+                                                  SDTCisSameAs<0, 3>,
+                                                  SDTCisSameNumEltsAs<0, 1>,
+                                                  SDTCisVec<1>,
+                                                  SDTCisOpSmallerThanOp<2, 1>,
+                                                  SDTCisSameAs<0, 2>,
+                                                  SDTCisSameNumEltsAs<0, 4>,
+                                                  SDTCVecEltisVT<4, i1>,
+                                                  SDTCisVT<5, XLenVT>,
+                                                  SDTCisVT<6, XLenVT>]>;
+
 def SDT_RISCVFPUnOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
                                               SDTCisVec<0>, SDTCisFP<0>,
                                               SDTCVecEltisVT<2, i1>,
@@ -120,6 +132,9 @@ def riscv_uaddsat_vl   : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
 def riscv_ssubsat_vl   : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 def riscv_usubsat_vl   : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 
+def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
+def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
+
 def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
 def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
 def riscv_fmul_vl  : SDNode<"RISCVISD::FMUL_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -635,6 +650,34 @@ class VPatBinaryVL_V<SDPatternOperator vop,
                    op2_reg_class:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
+multiclass VPatBinaryRM_VL_V<SDNode vop,
+                             string instruction_name,
+                             string suffix,
+                             ValueType result_type,
+                             ValueType op1_type,
+                             ValueType op2_type,
+                             ValueType mask_type,
+                             int sew,
+                             LMULInfo vlmul,
+                             VReg result_reg_class,
+                             VReg op1_reg_class,
+                             VReg op2_reg_class> {
+  def : Pat<(result_type (vop
+                         (op1_type op1_reg_class:$rs1),
+                         (op2_type op2_reg_class:$rs2),
+                         (result_type result_reg_class:$merge),
+                         (mask_type V0),
+                         (XLenVT timm:$roundmode),
+                         VLOpFrag)),
+        (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+                     result_reg_class:$merge,
+                     op1_reg_class:$rs1,
+                     op2_reg_class:$rs2,
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
 class VPatBinaryVL_V_RM<SDPatternOperator vop,
                         string instruction_name,
                         string suffix,
@@ -795,6 +838,35 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
                    xop_kind:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
+multiclass VPatBinaryRM_VL_XI<SDNode vop,
+                              string instruction_name,
+                              string suffix,
+                              ValueType result_type,
+                              ValueType vop1_type,
+                              ValueType vop2_type,
+                              ValueType mask_type,
+                              int sew,
+                              LMULInfo vlmul,
+                              VReg result_reg_class,
+                              VReg vop_reg_class,
+                              ComplexPattern SplatPatKind,
+                              DAGOperand xop_kind> {
+  def : Pat<(result_type (vop
+                     (vop1_type vop_reg_class:$rs1),
+                     (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
+                     (result_type result_reg_class:$merge),
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     VLOpFrag)),
+        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
+                     result_reg_class:$merge,
+                     vop_reg_class:$rs1,
+                     xop_kind:$rs2,
+                     (mask_type V0),
+                     (XLenVT timm:$roundmode),
+                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
+}
+
 multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
                               list<VTypeInfo> vtilist = AllIntegerVectors,
                               bit isSEWAware = 0> {
@@ -893,6 +965,24 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
   }
 }
 
+multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
+  foreach VtiToWti = AllWidenableIntVectors in {
+    defvar vti = VtiToWti.Vti;
+    defvar wti = VtiToWti.Wti;
+    defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
+                             vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                             vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
+    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
+                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
+    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
+                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
+                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
+                              !cast<ComplexPattern>(SplatPat#_#uimm5),
+                              uimm5>;
+  }
+}
+
 class VPatBinaryVL_VF<SDPatternOperator vop,
                       string instruction_name,
                       ValueType result_type,
@@ -2376,6 +2466,10 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgflooru_vl, 0b10, suffix="U">;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
 
+// 12.5. Vector Narrowing Fixed-Point Clip Instructions
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
+defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+
 // 12.5. Vector Narrowing Fixed-Point Clip Instructions
 multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
   defvar sew = vti.SEW;
@@ -2410,16 +2504,6 @@ multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
       (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
         (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
         (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (riscv_umin_vl
-          (wti.Vector wti.RegClass:$rs1),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), uminval, (XLenVT srcvalue))),
-          (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
   }
 }
 



More information about the llvm-commits mailing list