[llvm] [InstCombine] Remove over-generalization from computeKnownBitsFromCmp() (PR #72637)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 20 06:52:13 PST 2023
https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/72637
>From c2a32311b1d5d289c43cac59ec82b80497d9a859 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Fri, 17 Nov 2023 11:07:49 +0100
Subject: [PATCH 1/4] wip
---
llvm/lib/Analysis/AssumptionCache.cpp | 32 +++----
llvm/lib/Analysis/ValueTracking.cpp | 100 +++++++++------------
llvm/test/Transforms/InstCombine/assume.ll | 4 +-
llvm/test/Transforms/InstCombine/icmp.ll | 5 +-
4 files changed, 59 insertions(+), 82 deletions(-)
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
AddAffected(B);
if (Pred == ICmpInst::ICMP_EQ) {
- // For equality comparisons, we handle the case of bit inversion.
- auto AddAffectedFromEq = [&AddAffected](Value *V) {
- Value *A, *B;
- // (A & B) or (A | B) or (A ^ B).
- if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
- AddAffected(A);
- AddAffected(B);
- // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
- } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
- AddAffected(A);
- }
- };
-
- AddAffectedFromEq(A);
- AddAffectedFromEq(B);
+ if (match(B, m_ConstantInt())) {
+ Value *X;
+ // (X & C) or (X | C) or (X ^ C).
+ // (X << C) or (X >>_s C) or (X >>_u C).
+ if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+ match(A, m_Shift(m_Value(X), m_ConstantInt())))
+ AddAffected(X);
+ }
} else if (Pred == ICmpInst::ICMP_NE) {
- Value *X, *Y;
- // Handle (a & b != 0). If a/b is a power of 2 we can use this
- // information.
- if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+ Value *X;
+ // Handle (X & pow2 != 0).
+ if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
AddAffected(X);
- AddAffected(Y);
- }
} else if (Pred == ICmpInst::ICMP_ULT) {
Value *X;
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index d1af0ea35e5e751..7a70ae9ee424d71 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -640,82 +640,68 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
QueryNoAC.AC = nullptr;
// Note that ptrtoint may change the bitwidth.
- Value *A, *B;
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
CmpInst::Predicate Pred;
- uint64_t C;
+ const APInt *Mask, *C;
+ uint64_t ShAmt;
switch (Cmp->getPredicate()) {
case ICmpInst::ICMP_EQ:
- // assume(v = a)
- if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- Known = Known.unionWith(RHSKnown);
- // assume(v & b = a)
- } else if (match(Cmp,
- m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in the mask that are known to be one, we can propagate
- // known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & MaskKnown.One;
- Known.One |= RHSKnown.One & MaskKnown.One;
- // assume(v | b = a)
+ // assume(V = C)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+ Known = Known.unionWith(KnownBits::makeConstant(*C));
+ // assume(V & Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- // assume(v ^ b = a)
+ m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For one bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & *Mask;
+ Known.One |= *C & *Mask;
+ // assume(V | Mask = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For zero bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ // assume(V ^ Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V. For those bits in B that are known to be one,
- // we can propagate inverted known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- Known.Zero |= RHSKnown.One & BKnown.One;
- Known.One |= RHSKnown.Zero & BKnown.One;
- // assume(v << c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
- // For those bits in RHS that are known, we can propagate them to known
- // bits in V shifted to the right by C.
- RHSKnown.Zero.lshrInPlace(C);
- RHSKnown.One.lshrInPlace(C);
+ m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For those bits in Mask that are zero, we can propagate known bits
+ // from C to V. For those bits in Mask that are one, we can propagate
+ // inverted bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ Known.Zero |= *C & *Mask;
+ Known.One |= ~*C & *Mask;
+ // assume(V << ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ // For those bits in C that are known, we can propagate them to known
+ // bits in V shifted to the right by ShAmt.
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
+ RHSKnown.Zero.lshrInPlace(ShAmt);
+ RHSKnown.One.lshrInPlace(ShAmt);
Known = Known.unionWith(RHSKnown);
- // assume(v >> c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+ // assume(V >> ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
- Known.Zero |= RHSKnown.Zero << C;
- Known.One |= RHSKnown.One << C;
+ Known.Zero |= RHSKnown.Zero << ShAmt;
+ Known.One |= RHSKnown.One << ShAmt;
}
break;
case ICmpInst::ICMP_NE: {
- // assume (v & b != 0) where b is a power of 2
+ // assume (V & B != 0) where B is a power of 2
const APInt *BPow2;
- if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+ if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
Known.One |= *BPow2;
- }
break;
}
default:
+ Value *A;
const APInt *Offset = nullptr;
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
m_Value(A)))) {
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 934e9594f3f7b5a..5b2039b5b480501 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -265,7 +265,7 @@ define i32 @bundle2(ptr %P) {
define i1 @nonnull1(ptr %a) {
; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
; CHECK-NEXT: tail call void @escape(ptr nonnull [[LOAD]])
; CHECK-NEXT: ret i1 false
;
@@ -383,7 +383,7 @@ define i1 @nonnull5(ptr %a) {
define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
; CHECK-LABEL: @assumption_conflicts_with_known_bits(
; CHECK-NEXT: store i1 true, ptr poison, align 1
-; CHECK-NEXT: ret i32 poison
+; CHECK-NEXT: ret i32 1
;
%and1 = and i32 %b, 3
%B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 78ac730cf026ed9..d49cb79e1e27c98 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,9 +4016,10 @@ define i32 @abs_preserve(i32 %x) {
declare void @llvm.assume(i1)
define i1 @PR35794(ptr %a) {
; CHECK-LABEL: @PR35794(
-; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
+; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
; CHECK-NEXT: tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT: ret i1 true
+; CHECK-NEXT: ret i1 [[CMP]]
;
%cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
%maskcond = icmp eq ptr %a, null
>From c5e3d8cad8e6a632fe85de297825a8acc5a1d40d Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 20 Nov 2023 15:04:25 +0100
Subject: [PATCH 2/4] Simplify xor case
---
llvm/lib/Analysis/ValueTracking.cpp | 9 ++-------
1 file changed, 2 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 7a70ae9ee424d71..a6084bfceb593b5 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -665,13 +665,8 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
// assume(V ^ Mask = C)
} else if (match(Cmp,
m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
- // For those bits in Mask that are zero, we can propagate known bits
- // from C to V. For those bits in Mask that are one, we can propagate
- // inverted bits from C to V.
- Known.Zero |= ~*C & ~*Mask;
- Known.One |= *C & ~*Mask;
- Known.Zero |= *C & *Mask;
- Known.One |= ~*C & *Mask;
+ // Equivalent to assume(V == Mask ^ C)
+ Known = Known.unionWith(KnownBits::makeConstant(*C ^ *Mask));
// assume(V << ShAmt = C)
} else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
m_APInt(C))) &&
>From e3633bfb7c645988eecc46eac242a1d2e4836749 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 20 Nov 2023 15:07:38 +0100
Subject: [PATCH 3/4] Also drop non-constant support for simple comparison
---
llvm/lib/Analysis/ValueTracking.cpp | 19 ++-----------------
.../Transforms/InstCombine/zext-or-icmp.ll | 2 +-
2 files changed, 3 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a6084bfceb593b5..137694e235b5928 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -630,16 +630,6 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &Q) {
unsigned BitWidth = Known.getBitWidth();
- // We are attempting to compute known bits for the operands of an assume.
- // Do not try to use other assumptions for those recursive calls because
- // that can lead to mutual recursion and a compile-time explosion.
- // An example of the mutual recursion: computeKnownBits can call
- // isKnownNonZero which calls computeKnownBitsFromAssume (this function)
- // and so on.
- SimplifyQuery QueryNoAC = Q;
- QueryNoAC.AC = nullptr;
-
- // Note that ptrtoint may change the bitwidth.
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
@@ -696,15 +686,10 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
break;
}
default:
- Value *A;
const APInt *Offset = nullptr;
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
- m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- ConstantRange RHSRange =
- ConstantRange::fromKnownBits(RHSKnown, Cmp->isSigned());
- ConstantRange LHSRange =
- ConstantRange::makeAllowedICmpRegion(Pred, RHSRange);
+ m_APInt(C)))) {
+ ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, *C);
if (Offset)
LHSRange = LHSRange.sub(*Offset);
Known = Known.unionWith(LHSRange.toKnownBits());
diff --git a/llvm/test/Transforms/InstCombine/zext-or-icmp.ll b/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
index dada32d1b744983..bc0e4bdce29b595 100644
--- a/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
+++ b/llvm/test/Transforms/InstCombine/zext-or-icmp.ll
@@ -243,7 +243,7 @@ define i1 @PR51762(ptr %i, i32 %t0, i16 %t1, ptr %p, ptr %d, ptr %f, i32 %p2, i1
; CHECK-NEXT: store i32 [[ADD]], ptr [[F]], align 4
; CHECK-NEXT: [[REM18:%.*]] = srem i32 [[LOR_EXT]], [[ADD]]
; CHECK-NEXT: [[CONV19:%.*]] = zext nneg i32 [[REM18]] to i64
-; CHECK-NEXT: store i32 0, ptr [[D]], align 8
+; CHECK-NEXT: store i32 [[SROA38]], ptr [[D]], align 8
; CHECK-NEXT: [[R:%.*]] = icmp ult i64 [[INSERT_INSERT41]], [[CONV19]]
; CHECK-NEXT: call void @llvm.assume(i1 [[R]])
; CHECK-NEXT: ret i1 [[R]]
>From 3e6f8bde08690c98ede2161c4ce1b23b81f96d89 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Mon, 20 Nov 2023 15:33:04 +0100
Subject: [PATCH 4/4] Explicitly handle pointer comparisons
These don't get handled by the m_APInt code.
---
llvm/lib/Analysis/ValueTracking.cpp | 22 ++++++++++++++++++++++
llvm/test/Transforms/InstCombine/icmp.ll | 5 ++---
2 files changed, 24 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 137694e235b5928..f285cadc38d18b7 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -629,6 +629,28 @@ static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &Q) {
+ if (Cmp->getOperand(1)->getType()->isPointerTy()) {
+ // Handle comparison of pointer to null explicitly, as it will not be
+ // covered by the m_APInt() logic below.
+ if (match(Cmp->getOperand(1), m_Zero())) {
+ switch (Cmp->getPredicate()) {
+ case ICmpInst::ICMP_EQ:
+ Known.setAllZero();
+ break;
+ case ICmpInst::ICMP_SGE:
+ case ICmpInst::ICMP_SGT:
+ Known.makeNonNegative();
+ break;
+ case ICmpInst::ICMP_SLT:
+ Known.makeNegative();
+ break;
+ default:
+ break;
+ }
+ }
+ return;
+ }
+
unsigned BitWidth = Known.getBitWidth();
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index d49cb79e1e27c98..78ac730cf026ed9 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,10 +4016,9 @@ define i32 @abs_preserve(i32 %x) {
declare void @llvm.assume(i1)
define i1 @PR35794(ptr %a) {
; CHECK-LABEL: @PR35794(
-; CHECK-NEXT: [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
-; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
+; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
; CHECK-NEXT: tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
%maskcond = icmp eq ptr %a, null
More information about the llvm-commits
mailing list