[llvm] [InstCombine] Fold zext-of-icmp with no shift (PR #68503)

via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 7 21:17:39 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

<details>
<summary>Changes</summary>

This regression triggers after commit f400daa to fix infinite loop issue.

In this case, we can known the shift count is 0, so it will not be
triggered by the form of (iN (~X) u>> (N - 1)) in commit 21d3871, of
which N indicates the data type bitwidth of X.

Fixes https://github.com/llvm/llvm-project/issues/68465

---
Full diff: https://github.com/llvm/llvm-project/pull/68503.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+24-27) 
- (modified) llvm/test/Transforms/InstCombine/zext.ll (+13) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 20c13de33f8189d..f7c6fb2a2bf2958 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -904,37 +904,34 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp,
     // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
     // zext (X != 0) to i32 --> X        iff X has only the low bit set.
     // zext (X != 0) to i32 --> X>>1     iff X has only the 2nd bit set.
-    if (Op1CV->isZero() && Cmp->isEquality() &&
-        (Cmp->getOperand(0)->getType() == Zext.getType() ||
-         Cmp->getPredicate() == ICmpInst::ICMP_NE)) {
-      // If Op1C some other power of two, convert:
-      KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
 
-      // Exactly 1 possible 1? But not the high-bit because that is
-      // canonicalized to this form.
-      APInt KnownZeroMask(~Known.Zero);
-      if (KnownZeroMask.isPowerOf2() &&
-          (Zext.getType()->getScalarSizeInBits() !=
-           KnownZeroMask.logBase2() + 1)) {
-        uint32_t ShAmt = KnownZeroMask.logBase2();
-        Value *In = Cmp->getOperand(0);
-        if (ShAmt) {
-          // Perform a logical shr by shiftamt.
-          // Insert the shift to put the result in the low bit.
-          In = Builder.CreateLShr(In, ConstantInt::get(In->getType(), ShAmt),
-                                  In->getName() + ".lobit");
-        }
+    // Exactly 1 possible 1? But not the high-bit because that is
+    // canonicalized to this form.
+    KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
+    APInt KnownZeroMask(~Known.Zero);
+    uint32_t ShAmt = KnownZeroMask.logBase2();
+    bool isExpectShAmt = KnownZeroMask.isPowerOf2() &&
+                         (Zext.getType()->getScalarSizeInBits() != ShAmt + 1);
+    if (Op1CV->isZero() && Cmp->isEquality() && isExpectShAmt &&
+        (Cmp->getOperand(0)->getType() == Zext.getType() ||
+         Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) {
+      Value *In = Cmp->getOperand(0);
+      if (ShAmt) {
+        // Perform a logical shr by shiftamt.
+        // Insert the shift to put the result in the low bit.
+        In = Builder.CreateLShr(In, ConstantInt::get(In->getType(), ShAmt),
+                                In->getName() + ".lobit");
+      }
 
-        // Toggle the low bit for "X == 0".
-        if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
-          In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1));
+      // Toggle the low bit for "X == 0".
+      if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
+        In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1));
 
-        if (Zext.getType() == In->getType())
-          return replaceInstUsesWith(Zext, In);
+      if (Zext.getType() == In->getType())
+        return replaceInstUsesWith(Zext, In);
 
-        Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false);
-        return replaceInstUsesWith(Zext, IntCast);
-      }
+      Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false);
+      return replaceInstUsesWith(Zext, IntCast);
     }
   }
 
diff --git a/llvm/test/Transforms/InstCombine/zext.ll b/llvm/test/Transforms/InstCombine/zext.ll
index 8aa2a10e6abb2ea..29b6601774a89b3 100644
--- a/llvm/test/Transforms/InstCombine/zext.ll
+++ b/llvm/test/Transforms/InstCombine/zext.ll
@@ -748,3 +748,16 @@ define i64 @zext_icmp_ne_bool_1(ptr %ptr) {
   %len = zext i1 %cmp to i64
   ret i64 %len
 }
+
+define i32  @zext_icmp_eq0_no_shift(ptr %ptr ) {
+; CHECK-LABEL: @zext_icmp_eq0_no_shift(
+; CHECK-NEXT:    [[X:%.*]] = load i8, ptr [[PTR:%.*]], align 1, !range [[RNG1:![0-9]+]]
+; CHECK-NEXT:    [[TMP1:%.*]] = xor i8 [[X]], 1
+; CHECK-NEXT:    [[RES:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %X = load i8, ptr %ptr,align 1, !range !{i8 0, i8 2} ; range [0, 2)
+  %cmp = icmp eq i8 %X, 0
+  %res = zext i1 %cmp to i32
+  ret i32 %res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68503


More information about the llvm-commits mailing list