[llvm] 28a5e6b - [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (#72637)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 27 07:50:10 PST 2023
Author: Nikita Popov
Date: 2023-11-27T16:50:05+01:00
New Revision: 28a5e6b069ee7540cd0965a0ad6cf0475db8d68c
URL: https://github.com/llvm/llvm-project/commit/28a5e6b069ee7540cd0965a0ad6cf0475db8d68c
DIFF: https://github.com/llvm/llvm-project/commit/28a5e6b069ee7540cd0965a0ad6cf0475db8d68c.diff
LOG: [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (#72637)
For most practical purposes, the only KnownBits patterns we care about are
those involving a constant comparison RHS and constant mask. However,
the actual implementation is written in a very general way -- and of
course, with basically no test coverage of those generalizations.
This patch reduces the implementation to only handle cases with constant
operands. The test changes are all in "make sure we don't crash" tests.
The motivation for this change is an upcoming patch to handling dominating
conditions in computeKnownBits(). Handling non-constant RHS would add
significant additional compile-time overhead in that case, without any
significant impact on optimization quality.
Added:
Modified:
llvm/lib/Analysis/AssumptionCache.cpp
llvm/lib/Analysis/ValueTracking.cpp
llvm/test/Transforms/InstCombine/assume.ll
llvm/test/Transforms/InstCombine/zext-or-icmp.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
AddAffected(B);
if (Pred == ICmpInst::ICMP_EQ) {
- // For equality comparisons, we handle the case of bit inversion.
- auto AddAffectedFromEq = [&AddAffected](Value *V) {
- Value *A, *B;
- // (A & B) or (A | B) or (A ^ B).
- if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
- AddAffected(A);
- AddAffected(B);
- // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
- } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
- AddAffected(A);
- }
- };
-
- AddAffectedFromEq(A);
- AddAffectedFromEq(B);
+ if (match(B, m_ConstantInt())) {
+ Value *X;
+ // (X & C) or (X | C) or (X ^ C).
+ // (X << C) or (X >>_s C) or (X >>_u C).
+ if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+ match(A, m_Shift(m_Value(X), m_ConstantInt())))
+ AddAffected(X);
+ }
} else if (Pred == ICmpInst::ICMP_NE) {
- Value *X, *Y;
- // Handle (a & b != 0). If a/b is a power of 2 we can use this
- // information.
- if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+ Value *X;
+ // Handle (X & pow2 != 0).
+ if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
AddAffected(X);
- AddAffected(Y);
- }
} else if (Pred == ICmpInst::ICMP_ULT) {
Value *X;
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index baf9f488d47d614..cda9c12c2dc6845 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -633,101 +633,89 @@ static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &Q) {
+ if (Cmp->getOperand(1)->getType()->isPointerTy()) {
+ // Handle comparison of pointer to null explicitly, as it will not be
+ // covered by the m_APInt() logic below.
+ if (match(Cmp->getOperand(1), m_Zero())) {
+ switch (Cmp->getPredicate()) {
+ case ICmpInst::ICMP_EQ:
+ Known.setAllZero();
+ break;
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_SGT:
+ Known.makeNonNegative();
+ break;
+ case ICmpInst::ICMP_SLT:
+ Known.makeNegative();
+ break;
+ default:
+ break;
+ }
+ }
+ return;
+ }
+
unsigned BitWidth = Known.getBitWidth();
- // We are attempting to compute known bits for the operands of an assume.
- // Do not try to use other assumptions for those recursive calls because
- // that can lead to mutual recursion and a compile-time explosion.
- // An example of the mutual recursion: computeKnownBits can call
- // isKnownNonZero which calls computeKnownBitsFromAssume (this function)
- // and so on.
- SimplifyQuery QueryNoAC = Q;
- QueryNoAC.AC = nullptr;
-
- // Note that ptrtoint may change the bitwidth.
- Value *A, *B;
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
CmpInst::Predicate Pred;
- uint64_t C;
+ const APInt *Mask, *C;
+ uint64_t ShAmt;
switch (Cmp->getPredicate()) {
case ICmpInst::ICMP_EQ:
- // assume(v = a)
- if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- Known = Known.unionWith(RHSKnown);
- // assume(v & b = a)
- } else if (match(Cmp,
- m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in the mask that are known to be one, we can propagate
- // known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & MaskKnown.One;
- Known.One |= RHSKnown.One & MaskKnown.One;
- // assume(v | b = a)
+ // assume(V = C)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+ Known = Known.unionWith(KnownBits::makeConstant(*C));
+ // assume(V & Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- // assume(v ^ b = a)
+ m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For one bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & *Mask;
+ Known.One |= *C & *Mask;
+ // assume(V | Mask = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For zero bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ // assume(V ^ Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V. For those bits in B that are known to be one,
- // we can propagate inverted known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- Known.Zero |= RHSKnown.One & BKnown.One;
- Known.One |= RHSKnown.Zero & BKnown.One;
- // assume(v << c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
- // For those bits in RHS that are known, we can propagate them to known
- // bits in V shifted to the right by C.
- RHSKnown.Zero.lshrInPlace(C);
- RHSKnown.One.lshrInPlace(C);
+ m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // Equivalent to assume(V == Mask ^ C)
+ Known = Known.unionWith(KnownBits::makeConstant(*C ^ *Mask));
+ // assume(V << ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ // For those bits in C that are known, we can propagate them to known
+ // bits in V shifted to the right by ShAmt.
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
+ RHSKnown.Zero.lshrInPlace(ShAmt);
+ RHSKnown.One.lshrInPlace(ShAmt);
Known = Known.unionWith(RHSKnown);
- // assume(v >> c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+ // assume(V >> ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
- Known.Zero |= RHSKnown.Zero << C;
- Known.One |= RHSKnown.One << C;
+ Known.Zero |= RHSKnown.Zero << ShAmt;
+ Known.One |= RHSKnown.One << ShAmt;
}
break;
case ICmpInst::ICMP_NE: {
- // assume (v & b != 0) where b is a power of 2
+ // assume (V & B != 0) where B is a power of 2
const APInt *BPow2;
- if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+ if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
Known.One |= *BPow2;
- }
break;
}
default:
const APInt *Offset = nullptr;
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
- m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- ConstantRange RHSRange =
- ConstantRange::fromKnownBits(RHSKnown, Cmp->isSigned());
- ConstantRange LHSRange =
- ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
+ m_APInt(C)))) {
+ ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, *C);
if (Offset)
LHSRange = LHSRange.sub(*Offset);
Known = Known.unionWith(LHSRange.toKnownBits());
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 0ae558c6a21da60..86d45bc70640998 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -268,7 +268,7 @@ define i32 @bundle2(ptr %P) {
define i1 @nonnull1(ptr %a) {
; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
; CHECK-NEXT: tail call void @escape(ptr nonnull [[LOAD]])
; CHECK-NEXT: ret i1 false
;
@@ -386,7 +386,7 @@ define i1 @nonnull5(ptr %a) {
define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
; CHECK-LABEL: @assumption_conflicts_with_known_bits(
; CHECK-NEXT: store i1 true, ptr poison, align 1
-; CHECK-NEXT: ret i32 poison
+; CHECK-NEXT: ret i32 1
;
%and1 = and i32 %b, 3
%B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/zext-or-icmp.ll b/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
index dada32d1b744983..bc0e4bdce29b595 100644
--- a/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
+++ b/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
@@ -243,7 +243,7 @@ define i1 @PR51762(ptr %i, i32 %t0, i16 %t1, ptr %p, ptr %d, ptr %f, i32 %p2, i1
; CHECK-NEXT: store i32 [[ADD]], ptr [[F]], align 4
; CHECK-NEXT: [[REM18:%.*]] = srem i32 [[LOR_EXT]], [[ADD]]
; CHECK-NEXT: [[CONV19:%.*]] = zext nneg i32 [[REM18]] to i64
-; CHECK-NEXT: store i32 0, ptr [[D]], align 8
+; CHECK-NEXT: store i32 [[SROA38]], ptr [[D]], align 8
; CHECK-NEXT: [[R:%.*]] = icmp ult i64 [[INSERT_INSERT41]], [[CONV19]]
; CHECK-NEXT: call void @llvm.assume(i1 [[R]])
; CHECK-NEXT: ret i1 [[R]]
More information about the llvm-commits
mailing list