[llvm] 87663fd - [VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth (#108705)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 15 03:38:09 PDT 2024
Author: Yingwei Zheng
Date: 2024-09-15T18:38:06+08:00
New Revision: 87663fdab9d0e7bcc0b963ea078da9e2eb574908
URL: https://github.com/llvm/llvm-project/commit/87663fdab9d0e7bcc0b963ea078da9e2eb574908
DIFF: https://github.com/llvm/llvm-project/commit/87663fdab9d0e7bcc0b963ea078da9e2eb574908.diff
LOG: [VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth (#108705)
Consider the following case:
```
define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) {
%19 = icmp eq <2 x i64> %vec.ind16, zeroinitializer
%20 = zext <2 x i1> %19 to <2 x i32>
%21 = lshr <2 x i32> %20, %broadcast.splat20
ret <2 x i32> %21
}
```
After https://github.com/llvm/llvm-project/pull/104606, we shrink the
lshr into:
```
define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) {
%1 = icmp eq <2 x i64> %vec.ind16, zeroinitializer
%2 = trunc <2 x i32> %broadcast.splat20 to <2 x i1>
%3 = lshr <2 x i1> %1, %2
%4 = zext <2 x i1> %3 to <2 x i32>
ret <2 x i32> %4
}
```
It is incorrect since `lshr i1 X, 1` returns `poison`.
This patch adds additional check on the shamt operand. The lshr will get
shrunk iff we ensure that the shamt is less than bitwidth of the smaller
type. As `computeKnownBits(&I, *DL).countMaxActiveBits() > BW` always
evaluates to true for `lshr(zext(X), Y)`, this check will only apply to
bitwise logical instructions.
Alive2: https://alive2.llvm.org/ce/z/j_RmTa
Fixes https://github.com/llvm/llvm-project/issues/108698.
Added:
Modified:
llvm/lib/Transforms/Vectorize/VectorCombine.cpp
llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d7afe2f426d392..58701bfa60a33e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
- // Check that the expression overall uses at most the same number of bits as
- // ZExted
- KnownBits KB = computeKnownBits(&I, *DL);
- if (KB.countMaxActiveBits() > BW)
- return false;
+ if (I.getOpcode() == Instruction::LShr) {
+ // Check that the shift amount is less than the number of bits in the
+ // smaller type. Otherwise, the smaller lshr will return a poison value.
+ KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
+ if (ShAmtKB.getMaxValue().uge(BW))
+ return false;
+ } else {
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ KnownBits KB = computeKnownBits(&I, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+ }
// Calculate costs of leaving current IR as it is and moving ZExt operation
// later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
return false;
// Check if we can propagate ZExt through its other users
- KB = computeKnownBits(UI, *DL);
+ KnownBits KB = computeKnownBits(UI, *DL);
if (KB.countMaxActiveBits() > BW)
return false;
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
index 1a23f0a0ac142f..4216b0e643bb68 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -100,4 +100,17 @@ vector.body:
ret i32 %2
}
+define <2 x i32> @pr108698(<2 x i64> %x, <2 x i32> %y) {
+; CHECK-LABEL: @pr108698(
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], zeroinitializer
+; CHECK-NEXT: [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
+; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y:%.*]]
+; CHECK-NEXT: ret <2 x i32> [[LSHR]]
+;
+ %cmp = icmp eq <2 x i64> %x, zeroinitializer
+ %ext = zext <2 x i1> %cmp to <2 x i32>
+ %lshr = lshr <2 x i32> %ext, %y
+ ret <2 x i32> %lshr
+}
+
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)
More information about the llvm-commits
mailing list