[llvm] [InstCombine] Use KnownBits predicate helpers (PR #115874)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 06:23:36 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

Inside foldICmpUsingKnownBits(), instead of rolling our own logic based on min/max values, make use of KnownBits::eq() etc. This gives better results for the equality predicates.

I've adjusted some tests to prevent the new fold from triggering, to retain their original intent of testing constant expressions.

---
Full diff: https://github.com/llvm/llvm-project/pull/115874.diff


5 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+40-61) 
- (modified) llvm/test/Transforms/InstCombine/icmp-gep.ll (+1-3) 
- (modified) llvm/test/Transforms/InstCombine/mul-inseltpoison.ll (+2-2) 
- (modified) llvm/test/Transforms/InstCombine/mul.ll (+2-2) 
- (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 5a8814dfd6b3d3..975abf027f6c54 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6544,6 +6544,35 @@ bool InstCombinerImpl::replacedSelectWithOperand(SelectInst *SI,
   return false;
 }
 
+static std::optional<bool> compareKnownBits(ICmpInst::Predicate Pred,
+                                            const KnownBits &Op0,
+                                            const KnownBits &Op1) {
+  switch (Pred) {
+  case ICmpInst::ICMP_EQ:
+    return KnownBits::eq(Op0, Op1);
+  case ICmpInst::ICMP_NE:
+    return KnownBits::ne(Op0, Op1);
+  case ICmpInst::ICMP_ULT:
+    return KnownBits::ult(Op0, Op1);
+  case ICmpInst::ICMP_ULE:
+    return KnownBits::ule(Op0, Op1);
+  case ICmpInst::ICMP_UGT:
+    return KnownBits::ugt(Op0, Op1);
+  case ICmpInst::ICMP_UGE:
+    return KnownBits::uge(Op0, Op1);
+  case ICmpInst::ICMP_SLT:
+    return KnownBits::slt(Op0, Op1);
+  case ICmpInst::ICMP_SLE:
+    return KnownBits::sle(Op0, Op1);
+  case ICmpInst::ICMP_SGT:
+    return KnownBits::sgt(Op0, Op1);
+  case ICmpInst::ICMP_SGE:
+    return KnownBits::sge(Op0, Op1);
+  default:
+    llvm_unreachable("Unknown predicate");
+  }
+}
+
 /// Try to fold the comparison based on range information we can get by checking
 /// whether bits are known to be zero or one in the inputs.
 Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
@@ -6576,6 +6605,16 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return &I;
   }
 
+  if (!isa<Constant>(Op0) && Op0Known.isConstant())
+    return new ICmpInst(
+        Pred, ConstantExpr::getIntegerValue(Ty, Op0Known.getConstant()), Op1);
+  if (!isa<Constant>(Op1) && Op1Known.isConstant())
+    return new ICmpInst(
+        Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Known.getConstant()));
+
+  if (std::optional<bool> Res = compareKnownBits(Pred, Op0Known, Op1Known))
+    return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *Res));
+
   // Given the known and unknown bits, compute a range that the LHS could be
   // in.  Compute the Min, Max and RHS values based on the known bits. For the
   // EQ and NE we use unsigned values.
@@ -6593,14 +6632,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     Op1Max = Op1Known.getMaxValue();
   }
 
-  // If Min and Max are known to be the same, then SimplifyDemandedBits figured
-  // out that the LHS or RHS is a constant. Constant fold this now, so that
-  // code below can assume that Min != Max.
-  if (!isa<Constant>(Op0) && Op0Min == Op0Max)
-    return new ICmpInst(Pred, ConstantExpr::getIntegerValue(Ty, Op0Min), Op1);
-  if (!isa<Constant>(Op1) && Op1Min == Op1Max)
-    return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
-
   // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
   // min/max canonical compare with some other compare. That could lead to
   // conflict with select canonicalization and infinite looping.
@@ -6682,13 +6713,9 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
   // simplify this comparison.  For example, (x&4) < 8 is always true.
   switch (Pred) {
   default:
-    llvm_unreachable("Unknown icmp opcode!");
+    break;
   case ICmpInst::ICMP_EQ:
   case ICmpInst::ICMP_NE: {
-    if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
-      return replaceInstUsesWith(
-          I, ConstantInt::getBool(I.getType(), Pred == CmpInst::ICMP_NE));
-
     // If all bits are known zero except for one, then we know at most one bit
     // is set. If the comparison is against zero, then this is a check to see if
     // *that* bit is set.
@@ -6728,67 +6755,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
                           ConstantInt::getNullValue(Op1->getType()));
     break;
   }
-  case ICmpInst::ICMP_ULT: {
-    if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_UGT: {
-    if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_SLT: {
-    if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
-  case ICmpInst::ICMP_SGT: {
-    if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    break;
-  }
   case ICmpInst::ICMP_SGE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
-    if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_SLE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
-    if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_UGE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
-    if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
   case ICmpInst::ICMP_ULE:
-    assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
-    if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
-      return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-    if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
-      return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
     if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
       return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
     break;
diff --git a/llvm/test/Transforms/InstCombine/icmp-gep.ll b/llvm/test/Transforms/InstCombine/icmp-gep.ll
index 887cf1162319bc..776716fe908733 100644
--- a/llvm/test/Transforms/InstCombine/icmp-gep.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-gep.ll
@@ -583,9 +583,7 @@ define i1 @gep_nusw(ptr %p, i64 %a, i64 %b, i64 %c, i64 %d) {
 
 define i1 @pointer_icmp_aligned_with_offset(ptr align 8 %a, ptr align 8 %a2) {
 ; CHECK-LABEL: @pointer_icmp_aligned_with_offset(
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 4
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq ptr [[GEP]], [[A2:%.*]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
   %gep = getelementptr i8, ptr %a, i64 4
   %cmp = icmp eq ptr %gep, %a2
diff --git a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
index 997758af62a543..8baf6a70fdd5d1 100644
--- a/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/mul-inseltpoison.ll
@@ -570,12 +570,12 @@ define i64 @test30(i32 %A, i32 %B) {
 @PR22087 = external global i32
 define i32 @test31(i32 %V) {
 ; CHECK-LABEL: @test31(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
 ; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[CMP]] to i32
 ; CHECK-NEXT:    [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
 ; CHECK-NEXT:    ret i32 [[MUL1]]
 ;
-  %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+  %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
   %ext = zext i1 %cmp to i32
   %shl = shl i32 1, %ext
   %mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index e38ab1b9622b2c..340828a8d3f9dd 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -1152,12 +1152,12 @@ define i64 @test30(i32 %A, i32 %B) {
 @PR22087 = external global i32
 define i32 @test31(i32 %V) {
 ; CHECK-LABEL: @test31(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
 ; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[CMP]] to i32
 ; CHECK-NEXT:    [[MUL1:%.*]] = shl i32 [[V:%.*]], [[EXT]]
 ; CHECK-NEXT:    ret i32 [[MUL1]]
 ;
-  %cmp = icmp ne ptr inttoptr (i64 1 to ptr), @PR22087
+  %cmp = icmp ne ptr inttoptr (i64 4 to ptr), @PR22087
   %ext = zext i1 %cmp to i32
   %shl = shl i32 1, %ext
   %mul = mul i32 %V, %shl
diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
index 070a3b03302124..e95955da1b8728 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
@@ -669,7 +669,7 @@ define <2 x i1> @n38_overshift(<2 x i32> %x, <2 x i32> %y) {
 }
 
 ; As usual, don't crash given constantexpr's :/
- at f.a = internal global i16 0
+ at f.a = internal global i16 0, align 1
 define i1 @constantexpr() {
 ; CHECK-LABEL: @constantexpr(
 ; CHECK-NEXT:  entry:

``````````

</details>


https://github.com/llvm/llvm-project/pull/115874


More information about the llvm-commits mailing list