[llvm] 403926a - [WebAssembly] Skip implied bitmask operation in LowerShift

Jun Ma via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 1 17:37:51 PST 2023


Author: Jun Ma
Date: 2023-03-02T09:37:25+08:00
New Revision: 403926aefefb13553f89ad812b1e2385826a82ec

URL: https://github.com/llvm/llvm-project/commit/403926aefefb13553f89ad812b1e2385826a82ec
DIFF: https://github.com/llvm/llvm-project/commit/403926aefefb13553f89ad812b1e2385826a82ec.diff

LOG: [WebAssembly] Skip implied bitmask operation in LowerShift

This patch skips redundant explicit masks of the shift count since
it is implied inside wasm shift instruction.

Differential Revision: https://reviews.llvm.org/D144619

Added: 
    

Modified: 
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/test/CodeGen/WebAssembly/masked-shifts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 94544800a6fba..32d0b01cb90e8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -2287,10 +2287,43 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
   // Only manually lower vector shifts
   assert(Op.getSimpleValueType().isVector());
 
-  auto ShiftVal = DAG.getSplatValue(Op.getOperand(1));
+  uint64_t LaneBits = Op.getValueType().getScalarSizeInBits();
+  auto ShiftVal = Op.getOperand(1);
+
+  // Try to skip bitmask operation since it is implied inside shift instruction
+  auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) {
+    if (MaskOp.getOpcode() != ISD::AND)
+      return MaskOp;
+    SDValue LHS = MaskOp.getOperand(0);
+    SDValue RHS = MaskOp.getOperand(1);
+    if (MaskOp.getValueType().isVector()) {
+      APInt MaskVal;
+      if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal))
+        std::swap(LHS, RHS);
+
+      if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) &&
+          MaskVal == MaskBits)
+        MaskOp = LHS;
+    } else {
+      if (!isa<ConstantSDNode>(RHS.getNode()))
+        std::swap(LHS, RHS);
+
+      auto ConstantRHS = dyn_cast<ConstantSDNode>(RHS.getNode());
+      if (ConstantRHS && ConstantRHS->getAPIntValue() == MaskBits)
+        MaskOp = LHS;
+    }
+
+    return MaskOp;
+  };
+
+  // Skip vector and operation
+  ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
+  ShiftVal = DAG.getSplatValue(ShiftVal);
   if (!ShiftVal)
     return unrollVectorShift(Op, DAG);
 
+  // Skip scalar and operation
+  ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
   // Use anyext because none of the high bits can affect the shift
   ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32);
 

diff  --git a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
index 56e6119454380..5bcb023e546b5 100644
--- a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
+++ b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
@@ -106,10 +106,6 @@ define <16 x i8> @shl_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shl
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -145,10 +141,6 @@ define <16 x i8> @ashr_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shr_s
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -184,10 +176,6 @@ define <16 x i8> @lshr_v16i8_late(<16 x i8> %v, i8 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.splat
-; CHECK-NEXT:    v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i8x16.extract_lane_u 0
 ; CHECK-NEXT:    i8x16.shr_u
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -222,10 +210,6 @@ define <8 x i16> @shl_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shl
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -259,10 +243,6 @@ define <8 x i16> @ashr_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shr_s
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -296,10 +276,6 @@ define <8 x i16> @lshr_v8i16_late(<8 x i16> %v, i16 %x) {
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i16x8.splat
-; CHECK-NEXT:    v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT:    v128.and
-; CHECK-NEXT:    i16x8.extract_lane_u 0
 ; CHECK-NEXT:    i16x8.shr_u
 ; CHECK-NEXT:    # fallthrough-return
   %t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -519,6 +495,22 @@ define <2 x i64> @shl_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @shl_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: shl_v2i64_i32_late:
+; CHECK:         .functype shl_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shl
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = shl <2 x i64> %v, %m
+  ret <2 x i64> %a
+}
+
 define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
 ; CHECK-LABEL: ashr_v2i64_i32:
 ; CHECK:         .functype ashr_v2i64_i32 (v128, i32) -> (v128)
@@ -535,6 +527,22 @@ define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @ashr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: ashr_v2i64_i32_late:
+; CHECK:         .functype ashr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shr_s
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = ashr <2 x i64> %v, %m
+  ret <2 x i64> %a
+}
+
 define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
 ; CHECK-LABEL: lshr_v2i64_i32:
 ; CHECK:         .functype lshr_v2i64_i32 (v128, i32) -> (v128)
@@ -551,3 +559,18 @@ define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
   ret <2 x i64> %a
 }
 
+define <2 x i64> @lshr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: lshr_v2i64_i32_late:
+; CHECK:         .functype lshr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i64x2.shr_u
+; CHECK-NEXT:    # fallthrough-return
+  %z = zext i32 %x to i64
+  %t = insertelement <2 x i64> undef, i64 %z, i32 0
+  %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+  %m = and <2 x i64> %s, <i64 63, i64 63>
+  %a = lshr <2 x i64> %v, %m
+  ret <2 x i64> %a
+}


        


More information about the llvm-commits mailing list