[llvm] [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (PR #72637)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 17 04:31:07 PST 2023


================
@@ -640,82 +640,68 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   QueryNoAC.AC = nullptr;
 
   // Note that ptrtoint may change the bitwidth.
-  Value *A, *B;
   auto m_V =
       m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
 
   CmpInst::Predicate Pred;
-  uint64_t C;
+  const APInt *Mask, *C;
+  uint64_t ShAmt;
   switch (Cmp->getPredicate()) {
   case ICmpInst::ICMP_EQ:
-    // assume(v = a)
-    if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      Known = Known.unionWith(RHSKnown);
-      // assume(v & b = a)
-    } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in the mask that are known to be one, we can propagate
-      // known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & MaskKnown.One;
-      Known.One |= RHSKnown.One & MaskKnown.One;
-      // assume(v | b = a)
+    // assume(V = C)
+    if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+      Known = Known.unionWith(KnownBits::makeConstant(*C));
+      // assume(V & Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      // assume(v ^ b = a)
+                     m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For one bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & *Mask;
+      Known.One |= *C & *Mask;
+      // assume(V | Mask = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For zero bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      // assume(V ^ Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V. For those bits in B that are known to be one,
-      // we can propagate inverted known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      Known.Zero |= RHSKnown.One & BKnown.One;
-      Known.One |= RHSKnown.Zero & BKnown.One;
-      // assume(v << c = a)
-    } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
-                                   m_Value(A))) &&
-               C < BitWidth) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
-      // For those bits in RHS that are known, we can propagate them to known
-      // bits in V shifted to the right by C.
-      RHSKnown.Zero.lshrInPlace(C);
-      RHSKnown.One.lshrInPlace(C);
+                     m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For those bits in Mask that are zero, we can propagate known bits
+      // from C to V. For those bits in Mask that are one, we can propagate
+      // inverted bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      Known.Zero |= *C & *Mask;
+      Known.One |= ~*C & *Mask;
+      // assume(V << ShAmt = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+                                 m_APInt(C))) &&
+               ShAmt < BitWidth) {
+      // For those bits in C that are known, we can propagate them to known
+      // bits in V shifted to the right by ShAmt.
+      KnownBits RHSKnown = KnownBits::makeConstant(*C);
+      RHSKnown.Zero.lshrInPlace(ShAmt);
+      RHSKnown.One.lshrInPlace(ShAmt);
       Known = Known.unionWith(RHSKnown);
----------------
dtcxzyw wrote:

```suggestion
      Known = Known.unionWith(KnownBits::makeConstant(C->lshr(ShAmt)));
```

https://github.com/llvm/llvm-project/pull/72637


More information about the llvm-commits mailing list