[llvm] [VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth (PR #108705)

via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 14 08:31:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

Consider the following case:
```
define <2 x i32> @<!-- -->test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) {
  %19 = icmp eq <2 x i64> %vec.ind16, zeroinitializer
  %20 = zext <2 x i1> %19 to <2 x i32>
  %21 = lshr <2 x i32> %20, %broadcast.splat20
  ret <2 x i32> %21
}
```
After https://github.com/llvm/llvm-project/pull/104606, we shrink the lshr into:
```
define <2 x i32> @<!-- -->test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) {
  %1 = icmp eq <2 x i64> %vec.ind16, zeroinitializer
  %2 = trunc <2 x i32> %broadcast.splat20 to <2 x i1>
  %3 = lshr <2 x i1> %1, %2
  %4 = zext <2 x i1> %3 to <2 x i32>
  ret <2 x i32> %4
}
```
It is incorrect since `lshr i1 X, 1` returns `poison`.
This patch adds additional check on the shamt operand. The lshr will get shrunk iff we ensure that the shamt is less than bitwidth of the smaller type. As `computeKnownBits(&I, *DL).countMaxActiveBits() > BW` always evaluates to true for `lshr(zext(X), Y)`, this check will only apply to bitwise logical instructions.

Alive2: https://alive2.llvm.org/ce/z/j_RmTa
Fixes https://github.com/llvm/llvm-project/issues/108698.



---
Full diff: https://github.com/llvm/llvm-project/pull/108705.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+14-6) 
- (added) llvm/test/Transforms/VectorCombine/X86/pr108698.ll (+16) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d7afe2f426d392..58701bfa60a33e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
   auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
   unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
 
-  // Check that the expression overall uses at most the same number of bits as
-  // ZExted
-  KnownBits KB = computeKnownBits(&I, *DL);
-  if (KB.countMaxActiveBits() > BW)
-    return false;
+  if (I.getOpcode() == Instruction::LShr) {
+    // Check that the shift amount is less than the number of bits in the
+    // smaller type. Otherwise, the smaller lshr will return a poison value.
+    KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
+    if (ShAmtKB.getMaxValue().uge(BW))
+      return false;
+  } else {
+    // Check that the expression overall uses at most the same number of bits as
+    // ZExted
+    KnownBits KB = computeKnownBits(&I, *DL);
+    if (KB.countMaxActiveBits() > BW)
+      return false;
+  }
 
   // Calculate costs of leaving current IR as it is and moving ZExt operation
   // later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
       return false;
 
     // Check if we can propagate ZExt through its other users
-    KB = computeKnownBits(UI, *DL);
+    KnownBits KB = computeKnownBits(UI, *DL);
     if (KB.countMaxActiveBits() > BW)
       return false;
 
diff --git a/llvm/test/Transforms/VectorCombine/X86/pr108698.ll b/llvm/test/Transforms/VectorCombine/X86/pr108698.ll
new file mode 100644
index 00000000000000..675cf6ed7da51f
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/pr108698.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64 | FileCheck %s
+
+define <2 x i32> @test(<2 x i64> %x, <2 x i32> %y) {
+; CHECK-LABEL: define <2 x i32> @test(
+; CHECK-SAME: <2 x i64> [[X:%.*]], <2 x i32> [[Y:%.*]]) {
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i64> [[X]], zeroinitializer
+; CHECK-NEXT:    [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y]]
+; CHECK-NEXT:    ret <2 x i32> [[LSHR]]
+;
+  %cmp = icmp eq <2 x i64> %x, zeroinitializer
+  %ext = zext <2 x i1> %cmp to <2 x i32>
+  %lshr = lshr <2 x i32> %ext, %y
+  ret <2 x i32> %lshr
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/108705


More information about the llvm-commits mailing list