[llvm] [RISCV] Remove RISCVISD::VNSRL_VL and adjust deinterleave lowering to match (PR #118391)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 2 11:41:17 PST 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/118391

Instead of directly lowering to vnsrl_vl and having custom pattern matching for that case, we can just lower to a (legal) shift and truncate, and let generic pattern matching produce the vnsrl.

The major motivation for this is that I'm going to reuse this logic to handle e.g. deinterleave4 w/ i8 result.

The test changes aren't particularly interesting.  They're minor code improvements - I think because we do slightly better with the insert_subvector patterns, but that's mostly irrelevant.

>From 1aa53646afdb6d8393d0a0dfbb2d263b29a72222 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 2 Dec 2024 11:15:33 -0800
Subject: [PATCH] [RISCV] Remove RISCVISD::VNSRL_VL and adjust deinterleave
 lowering to match

Instead of directly lowering to vnsrl_vl and having custom pattern matching
for that case, we can just lower to a (legal) shift and truncate, and let
generic pattern matching produce the vnsrl.

The major motivation for this is that I'm going to reuse this logic to
handle e.g. deinterleave4 w/ i8 result.

The test changes aren't particularly interesting.  They're minor code
improvements - I think because we do slightly better with the
insert_subvector patterns, but that's mostly irrelevant.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 66 +++++++------------
 llvm/lib/Target/RISCV/RISCVISelLowering.h     |  4 --
 .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 36 ----------
 .../fixed-vectors-shuffle-changes-length.ll   | 29 ++++----
 4 files changed, 36 insertions(+), 99 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 78dc3cb27a6988..c68abcd5916049 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4618,51 +4618,32 @@ static int isElementRotate(int &LoSrc, int &HiSrc, ArrayRef<int> Mask) {
 // VT is the type of the vector to return, <[vscale x ]n x ty>
 // Src is the vector to deinterleave of type <[vscale x ]n*2 x ty>
 static SDValue getDeinterleaveViaVNSRL(const SDLoc &DL, MVT VT, SDValue Src,
-                                       bool EvenElts,
-                                       const RISCVSubtarget &Subtarget,
-                                       SelectionDAG &DAG) {
-  // The result is a vector of type <m x n x ty>
-  MVT ContainerVT = VT;
-  // Convert fixed vectors to scalable if needed
-  if (ContainerVT.isFixedLengthVector()) {
-    assert(Src.getSimpleValueType().isFixedLengthVector());
-    ContainerVT = getContainerForFixedLengthVector(DAG, ContainerVT, Subtarget);
-
-    // The source is a vector of type <m x n*2 x ty> (For the single source
-    // case, the high half is undefined)
-    MVT SrcContainerVT =
-        MVT::getVectorVT(ContainerVT.getVectorElementType(),
-                         ContainerVT.getVectorElementCount() * 2);
-    Src = convertToScalableVector(SrcContainerVT, Src, DAG, Subtarget);
+                                       bool EvenElts, SelectionDAG &DAG) {
+  // FLAGIT
+  // The result is a vector of type <m x n x ty>.  The source is a vector of
+  // type <m x n*2 x ty> (For the single source case, the high half is undef)
+  if (Src.getValueType() == VT) {
+    EVT WideVT = VT.getDoubleNumVectorElementsVT();
+    Src = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideVT, DAG.getUNDEF(WideVT),
+                      Src, DAG.getVectorIdxConstant(0, DL));
   }
 
-  auto [TrueMask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
-
   // Bitcast the source vector from <m x n*2 x ty> -> <m x n x ty*2>
   // This also converts FP to int.
-  unsigned EltBits = ContainerVT.getScalarSizeInBits();
-  MVT WideSrcContainerVT = MVT::getVectorVT(
-      MVT::getIntegerVT(EltBits * 2), ContainerVT.getVectorElementCount());
-  Src = DAG.getBitcast(WideSrcContainerVT, Src);
+  unsigned EltBits = VT.getScalarSizeInBits();
+  MVT WideSrcVT = MVT::getVectorVT(MVT::getIntegerVT(EltBits * 2),
+                                   VT.getVectorElementCount());
+  Src = DAG.getBitcast(WideSrcVT, Src);
 
-  // The integer version of the container type.
-  MVT IntContainerVT = ContainerVT.changeVectorElementTypeToInteger();
+  MVT IntVT = VT.changeVectorElementTypeToInteger();
 
   // If we want even elements, then the shift amount is 0. Otherwise, shift by
   // the original element size.
   unsigned Shift = EvenElts ? 0 : EltBits;
-  SDValue SplatShift = DAG.getNode(
-      RISCVISD::VMV_V_X_VL, DL, IntContainerVT, DAG.getUNDEF(ContainerVT),
-      DAG.getConstant(Shift, DL, Subtarget.getXLenVT()), VL);
-  SDValue Res =
-      DAG.getNode(RISCVISD::VNSRL_VL, DL, IntContainerVT, Src, SplatShift,
-                  DAG.getUNDEF(IntContainerVT), TrueMask, VL);
-  // Cast back to FP if needed.
-  Res = DAG.getBitcast(ContainerVT, Res);
-
-  if (VT.isFixedLengthVector())
-    Res = convertFromScalableVector(VT, Res, DAG, Subtarget);
-  return Res;
+  SDValue Res = DAG.getNode(ISD::SRL, DL, WideSrcVT, Src,
+                            DAG.getConstant(Shift, DL, WideSrcVT));
+  Res = DAG.getNode(ISD::TRUNCATE, DL, IntVT, Res);
+  return DAG.getBitcast(VT, Res);
 }
 
 // Lower the following shuffle to vslidedown.
@@ -5356,7 +5337,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
   // vnsrl to deinterleave.
   if (SDValue Src =
           isDeinterleaveShuffle(VT, ContainerVT, V1, V2, Mask, Subtarget))
-    return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, Subtarget, DAG);
+    return getDeinterleaveViaVNSRL(DL, VT, Src, Mask[0] == 0, DAG);
 
   if (SDValue V =
           lowerVECTOR_SHUFFLEAsVSlideup(DL, VT, V1, V2, Mask, Subtarget, DAG))
@@ -6258,7 +6239,7 @@ static bool hasPassthruOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    127 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -6284,7 +6265,7 @@ static bool hasMaskOp(unsigned Opcode) {
          Opcode <= RISCVISD::LAST_RISCV_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP ==
-                    128 &&
+                    127 &&
                 RISCVISD::LAST_RISCV_STRICTFP_OPCODE -
                         ISD::FIRST_TARGET_STRICTFP_OPCODE ==
                     21 &&
@@ -10763,10 +10744,8 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
   // We can deinterleave through vnsrl.wi if the element type is smaller than
   // ELEN
   if (VecVT.getScalarSizeInBits() < Subtarget.getELen()) {
-    SDValue Even =
-        getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, Subtarget, DAG);
-    SDValue Odd =
-        getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, Subtarget, DAG);
+    SDValue Even = getDeinterleaveViaVNSRL(DL, VecVT, Concat, true, DAG);
+    SDValue Odd = getDeinterleaveViaVNSRL(DL, VecVT, Concat, false, DAG);
     return DAG.getMergeValues({Even, Odd}, DL);
   }
 
@@ -20494,7 +20473,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(VWMACC_VL)
   NODE_NAME_CASE(VWMACCU_VL)
   NODE_NAME_CASE(VWMACCSU_VL)
-  NODE_NAME_CASE(VNSRL_VL)
   NODE_NAME_CASE(SETCC_VL)
   NODE_NAME_CASE(VMERGE_VL)
   NODE_NAME_CASE(VMAND_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 7ada941563c1ff..c753469562ebac 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -369,10 +369,6 @@ enum NodeType : unsigned {
   VWMACCU_VL,
   VWMACCSU_VL,
 
-  // Narrowing logical shift right.
-  // Operands are (source, shift, passthru, mask, vl)
-  VNSRL_VL,
-
   // Vector compare producing a mask. Fourth operand is input mask. Fifth
   // operand is VL.
   SETCC_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 4b938fc734e5c1..e48a6f9309294b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -459,17 +459,6 @@ def riscv_vfwmul_vl : SDNode<"RISCVISD::VFWMUL_VL", SDT_RISCVVWFPBinOp_VL, [SDNP
 def riscv_vfwadd_vl : SDNode<"RISCVISD::VFWADD_VL", SDT_RISCVVWFPBinOp_VL, [SDNPCommutative]>;
 def riscv_vfwsub_vl : SDNode<"RISCVISD::VFWSUB_VL", SDT_RISCVVWFPBinOp_VL, []>;
 
-def SDT_RISCVVNIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
-                                                  SDTCisInt<1>,
-                                                  SDTCisSameNumEltsAs<0, 1>,
-                                                  SDTCisOpSmallerThanOp<0, 1>,
-                                                  SDTCisSameAs<0, 2>,
-                                                  SDTCisSameAs<0, 3>,
-                                                  SDTCisSameNumEltsAs<0, 4>,
-                                                  SDTCVecEltisVT<4, i1>,
-                                                  SDTCisVT<5, XLenVT>]>;
-def riscv_vnsrl_vl : SDNode<"RISCVISD::VNSRL_VL", SDT_RISCVVNIntBinOp_VL>;
-
 def SDT_RISCVVWIntBinOpW_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
                                                    SDTCisSameAs<0, 1>,
                                                    SDTCisInt<2>,
@@ -885,29 +874,6 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
   }
 }
 
-multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name> {
-  foreach VtiToWti = AllWidenableIntVectors in {
-    defvar vti = VtiToWti.Vti;
-    defvar wti = VtiToWti.Wti;
-    let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
-                                 GetVTypePredicates<wti>.Predicates) in {
-      def : VPatBinaryVL_V<vop, instruction_name, "WV",
-                           vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                           vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
-                           vti.RegClass>;
-      def : VPatBinaryVL_XI<vop, instruction_name, "WX",
-                            vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                            vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
-                            SplatPat, GPR>;
-      def : VPatBinaryVL_XI<vop, instruction_name, "WI",
-                            vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                            vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
-                            !cast<ComplexPattern>(SplatPat#_#uimm5),
-                            uimm5>;
-    }
-  }
-}
-
 class VPatBinaryVL_VF<SDPatternOperator vop,
                       string instruction_name,
                       ValueType result_type,
@@ -2166,8 +2132,6 @@ defm : VPatNarrowShiftSplatExt_WX<riscv_srl_vl, riscv_zext_vl_oneuse, "PseudoVNS
 defm : VPatNarrowShiftVL_WV<riscv_srl_vl, "PseudoVNSRL">;
 defm : VPatNarrowShiftVL_WV<riscv_sra_vl, "PseudoVNSRA">;
 
-defm : VPatBinaryNVL_WV_WX_WI<riscv_vnsrl_vl, "PseudoVNSRL">;
-
 foreach vtiTowti = AllWidenableIntVectors in {
   defvar vti = vtiTowti.Vti;
   defvar wti = vtiTowti.Wti;
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
index 8b18be908089f2..9d2c722334b080 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-shuffle-changes-length.ll
@@ -97,33 +97,31 @@ define <4 x i32> @v4i32_v8i32(<8 x i32>) {
 define <4 x i32> @v4i32_v16i32(<16 x i32>) {
 ; RV32-LABEL: v4i32_v16i32:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    vsetivli zero, 8, e32, m4, ta, ma
-; RV32-NEXT:    vslidedown.vi v16, v8, 8
-; RV32-NEXT:    vmv4r.v v20, v8
 ; RV32-NEXT:    vsetivli zero, 8, e16, m1, ta, ma
-; RV32-NEXT:    vmv.v.i v8, 1
-; RV32-NEXT:    vmv2r.v v22, v12
-; RV32-NEXT:    vmv.v.i v10, 6
+; RV32-NEXT:    vmv.v.i v12, 1
+; RV32-NEXT:    vmv.v.i v14, 6
 ; RV32-NEXT:    li a0, 32
 ; RV32-NEXT:    vmv.v.i v0, 10
 ; RV32-NEXT:    vsetivli zero, 2, e16, m1, tu, ma
-; RV32-NEXT:    vslideup.vi v10, v8, 1
+; RV32-NEXT:    vslideup.vi v14, v12, 1
+; RV32-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
+; RV32-NEXT:    vnsrl.wx v12, v8, a0
+; RV32-NEXT:    vsetivli zero, 8, e32, m4, ta, ma
+; RV32-NEXT:    vslidedown.vi v8, v8, 8
 ; RV32-NEXT:    vsetivli zero, 8, e32, m2, ta, mu
-; RV32-NEXT:    vnsrl.wx v8, v20, a0
-; RV32-NEXT:    vrgatherei16.vv v8, v16, v10, v0.t
+; RV32-NEXT:    vrgatherei16.vv v12, v8, v14, v0.t
+; RV32-NEXT:    vmv1r.v v8, v12
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: v4i32_v16i32:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetivli zero, 8, e32, m4, ta, ma
-; RV64-NEXT:    vslidedown.vi v16, v8, 8
-; RV64-NEXT:    vmv4r.v v20, v8
 ; RV64-NEXT:    li a0, 32
-; RV64-NEXT:    vmv2r.v v22, v12
 ; RV64-NEXT:    vsetivli zero, 1, e8, mf8, ta, ma
 ; RV64-NEXT:    vmv.v.i v0, 10
 ; RV64-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; RV64-NEXT:    vnsrl.wx v8, v20, a0
+; RV64-NEXT:    vnsrl.wx v12, v8, a0
+; RV64-NEXT:    vsetivli zero, 8, e32, m4, ta, ma
+; RV64-NEXT:    vslidedown.vi v8, v8, 8
 ; RV64-NEXT:    li a0, 3
 ; RV64-NEXT:    slli a0, a0, 33
 ; RV64-NEXT:    addi a0, a0, 1
@@ -131,7 +129,8 @@ define <4 x i32> @v4i32_v16i32(<16 x i32>) {
 ; RV64-NEXT:    vsetivli zero, 2, e64, m1, ta, ma
 ; RV64-NEXT:    vmv.v.x v10, a0
 ; RV64-NEXT:    vsetivli zero, 8, e32, m2, ta, mu
-; RV64-NEXT:    vrgatherei16.vv v8, v16, v10, v0.t
+; RV64-NEXT:    vrgatherei16.vv v12, v8, v10, v0.t
+; RV64-NEXT:    vmv1r.v v8, v12
 ; RV64-NEXT:    ret
   %2 = shufflevector <16 x i32> %0, <16 x i32> poison, <4 x i32> <i32 1, i32 9, i32 5, i32 14>
   ret <4 x i32> %2



More information about the llvm-commits mailing list