[llvm] 0d7286a - [WebAssembly] Avoid scalarizing vector shifts in more cases

Thomas Lively via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 7 10:45:35 PDT 2020


Author: Thomas Lively
Date: 2020-07-07T10:45:26-07:00
New Revision: 0d7286a652371bca460357348f3b4828cd4ca214

URL: https://github.com/llvm/llvm-project/commit/0d7286a652371bca460357348f3b4828cd4ca214
DIFF: https://github.com/llvm/llvm-project/commit/0d7286a652371bca460357348f3b4828cd4ca214.diff

LOG: [WebAssembly] Avoid scalarizing vector shifts in more cases

Since WebAssembly's vector shift instructions take a scalar shift
amount rather than a vector shift amount, we have to check in ISel
that the vector shift amount is a splat. Previously, we were checking
explicitly for splat BUILD_VECTOR nodes, but this change uses the
standard utilities for detecting splat values that can handle more
complex splat patterns. Since the C++ ISel lowering is now more
general than the ISel patterns, this change also simplifies shift
lowering by using the C++ lowering for all SIMD shifts rather than
mixing C++ and normal pattern-based lowering.

This change improves ISel for shifts to the point that the
simd-shift-unroll.ll regression test no longer tests the code path it
was originally meant to test. The bug corresponding to that regression
test is no longer reproducible with its original reported reproducer,
so rather than try to fix the regression test, this change just
removes it.

Differential Revision: https://reviews.llvm.org/D83278

Added: 
    llvm/test/CodeGen/WebAssembly/simd-shift-complex-splats.ll

Modified: 
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td

Removed: 
    llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll


################################################################################
diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 651c504efc06..3f4ebd501595 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -1677,19 +1677,13 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
   // Only manually lower vector shifts
   assert(Op.getSimpleValueType().isVector());
 
-  // Unroll non-splat vector shifts
-  BuildVectorSDNode *ShiftVec;
-  SDValue SplatVal;
-  if (!(ShiftVec = dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode())) ||
-      !(SplatVal = ShiftVec->getSplatValue()))
+  auto ShiftVal = Op.getOperand(1);
+  if (!DAG.isSplatValue(ShiftVal, /*AllowUndefs=*/true))
     return unrollVectorShift(Op, DAG);
 
-  // All splats except i64x2 const splats are handled by patterns
-  auto *SplatConst = dyn_cast<ConstantSDNode>(SplatVal);
-  if (!SplatConst || Op.getSimpleValueType() != MVT::v2i64)
-    return Op;
+  auto SplatVal = DAG.getSplatValue(ShiftVal);
+  assert(SplatVal != SDValue());
 
-  // i64x2 const splats are custom lowered to avoid unnecessary wraps
   unsigned Opcode;
   switch (Op.getOpcode()) {
   case ISD::SHL:
@@ -1704,9 +1698,11 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
   default:
     llvm_unreachable("unexpected opcode");
   }
-  APInt Shift = SplatConst->getAPIntValue().zextOrTrunc(32);
+
+  // Use anyext because none of the high bits can affect the shift
+  auto ScalarShift = DAG.getAnyExtOrTrunc(SplatVal, DL, MVT::i32);
   return DAG.getNode(Opcode, DL, Op.getValueType(), Op.getOperand(0),
-                     DAG.getConstant(Shift, DL, MVT::i32));
+                     ScalarShift);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index b4a8a7bc42ae..814bb80fb693 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -654,55 +654,36 @@ defm BITMASK : SIMDBitmask<v4i32, "i32x4", 164>;
 // Bit shifts
 //===----------------------------------------------------------------------===//
 
-multiclass SIMDShift<ValueType vec_t, string vec, SDNode node, dag shift_vec,
-                     string name, bits<32> simdop> {
+multiclass SIMDShift<ValueType vec_t, string vec, SDNode node, string name,
+                     bits<32> simdop> {
   defm _#vec_t : SIMD_I<(outs V128:$dst), (ins V128:$vec, I32:$x),
                         (outs), (ins),
-                        [(set (vec_t V128:$dst),
-                          (node V128:$vec, (vec_t shift_vec)))],
+                        [(set (vec_t V128:$dst), (node V128:$vec, I32:$x))],
                         vec#"."#name#"\t$dst, $vec, $x", vec#"."#name, simdop>;
 }
 
 multiclass SIMDShiftInt<SDNode node, string name, bits<32> baseInst> {
-  defm "" : SIMDShift<v16i8, "i8x16", node, (splat16 I32:$x), name, baseInst>;
-  defm "" : SIMDShift<v8i16, "i16x8", node, (splat8 I32:$x), name,
-                      !add(baseInst, 32)>;
-  defm "" : SIMDShift<v4i32, "i32x4", node, (splat4 I32:$x), name,
-                      !add(baseInst, 64)>;
-  defm "" : SIMDShift<v2i64, "i64x2", node, (splat2 (i64 (zext I32:$x))),
-                      name, !add(baseInst, 96)>;
+  defm "" : SIMDShift<v16i8, "i8x16", node, name, baseInst>;
+  defm "" : SIMDShift<v8i16, "i16x8", node, name, !add(baseInst, 32)>;
+  defm "" : SIMDShift<v4i32, "i32x4", node, name, !add(baseInst, 64)>;
+  defm "" : SIMDShift<v2i64, "i64x2", node, name, !add(baseInst, 96)>;
 }
 
-// Left shift by scalar: shl
-defm SHL : SIMDShiftInt<shl, "shl", 107>;
-
-// Right shift by scalar: shr_s / shr_u
-defm SHR_S : SIMDShiftInt<sra, "shr_s", 108>;
-defm SHR_U : SIMDShiftInt<srl, "shr_u", 109>;
-
-// Truncate i64 shift operands to i32s, except if they are already i32s
-foreach shifts = [[shl, SHL_v2i64], [sra, SHR_S_v2i64], [srl, SHR_U_v2i64]] in {
-def : Pat<(v2i64 (shifts[0]
-            (v2i64 V128:$vec),
-            (v2i64 (splat2 (i64 (sext I32:$x))))
-          )),
-          (v2i64 (shifts[1] (v2i64 V128:$vec), (i32 I32:$x)))>;
-def : Pat<(v2i64 (shifts[0] (v2i64 V128:$vec), (v2i64 (splat2 I64:$x)))),
-          (v2i64 (shifts[1] (v2i64 V128:$vec), (I32_WRAP_I64 I64:$x)))>;
-}
-
-// 2xi64 shifts with constant shift amounts are custom lowered to avoid wrapping
+// WebAssembly SIMD shifts are nonstandard in that the shift amount is
+// an i32 rather than a vector, so they need custom nodes.
 def wasm_shift_t : SDTypeProfile<1, 2,
   [SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVT<2, i32>]
 >;
 def wasm_shl : SDNode<"WebAssemblyISD::VEC_SHL", wasm_shift_t>;
 def wasm_shr_s : SDNode<"WebAssemblyISD::VEC_SHR_S", wasm_shift_t>;
 def wasm_shr_u : SDNode<"WebAssemblyISD::VEC_SHR_U", wasm_shift_t>;
-foreach shifts = [[wasm_shl, SHL_v2i64],
-                  [wasm_shr_s, SHR_S_v2i64],
-                  [wasm_shr_u, SHR_U_v2i64]] in
-def : Pat<(v2i64 (shifts[0] (v2i64 V128:$vec), I32:$x)),
-          (v2i64 (shifts[1] (v2i64 V128:$vec), I32:$x))>;
+
+// Left shift by scalar: shl
+defm SHL : SIMDShiftInt<wasm_shl, "shl", 107>;
+
+// Right shift by scalar: shr_s / shr_u
+defm SHR_S : SIMDShiftInt<wasm_shr_s, "shr_s", 108>;
+defm SHR_U : SIMDShiftInt<wasm_shr_u, "shr_u", 109>;
 
 //===----------------------------------------------------------------------===//
 // Integer binary arithmetic
@@ -766,12 +747,12 @@ def add_nuw : PatFrag<(ops node:$lhs, node:$rhs),
                       "return N->getFlags().hasNoUnsignedWrap();">;
 
 foreach nodes = [[v16i8, splat16], [v8i16, splat8]] in
-def : Pat<(srl
+def : Pat<(wasm_shr_u
             (add_nuw
               (add_nuw (nodes[0] V128:$lhs), (nodes[0] V128:$rhs)),
               (nodes[1] (i32 1))
             ),
-            (nodes[0] (nodes[1] (i32 1)))
+            (i32 1)
           ),
           (!cast<NI>("AVGR_U_"#nodes[0]) V128:$lhs, V128:$rhs)>;
 

diff  --git a/llvm/test/CodeGen/WebAssembly/simd-shift-complex-splats.ll b/llvm/test/CodeGen/WebAssembly/simd-shift-complex-splats.ll
new file mode 100644
index 000000000000..ded430f89545
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-shift-complex-splats.ll
@@ -0,0 +1,27 @@
+; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+
+; Test that SIMD shifts can be lowered correctly even with shift
+; values that are more complex than plain splats.
+
+target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
+target triple = "wasm32-unknown-unknown"
+
+;; TODO: Optimize this further by scalarizing the add
+
+; CHECK-LABEL: shl_add:
+; CHECK-NEXT: .functype shl_add (v128, i32, i32) -> (v128)
+; CHECK-NEXT: i8x16.splat $push1=, $1
+; CHECK-NEXT: i8x16.splat $push0=, $2
+; CHECK-NEXT: i8x16.add $push2=, $pop1, $pop0
+; CHECK-NEXT: i8x16.extract_lane_u $push3=, $pop2, 0
+; CHECK-NEXT: i8x16.shl $push4=, $0, $pop3
+; CHECK-NEXT: return $pop4
+define <16 x i8> @shl_add(<16 x i8> %v, i8 %a, i8 %b) {
+  %t1 = insertelement <16 x i8> undef, i8 %a, i32 0
+  %va = shufflevector <16 x i8> %t1, <16 x i8> undef, <16 x i32> zeroinitializer
+  %t2 = insertelement <16 x i8> undef, i8 %b, i32 0
+  %vb = shufflevector <16 x i8> %t2, <16 x i8> undef, <16 x i32> zeroinitializer
+  %shift = add <16 x i8> %va, %vb
+  %r = shl <16 x i8> %v, %shift
+  ret <16 x i8> %r
+}

diff  --git a/llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll b/llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll
deleted file mode 100644
index 2a5422cb0110..000000000000
--- a/llvm/test/CodeGen/WebAssembly/simd-shift-unroll.ll
+++ /dev/null
@@ -1,128 +0,0 @@
-; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+unimplemented-simd128 | FileCheck %s --check-prefixes CHECK,SIMD128,SIMD128-SLOW
-
-;; Test that the custom shift unrolling works correctly in cases that
-;; cause assertion failures due to illegal types when using
-;; DAG.UnrollVectorOp. Regression test for PR45178.
-
-target datalayout = "e-m:e-p:32:32-i64:64-n32:64-S128"
-target triple = "wasm32-unknown-unknown"
-
-; CHECK-LABEL: shl_v16i8:
-; CHECK-NEXT: .functype       shl_v16i8 (v128) -> (v128)
-; CHECK-NEXT: i8x16.extract_lane_u    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 3
-; CHECK-NEXT: i32.shl         $push2=, $pop0, $pop1
-; CHECK-NEXT: i8x16.splat     $push3=, $pop2
-; CHECK-NEXT: i8x16.extract_lane_u    $push4=, $0, 1
-; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i8x16.extract_lane_u    $push32=, $0, 15
-; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
-; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
-; CHECK-NEXT: return  $pop34
-define <16 x i8> @shl_v16i8(<16 x i8> %in) {
-  %out = shl <16 x i8> %in,
-    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
-     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
-  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
-  ret <16 x i8> %ret
-}
-
-; CHECK-LABEL: shr_s_v16i8:
-; CHECK-NEXT: functype       shr_s_v16i8 (v128) -> (v128)
-; CHECK-NEXT: i8x16.extract_lane_s    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 3
-; CHECK-NEXT: i32.shr_s       $push2=, $pop0, $pop1
-; CHECK-NEXT: i8x16.splat     $push3=, $pop2
-; CHECK-NEXT: i8x16.extract_lane_s    $push4=, $0, 1
-; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i8x16.extract_lane_s    $push32=, $0, 15
-; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
-; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
-; CHECK-NEXT: return  $pop34
-define <16 x i8> @shr_s_v16i8(<16 x i8> %in) {
-  %out = ashr <16 x i8> %in,
-    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
-     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
-  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
-  ret <16 x i8> %ret
-}
-
-; CHECK-LABEL: shr_u_v16i8:
-; CHECK-NEXT: functype       shr_u_v16i8 (v128) -> (v128)
-; CHECK-NEXT: i8x16.extract_lane_u    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 3
-; CHECK-NEXT: i32.shr_u       $push2=, $pop0, $pop1
-; CHECK-NEXT: i8x16.splat     $push3=, $pop2
-; CHECK-NEXT: i8x16.extract_lane_u    $push4=, $0, 1
-; CHECK-NEXT: i8x16.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i8x16.extract_lane_u    $push32=, $0, 15
-; CHECK-NEXT: i8x16.replace_lane      $push33=, $pop31, 15, $pop32
-; CHECK-NEXT: v8x16.shuffle   $push34=, $pop33, $0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
-; CHECK-NEXT: return  $pop34
-define <16 x i8> @shr_u_v16i8(<16 x i8> %in) {
-  %out = lshr <16 x i8> %in,
-    <i8 3, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0,
-     i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0, i8 0>
-  %ret = shufflevector <16 x i8> %out, <16 x i8> undef, <16 x i32> zeroinitializer
-  ret <16 x i8> %ret
-}
-
-; CHECK-LABEL: shl_v8i16:
-; CHECK-NEXT: functype       shl_v8i16 (v128) -> (v128)
-; CHECK-NEXT: i16x8.extract_lane_u    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 9
-; CHECK-NEXT: i32.shl         $push2=, $pop0, $pop1
-; CHECK-NEXT: i16x8.splat     $push3=, $pop2
-; CHECK-NEXT: i16x8.extract_lane_u    $push4=, $0, 1
-; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i16x8.extract_lane_u    $push16=, $0, 7
-; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
-; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
-; CHECK-NEXT: return  $pop18
-define <8 x i16> @shl_v8i16(<8 x i16> %in) {
-  %out = shl <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
-  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
-  ret <8 x i16> %ret
-}
-
-; CHECK-LABEL: shr_s_v8i16:
-; CHECK-NEXT: functype       shr_s_v8i16 (v128) -> (v128)
-; CHECK-NEXT: i16x8.extract_lane_s    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 9
-; CHECK-NEXT: i32.shr_s       $push2=, $pop0, $pop1
-; CHECK-NEXT: i16x8.splat     $push3=, $pop2
-; CHECK-NEXT: i16x8.extract_lane_s    $push4=, $0, 1
-; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i16x8.extract_lane_s    $push16=, $0, 7
-; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
-; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
-; CHECK-NEXT: return  $pop18
-define <8 x i16> @shr_s_v8i16(<8 x i16> %in) {
-  %out = ashr <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
-  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
-  ret <8 x i16> %ret
-}
-
-; CHECK-LABEL: shr_u_v8i16:
-; CHECK-NEXT: functype       shr_u_v8i16 (v128) -> (v128)
-; CHECK-NEXT: i16x8.extract_lane_u    $push0=, $0, 0
-; CHECK-NEXT: i32.const       $push1=, 9
-; CHECK-NEXT: i32.shr_u       $push2=, $pop0, $pop1
-; CHECK-NEXT: i16x8.splat     $push3=, $pop2
-; CHECK-NEXT: i16x8.extract_lane_u    $push4=, $0, 1
-; CHECK-NEXT: i16x8.replace_lane      $push5=, $pop3, 1, $pop4
-; ...
-; CHECK:      i16x8.extract_lane_u    $push16=, $0, 7
-; CHECK-NEXT: i16x8.replace_lane      $push17=, $pop15, 7, $pop16
-; CHECK-NEXT: v8x16.shuffle   $push18=, $pop17, $0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1
-; CHECK-NEXT: return  $pop18
-define <8 x i16> @shr_u_v8i16(<8 x i16> %in) {
-  %out = lshr <8 x i16> %in, <i16 9, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>
-  %ret = shufflevector <8 x i16> %out, <8 x i16> undef, <8 x i32> zeroinitializer
-  ret <8 x i16> %ret
-}


        


More information about the llvm-commits mailing list