[llvm] 403926a - [WebAssembly] Skip implied bitmask operation in LowerShift
Jun Ma via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 1 17:37:51 PST 2023
Author: Jun Ma
Date: 2023-03-02T09:37:25+08:00
New Revision: 403926aefefb13553f89ad812b1e2385826a82ec
URL: https://github.com/llvm/llvm-project/commit/403926aefefb13553f89ad812b1e2385826a82ec
DIFF: https://github.com/llvm/llvm-project/commit/403926aefefb13553f89ad812b1e2385826a82ec.diff
LOG: [WebAssembly] Skip implied bitmask operation in LowerShift
This patch skips redundant explicit masks of the shift count since
it is implied inside wasm shift instruction.
Differential Revision: https://reviews.llvm.org/D144619
Added:
Modified:
llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
llvm/test/CodeGen/WebAssembly/masked-shifts.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 94544800a6fba..32d0b01cb90e8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -2287,10 +2287,43 @@ SDValue WebAssemblyTargetLowering::LowerShift(SDValue Op,
// Only manually lower vector shifts
assert(Op.getSimpleValueType().isVector());
- auto ShiftVal = DAG.getSplatValue(Op.getOperand(1));
+ uint64_t LaneBits = Op.getValueType().getScalarSizeInBits();
+ auto ShiftVal = Op.getOperand(1);
+
+ // Try to skip bitmask operation since it is implied inside shift instruction
+ auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) {
+ if (MaskOp.getOpcode() != ISD::AND)
+ return MaskOp;
+ SDValue LHS = MaskOp.getOperand(0);
+ SDValue RHS = MaskOp.getOperand(1);
+ if (MaskOp.getValueType().isVector()) {
+ APInt MaskVal;
+ if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal))
+ std::swap(LHS, RHS);
+
+ if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) &&
+ MaskVal == MaskBits)
+ MaskOp = LHS;
+ } else {
+ if (!isa<ConstantSDNode>(RHS.getNode()))
+ std::swap(LHS, RHS);
+
+ auto ConstantRHS = dyn_cast<ConstantSDNode>(RHS.getNode());
+ if (ConstantRHS && ConstantRHS->getAPIntValue() == MaskBits)
+ MaskOp = LHS;
+ }
+
+ return MaskOp;
+ };
+
+ // Skip vector and operation
+ ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
+ ShiftVal = DAG.getSplatValue(ShiftVal);
if (!ShiftVal)
return unrollVectorShift(Op, DAG);
+ // Skip scalar and operation
+ ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1);
// Use anyext because none of the high bits can affect the shift
ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32);
diff --git a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
index 56e6119454380..5bcb023e546b5 100644
--- a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
+++ b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll
@@ -106,10 +106,6 @@ define <16 x i8> @shl_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i8x16.splat
-; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shl
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -145,10 +141,6 @@ define <16 x i8> @ashr_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i8x16.splat
-; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shr_s
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -184,10 +176,6 @@ define <16 x i8> @lshr_v16i8_late(<16 x i8> %v, i8 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i8x16.splat
-; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i8x16.extract_lane_u 0
; CHECK-NEXT: i8x16.shr_u
; CHECK-NEXT: # fallthrough-return
%t = insertelement <16 x i8> undef, i8 %x, i32 0
@@ -222,10 +210,6 @@ define <8 x i16> @shl_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i16x8.splat
-; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shl
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -259,10 +243,6 @@ define <8 x i16> @ashr_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i16x8.splat
-; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shr_s
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -296,10 +276,6 @@ define <8 x i16> @lshr_v8i16_late(<8 x i16> %v, i16 %x) {
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i16x8.splat
-; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15
-; CHECK-NEXT: v128.and
-; CHECK-NEXT: i16x8.extract_lane_u 0
; CHECK-NEXT: i16x8.shr_u
; CHECK-NEXT: # fallthrough-return
%t = insertelement <8 x i16> undef, i16 %x, i32 0
@@ -519,6 +495,22 @@ define <2 x i64> @shl_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}
+define <2 x i64> @shl_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: shl_v2i64_i32_late:
+; CHECK: .functype shl_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i64x2.shl
+; CHECK-NEXT: # fallthrough-return
+ %z = zext i32 %x to i64
+ %t = insertelement <2 x i64> undef, i64 %z, i32 0
+ %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+ %m = and <2 x i64> %s, <i64 63, i64 63>
+ %a = shl <2 x i64> %v, %m
+ ret <2 x i64> %a
+}
+
define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: ashr_v2i64_i32:
; CHECK: .functype ashr_v2i64_i32 (v128, i32) -> (v128)
@@ -535,6 +527,22 @@ define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}
+define <2 x i64> @ashr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: ashr_v2i64_i32_late:
+; CHECK: .functype ashr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i64x2.shr_s
+; CHECK-NEXT: # fallthrough-return
+ %z = zext i32 %x to i64
+ %t = insertelement <2 x i64> undef, i64 %z, i32 0
+ %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+ %m = and <2 x i64> %s, <i64 63, i64 63>
+ %a = ashr <2 x i64> %v, %m
+ ret <2 x i64> %a
+}
+
define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
; CHECK-LABEL: lshr_v2i64_i32:
; CHECK: .functype lshr_v2i64_i32 (v128, i32) -> (v128)
@@ -551,3 +559,18 @@ define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) {
ret <2 x i64> %a
}
+define <2 x i64> @lshr_v2i64_i32_late(<2 x i64> %v, i32 %x) {
+; CHECK-LABEL: lshr_v2i64_i32_late:
+; CHECK: .functype lshr_v2i64_i32_late (v128, i32) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i64x2.shr_u
+; CHECK-NEXT: # fallthrough-return
+ %z = zext i32 %x to i64
+ %t = insertelement <2 x i64> undef, i64 %z, i32 0
+ %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> <i32 0, i32 0>
+ %m = and <2 x i64> %s, <i64 63, i64 63>
+ %a = lshr <2 x i64> %v, %m
+ ret <2 x i64> %a
+}
More information about the llvm-commits
mailing list