[llvm] [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (PR #72637)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 17 02:49:23 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms
Author: Nikita Popov (nikic)
<details>
<summary>Changes</summary>
For practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations.
This patch reduces the implementation to only handle cases with constant operands. The only non-constant case I've kept are plain `V pred A` comparisons, where I am less confident that this is useless.
The test changes are all in "make sure we don't crash" tests.
---
Full diff: https://github.com/llvm/llvm-project/pull/72637.diff
4 Files Affected:
- (modified) llvm/lib/Analysis/AssumptionCache.cpp (+11-21)
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+43-56)
- (modified) llvm/test/Transforms/InstCombine/assume.ll (+2-2)
- (modified) llvm/test/Transforms/InstCombine/icmp.ll (+3-2)
``````````diff
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
AddAffected(B);
if (Pred == ICmpInst::ICMP_EQ) {
- // For equality comparisons, we handle the case of bit inversion.
- auto AddAffectedFromEq = [&AddAffected](Value *V) {
- Value *A, *B;
- // (A & B) or (A | B) or (A ^ B).
- if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
- AddAffected(A);
- AddAffected(B);
- // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
- } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
- AddAffected(A);
- }
- };
-
- AddAffectedFromEq(A);
- AddAffectedFromEq(B);
+ if (match(B, m_ConstantInt())) {
+ Value *X;
+ // (X & C) or (X | C) or (X ^ C).
+ // (X << C) or (X >>_s C) or (X >>_u C).
+ if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+ match(A, m_Shift(m_Value(X), m_ConstantInt())))
+ AddAffected(X);
+ }
} else if (Pred == ICmpInst::ICMP_NE) {
- Value *X, *Y;
- // Handle (a & b != 0). If a/b is a power of 2 we can use this
- // information.
- if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+ Value *X;
+ // Handle (X & pow2 != 0).
+ if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
AddAffected(X);
- AddAffected(Y);
- }
} else if (Pred == ICmpInst::ICMP_ULT) {
Value *X;
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e25aa9c6863f335..10a4a6e349c3794 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -640,82 +640,69 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
QueryNoAC.AC = nullptr;
// Note that ptrtoint may change the bitwidth.
- Value *A, *B;
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
CmpInst::Predicate Pred;
- uint64_t C;
+ const APInt *Mask, *C;
+ uint64_t ShAmt;
switch (Cmp->getPredicate()) {
case ICmpInst::ICMP_EQ:
- // assume(v = a)
- if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- Known = Known.unionWith(RHSKnown);
- // assume(v & b = a)
+ // assume(V = C)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+ Known = Known.unionWith(KnownBits::makeConstant(*C));
+ // assume(V & Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in the mask that are known to be one, we can propagate
- // known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & MaskKnown.One;
- Known.One |= RHSKnown.One & MaskKnown.One;
- // assume(v | b = a)
+ m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For one bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & *Mask;
+ Known.One |= *C & *Mask;
+ // assume(V | Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- // assume(v ^ b = a)
+ m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For zero bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ // assume(V ^ Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V. For those bits in B that are known to be one,
- // we can propagate inverted known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- Known.Zero |= RHSKnown.One & BKnown.One;
- Known.One |= RHSKnown.Zero & BKnown.One;
- // assume(v << c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
- // For those bits in RHS that are known, we can propagate them to known
- // bits in V shifted to the right by C.
- RHSKnown.Zero.lshrInPlace(C);
- RHSKnown.One.lshrInPlace(C);
+ m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For those bits in Mask that are zero, we can propagate known bits
+ // from C to V. For those bits in Mask that are one, we can propagate
+ // inverted bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ Known.Zero |= *C & *Mask;
+ Known.One |= ~*C & *Mask;
+ // assume(V << ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ // For those bits in C that are known, we can propagate them to known
+ // bits in V shifted to the right by ShAmt.
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
+ RHSKnown.Zero.lshrInPlace(ShAmt);
+ RHSKnown.One.lshrInPlace(ShAmt);
Known = Known.unionWith(RHSKnown);
- // assume(v >> c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+ // assume(V >> ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
- Known.Zero |= RHSKnown.Zero << C;
- Known.One |= RHSKnown.One << C;
+ Known.Zero |= RHSKnown.Zero << ShAmt;
+ Known.One |= RHSKnown.One << ShAmt;
}
break;
case ICmpInst::ICMP_NE: {
- // assume (v & b != 0) where b is a power of 2
+ // assume (V & B != 0) where B is a power of 2
const APInt *BPow2;
- if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+ if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
Known.One |= *BPow2;
- }
break;
}
default:
+ Value *A;
const APInt *Offset = nullptr;
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
m_Value(A)))) {
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 934e9594f3f7b5a..5b2039b5b480501 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -265,7 +265,7 @@ define i32 @bundle2(ptr %P) {
define i1 @nonnull1(ptr %a) {
; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
; CHECK-NEXT: tail call void @escape(ptr nonnull [[LOAD]])
; CHECK-NEXT: ret i1 false
;
@@ -383,7 +383,7 @@ define i1 @nonnull5(ptr %a) {
define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
; CHECK-LABEL: @assumption_conflicts_with_known_bits(
; CHECK-NEXT: store i1 true, ptr poison, align 1
-; CHECK-NEXT: ret i32 poison
+; CHECK-NEXT: ret i32 1
;
%and1 = and i32 %b, 3
%B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 78ac730cf026ed9..d49cb79e1e27c98 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,9 +4016,10 @@ define i32 @abs_preserve(i32 %x) {
declare void @llvm.assume(i1)
define i1 @PR35794(ptr %a) {
; CHECK-LABEL: @PR35794(
-; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
+; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
; CHECK-NEXT: tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT: ret i1 true
+; CHECK-NEXT: ret i1 [[CMP]]
;
%cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
%maskcond = icmp eq ptr %a, null
``````````
</details>
https://github.com/llvm/llvm-project/pull/72637
More information about the llvm-commits
mailing list