[llvm] [InstSimplify] Generalize simplification of icmps with monotonic operands (PR #69471)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 18 07:52:34 PDT 2023
https://github.com/nikic created https://github.com/llvm/llvm-project/pull/69471
InstSimplify currently folds patterns like `(x | y) uge x` and `(x & y) ule x` to true. However, it cannot handle combinations of such situations, such as `(x | y) uge (x & z)` etc.
To support this, recursively collect operands of monotonic instructions (that preserve either a greater-or-equal or less-or-equal relationship) and then check whether any of them match.
>From f8b434f2ad17477358a6290e866645a98c4bd61f Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 18 Oct 2023 16:21:54 +0200
Subject: [PATCH] wip
---
llvm/lib/Analysis/InstructionSimplify.cpp | 116 ++++++++++------------
1 file changed, 50 insertions(+), 66 deletions(-)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index b3feb2470e58efd..04a73cbb9c628df 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3103,6 +3103,54 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
return nullptr;
}
+/// Get values V_i such that V uge V_i (Greater) or V ule V_i (!Greater).
+static void getUnsignedMonotonicValues(SmallPtrSetImpl<Value *> &Res, Value *V,
+ bool Greater, unsigned Depth = 0) {
+ if (!Res.insert(V).second)
+ return;
+
+ // Can be increased if useful.
+ if (++Depth > 1)
+ return;
+
+ Value *X, *Y;
+ if (Greater) {
+ if (match(V, m_Or(m_Value(X), m_Value(Y))) ||
+ match(V, m_Intrinsic<Intrinsic::uadd_sat>(m_Value(X), m_Value(Y)))) {
+ getUnsignedMonotonicValues(Res, X, Greater, Depth);
+ getUnsignedMonotonicValues(Res, Y, Greater, Depth);
+ }
+ } else {
+ if (match(V, m_And(m_Value(X), m_Value(Y)))) {
+ getUnsignedMonotonicValues(Res, X, Greater, Depth);
+ getUnsignedMonotonicValues(Res, Y, Greater, Depth);
+ } else if (match(V, m_URem(m_Value(X), m_Value())) ||
+ match(V, m_UDiv(m_Value(X), m_Value())) ||
+ match(V, m_LShr(m_Value(X), m_Value())) ||
+ match(V, m_Intrinsic<Intrinsic::usub_sat>(m_Value(X)))) {
+ getUnsignedMonotonicValues(Res, X, Greater, Depth);
+ }
+ }
+}
+
+static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred,
+ Value *LHS, Value *RHS) {
+ if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT)
+ return nullptr;
+
+ // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the
+ // GreaterValues and LowerValues are the same, it follows that LHS uge RHS.
+ SmallPtrSet<Value *, 4> GreaterValues;
+ SmallPtrSet<Value *, 4> LowerValues;
+ getUnsignedMonotonicValues(GreaterValues, LHS, /*Greater*/ true);
+ getUnsignedMonotonicValues(LowerValues, RHS, /*Greater*/ false);
+ for (Value *GV : GreaterValues)
+ if (LowerValues.contains(GV))
+ return ConstantInt::getBool(getCompareTy(LHS),
+ Pred == ICmpInst::ICMP_UGE);
+ return nullptr;
+}
+
static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
BinaryOperator *LBO, Value *RHS,
const SimplifyQuery &Q,
@@ -3112,11 +3160,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
Value *Y = nullptr;
// icmp pred (or X, Y), X
if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_ULT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_UGE)
- return getTrue(ITy);
-
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
@@ -3127,14 +3170,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (and X, Y), X
- if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// icmp pred (urem X, Y), Y
if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
switch (Pred) {
@@ -3165,27 +3200,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (urem X, Y), X
- if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) {
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- }
-
- // x >>u y <=u x --> true.
- // x >>u y >u x --> false.
- // x udiv y <=u x --> true.
- // x udiv y >u x --> false.
- if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
- match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
- // icmp pred (X op Y), X
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// If x is nonzero:
// x >>u C <u x --> true for C != 0.
// x >>u C != x --> true for C != 0.
@@ -3727,36 +3741,6 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate,
return nullptr;
}
-static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
- Value *LHS, Value *RHS) {
- auto *II = dyn_cast<IntrinsicInst>(LHS);
- if (!II)
- return nullptr;
-
- switch (II->getIntrinsicID()) {
- case Intrinsic::uadd_sat:
- // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y
- if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) {
- if (Pred == ICmpInst::ICMP_UGE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_ULT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
- return nullptr;
- case Intrinsic::usub_sat:
- // usub.sat(X, Y) ule X
- if (II->getArgOperand(0) == RHS) {
- if (Pred == ICmpInst::ICMP_ULE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_UGT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
- return nullptr;
- default:
- return nullptr;
- }
-}
-
/// Given operands for an ICmpInst, see if we can fold the result.
/// If not, this returns null.
static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
@@ -4034,9 +4018,9 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse))
return V;
- if (Value *V = simplifyICmpWithIntrinsicOnLHS(Pred, LHS, RHS))
+ if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS))
return V;
- if (Value *V = simplifyICmpWithIntrinsicOnLHS(
+ if (Value *V = simplifyICmpUsingMonotonicValues(
ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
return V;
More information about the llvm-commits
mailing list