[llvm] [VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth (PR #108705)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 14 19:14:21 PDT 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/108705

>From 96f3521a66135b695c07734b552406126e9fec80 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 15 Sep 2024 10:07:22 +0800
Subject: [PATCH 1/2] [VectorCombine] Add pre-commit tests. NFC.

---
 .../VectorCombine/AArch64/shrink-types.ll          | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
index 1a23f0a0ac142f..d201580bb6dc71 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -100,4 +100,18 @@ vector.body:
   ret i32 %2
 }
 
+define <2 x i32> @pr108698(<2 x i64> %x, <2 x i32> %y) {
+; CHECK-LABEL: @pr108698(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], zeroinitializer
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[Y:%.*]] to <2 x i1>
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr <2 x i1> [[CMP]], [[TMP1]]
+; CHECK-NEXT:    [[LSHR:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
+; CHECK-NEXT:    ret <2 x i32> [[LSHR]]
+;
+  %cmp = icmp eq <2 x i64> %x, zeroinitializer
+  %ext = zext <2 x i1> %cmp to <2 x i32>
+  %lshr = lshr <2 x i32> %ext, %y
+  ret <2 x i32> %lshr
+}
+
 declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

>From fe91054f9a15f7d4db7888b55679ff2316ae4d41 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Sun, 15 Sep 2024 10:08:02 +0800
Subject: [PATCH 2/2] [VectorCombine] Don't shrink lshr if the shamt is not
 less than bitwidth

---
 .../Transforms/Vectorize/VectorCombine.cpp    | 20 +++++++++++++------
 .../VectorCombine/AArch64/shrink-types.ll     |  5 ++---
 2 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d7afe2f426d392..58701bfa60a33e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
   auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
   unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
 
-  // Check that the expression overall uses at most the same number of bits as
-  // ZExted
-  KnownBits KB = computeKnownBits(&I, *DL);
-  if (KB.countMaxActiveBits() > BW)
-    return false;
+  if (I.getOpcode() == Instruction::LShr) {
+    // Check that the shift amount is less than the number of bits in the
+    // smaller type. Otherwise, the smaller lshr will return a poison value.
+    KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
+    if (ShAmtKB.getMaxValue().uge(BW))
+      return false;
+  } else {
+    // Check that the expression overall uses at most the same number of bits as
+    // ZExted
+    KnownBits KB = computeKnownBits(&I, *DL);
+    if (KB.countMaxActiveBits() > BW)
+      return false;
+  }
 
   // Calculate costs of leaving current IR as it is and moving ZExt operation
   // later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
       return false;
 
     // Check if we can propagate ZExt through its other users
-    KB = computeKnownBits(UI, *DL);
+    KnownBits KB = computeKnownBits(UI, *DL);
     if (KB.countMaxActiveBits() > BW)
       return false;
 
diff --git a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
index d201580bb6dc71..4216b0e643bb68 100644
--- a/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
+++ b/llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll
@@ -103,9 +103,8 @@ vector.body:
 define <2 x i32> @pr108698(<2 x i64> %x, <2 x i32> %y) {
 ; CHECK-LABEL: @pr108698(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], zeroinitializer
-; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> [[Y:%.*]] to <2 x i1>
-; CHECK-NEXT:    [[TMP2:%.*]] = lshr <2 x i1> [[CMP]], [[TMP1]]
-; CHECK-NEXT:    [[LSHR:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
+; CHECK-NEXT:    [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y:%.*]]
 ; CHECK-NEXT:    ret <2 x i32> [[LSHR]]
 ;
   %cmp = icmp eq <2 x i64> %x, zeroinitializer



More information about the llvm-commits mailing list