[llvm] [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (PR #72637)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 17 02:49:23 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

For practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations.

This patch reduces the implementation to only handle cases with constant operands. The only non-constant case I've kept are plain `V pred A` comparisons, where I am less confident that this is useless.

The test changes are all in "make sure we don't crash" tests.

---
Full diff: https://github.com/llvm/llvm-project/pull/72637.diff


4 Files Affected:

- (modified) llvm/lib/Analysis/AssumptionCache.cpp (+11-21) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+43-56) 
- (modified) llvm/test/Transforms/InstCombine/assume.ll (+2-2) 
- (modified) llvm/test/Transforms/InstCombine/icmp.ll (+3-2) 


``````````diff
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
     AddAffected(B);
 
     if (Pred == ICmpInst::ICMP_EQ) {
-      // For equality comparisons, we handle the case of bit inversion.
-      auto AddAffectedFromEq = [&AddAffected](Value *V) {
-        Value *A, *B;
-        // (A & B) or (A | B) or (A ^ B).
-        if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
-          AddAffected(A);
-          AddAffected(B);
-          // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
-        } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
-          AddAffected(A);
-        }
-      };
-
-      AddAffectedFromEq(A);
-      AddAffectedFromEq(B);
+      if (match(B, m_ConstantInt())) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     } else if (Pred == ICmpInst::ICMP_NE) {
-      Value *X, *Y;
-      // Handle (a & b != 0). If a/b is a power of 2 we can use this
-      // information.
-      if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+      Value *X;
+      // Handle (X & pow2 != 0).
+      if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
         AddAffected(X);
-        AddAffected(Y);
-      }
     } else if (Pred == ICmpInst::ICMP_ULT) {
       Value *X;
       // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e25aa9c6863f335..10a4a6e349c3794 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -640,82 +640,69 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   QueryNoAC.AC = nullptr;
 
   // Note that ptrtoint may change the bitwidth.
-  Value *A, *B;
   auto m_V =
       m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
 
   CmpInst::Predicate Pred;
-  uint64_t C;
+  const APInt *Mask, *C;
+  uint64_t ShAmt;
   switch (Cmp->getPredicate()) {
   case ICmpInst::ICMP_EQ:
-    // assume(v = a)
-    if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      Known = Known.unionWith(RHSKnown);
-      // assume(v & b = a)
+    // assume(V = C)
+    if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+      Known = Known.unionWith(KnownBits::makeConstant(*C));
+      // assume(V & Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in the mask that are known to be one, we can propagate
-      // known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & MaskKnown.One;
-      Known.One |= RHSKnown.One & MaskKnown.One;
-      // assume(v | b = a)
+                     m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For one bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & *Mask;
+      Known.One |= *C & *Mask;
+      // assume(V | Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      // assume(v ^ b = a)
+                     m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For zero bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      // assume(V ^ Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V. For those bits in B that are known to be one,
-      // we can propagate inverted known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      Known.Zero |= RHSKnown.One & BKnown.One;
-      Known.One |= RHSKnown.Zero & BKnown.One;
-      // assume(v << c = a)
-    } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
-                                   m_Value(A))) &&
-               C < BitWidth) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
-      // For those bits in RHS that are known, we can propagate them to known
-      // bits in V shifted to the right by C.
-      RHSKnown.Zero.lshrInPlace(C);
-      RHSKnown.One.lshrInPlace(C);
+                     m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For those bits in Mask that are zero, we can propagate known bits
+      // from C to V. For those bits in Mask that are one, we can propagate
+      // inverted bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      Known.Zero |= *C & *Mask;
+      Known.One |= ~*C & *Mask;
+      // assume(V << ShAmt = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+                                 m_APInt(C))) &&
+               ShAmt < BitWidth) {
+      // For those bits in C that are known, we can propagate them to known
+      // bits in V shifted to the right by ShAmt.
+      KnownBits RHSKnown = KnownBits::makeConstant(*C);
+      RHSKnown.Zero.lshrInPlace(ShAmt);
+      RHSKnown.One.lshrInPlace(ShAmt);
       Known = Known.unionWith(RHSKnown);
-      // assume(v >> c = a)
-    } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
-                                   m_Value(A))) &&
-               C < BitWidth) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      // assume(V >> ShAmt = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+                                 m_APInt(C))) &&
+               ShAmt < BitWidth) {
+      KnownBits RHSKnown = KnownBits::makeConstant(*C);
       // For those bits in RHS that are known, we can propagate them to known
       // bits in V shifted to the right by C.
-      Known.Zero |= RHSKnown.Zero << C;
-      Known.One |= RHSKnown.One << C;
+      Known.Zero |= RHSKnown.Zero << ShAmt;
+      Known.One |= RHSKnown.One << ShAmt;
     }
     break;
   case ICmpInst::ICMP_NE: {
-    // assume (v & b != 0) where b is a power of 2
+    // assume (V & B != 0) where B is a power of 2
     const APInt *BPow2;
-    if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+    if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
       Known.One |= *BPow2;
-    }
     break;
   }
   default:
+    Value *A;
     const APInt *Offset = nullptr;
     if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
                           m_Value(A)))) {
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 934e9594f3f7b5a..5b2039b5b480501 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -265,7 +265,7 @@ define i32 @bundle2(ptr %P) {
 
 define i1 @nonnull1(ptr %a) {
 ; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT:    [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT:    [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
 ; CHECK-NEXT:    tail call void @escape(ptr nonnull [[LOAD]])
 ; CHECK-NEXT:    ret i1 false
 ;
@@ -383,7 +383,7 @@ define i1 @nonnull5(ptr %a) {
 define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
 ; CHECK-LABEL: @assumption_conflicts_with_known_bits(
 ; CHECK-NEXT:    store i1 true, ptr poison, align 1
-; CHECK-NEXT:    ret i32 poison
+; CHECK-NEXT:    ret i32 1
 ;
   %and1 = and i32 %b, 3
   %B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 78ac730cf026ed9..d49cb79e1e27c98 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,9 +4016,10 @@ define i32 @abs_preserve(i32 %x) {
 declare void @llvm.assume(i1)
 define i1 @PR35794(ptr %a) {
 ; CHECK-LABEL: @PR35794(
-; CHECK-NEXT:    [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
+; CHECK-NEXT:    [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT:    ret i1 true
+; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
   %maskcond = icmp eq ptr %a, null

``````````

</details>


https://github.com/llvm/llvm-project/pull/72637


More information about the llvm-commits mailing list