[llvm] ce8f094 - [RISCV] Add patterns for vnsrl.vx where shift amount is truncated

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 26 12:26:51 PDT 2023


Author: Luke Lau
Date: 2023-07-26T20:26:32+01:00
New Revision: ce8f094da8b2224b46e4f7192502d38a28f6aabd

URL: https://github.com/llvm/llvm-project/commit/ce8f094da8b2224b46e4f7192502d38a28f6aabd
DIFF: https://github.com/llvm/llvm-project/commit/ce8f094da8b2224b46e4f7192502d38a28f6aabd.diff

LOG: [RISCV] Add patterns for vnsrl.vx where shift amount is truncated

Similar to D155698 where the shift amount is extended, this patch extends the
ComplexPattern to handle the case where the shift amount has been truncated.
Truncations are custom lowered to truncate_vector_vl, and in cases like i64 ->
i16 they are truncated by one power of two at a time, so we need to unravel
nested layers of them.

The pattern can also be reused for Zvbb's vwsll.vx in an upcoming patch.

Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D155928

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
    llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
    llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
    llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index cafce628cf6a22..ddc42b4e13c401 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -3017,13 +3017,29 @@ bool RISCVDAGToDAGISel::selectVSplatUimm(SDValue N, unsigned Bits,
   return true;
 }
 
-bool RISCVDAGToDAGISel::selectExtOneUseVSplat(SDValue N, SDValue &SplatVal) {
-  if (N->getOpcode() == ISD::SIGN_EXTEND ||
-      N->getOpcode() == ISD::ZERO_EXTEND) {
-    if (!N.hasOneUse())
+bool RISCVDAGToDAGISel::selectLow8BitsVSplat(SDValue N, SDValue &SplatVal) {
+  // Truncates are custom lowered during legalization.
+  auto IsTrunc = [this](SDValue N) {
+    if (N->getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL)
+      return false;
+    SDValue VL;
+    selectVLOp(N->getOperand(2), VL);
+    // Any vmset_vl is ok, since any bits past VL are undefined and we can
+    // assume they are set.
+    return N->getOperand(1).getOpcode() == RISCVISD::VMSET_VL &&
+           isa<ConstantSDNode>(VL) &&
+           cast<ConstantSDNode>(VL)->getSExtValue() == RISCV::VLMaxSentinel;
+  };
+
+  // We can have multiple nested truncates, so unravel them all if needed.
+  while (N->getOpcode() == ISD::SIGN_EXTEND ||
+         N->getOpcode() == ISD::ZERO_EXTEND || IsTrunc(N)) {
+    if (!N.hasOneUse() ||
+        N.getValueType().getSizeInBits().getKnownMinValue() < 8)
       return false;
     N = N->getOperand(0);
   }
+
   return selectVSplat(N, SplatVal);
 }
 

diff  --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
index 281719c12e7032..d7ee20eb4eedc1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -134,7 +134,9 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
   }
   bool selectVSplatSimm5Plus1(SDValue N, SDValue &SplatVal);
   bool selectVSplatSimm5Plus1NonZero(SDValue N, SDValue &SplatVal);
-  bool selectExtOneUseVSplat(SDValue N, SDValue &SplatVal);
+  // Matches the splat of a value which can be extended or truncated, such that
+  // only the bottom 8 bits are preserved.
+  bool selectLow8BitsVSplat(SDValue N, SDValue &SplatVal);
   bool selectFPImm(SDValue N, SDValue &Imm);
 
   bool selectRVVSimm5(SDValue N, unsigned Width, SDValue &Imm);

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 900f9dd1be0535..90b863318b3397 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -577,8 +577,10 @@ def SplatPat_simm5_plus1
 def SplatPat_simm5_plus1_nonzero
     : ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1NonZero", [], [], 3>;
 
-def ext_oneuse_SplatPat
-    : ComplexPattern<vAny, 1, "selectExtOneUseVSplat", [], [], 2>;
+// Selects extends or truncates of splats where we only care about the lowest 8
+// bits of each element.
+def Low8BitsSplatPat
+    : ComplexPattern<vAny, 1, "selectLow8BitsVSplat", [], [], 2>;
 
 def SelectFPImm : ComplexPattern<fAny, 1, "selectFPImm", [], [], 1>;
 
@@ -1453,7 +1455,7 @@ multiclass VPatBinaryVL_WV_WX_WI<SDNode op, string instruction_name> {
         (vti.Vector
           (riscv_trunc_vector_vl
             (op (wti.Vector wti.RegClass:$rs2),
-                (wti.Vector (ext_oneuse_SplatPat (XLenVT GPR:$rs1)))),
+                (wti.Vector (Low8BitsSplatPat (XLenVT GPR:$rs1)))),
             (vti.Mask true_mask),
             VLOpFrag)),
         (!cast<Instruction>(instruction_name#"_WX_"#vti.LMul.MX)

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll
index 1a80e0a7c65795..f19f0addd87c7a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vnsrl-sdnode.ll
@@ -652,13 +652,8 @@ define <vscale x 1 x i16> @vnsrl_wx_i64_nxv1i16(<vscale x 1 x i32> %va, i64 %b)
 ;
 ; RV64-LABEL: vnsrl_wx_i64_nxv1i16:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; RV64-NEXT:    vmv.v.x v9, a0
-; RV64-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; RV64-NEXT:    vnsrl.wi v9, v9, 0
-; RV64-NEXT:    vsrl.vv v8, v8, v9
-; RV64-NEXT:    vsetvli zero, zero, e16, mf4, ta, ma
-; RV64-NEXT:    vnsrl.wi v8, v8, 0
+; RV64-NEXT:    vsetvli a1, zero, e16, mf4, ta, ma
+; RV64-NEXT:    vnsrl.wx v8, v8, a0
 ; RV64-NEXT:    ret
   %head = insertelement <vscale x 1 x i64> poison, i64 %b, i32 0
   %splat = shufflevector <vscale x 1 x i64> %head, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
@@ -689,15 +684,8 @@ define <vscale x 1 x i8> @vnsrl_wx_i64_nxv1i8(<vscale x 1 x i16> %va, i64 %b) {
 ;
 ; RV64-LABEL: vnsrl_wx_i64_nxv1i8:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
-; RV64-NEXT:    vmv.v.x v9, a0
-; RV64-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; RV64-NEXT:    vnsrl.wi v9, v9, 0
-; RV64-NEXT:    vsetvli zero, zero, e16, mf4, ta, ma
-; RV64-NEXT:    vnsrl.wi v9, v9, 0
-; RV64-NEXT:    vsrl.vv v8, v8, v9
-; RV64-NEXT:    vsetvli zero, zero, e8, mf8, ta, ma
-; RV64-NEXT:    vnsrl.wi v8, v8, 0
+; RV64-NEXT:    vsetvli a1, zero, e8, mf8, ta, ma
+; RV64-NEXT:    vnsrl.wx v8, v8, a0
 ; RV64-NEXT:    ret
   %head = insertelement <vscale x 1 x i64> poison, i64 %b, i32 0
   %splat = shufflevector <vscale x 1 x i64> %head, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer


        


More information about the llvm-commits mailing list