[llvm] [InstSimplify] Generalize simplification of icmps with monotonic operands (PR #69471)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 28 08:07:41 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Nikita Popov (nikic)
<details>
<summary>Changes</summary>
InstSimplify currently folds patterns like `(x | y) uge x` and `(x & y) ule x` to true. However, it cannot handle combinations of such situations, such as `(x | y) uge (x & z)` etc.
To support this, recursively collect operands of monotonic instructions (that preserve either a greater-or-equal or less-or-equal relationship) and then check whether any of them match.
Fixes https://github.com/llvm/llvm-project/issues/69333.
---
Full diff: https://github.com/llvm/llvm-project/pull/69471.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+69-48)
- (modified) llvm/test/Transforms/InstSimplify/icmp-monotonic.ll (+8-32)
``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 01b0a089aab718..be8c2e3c520d3b 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3070,6 +3070,69 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
return nullptr;
}
+enum class MonotonicType { GreaterEq, LowerEq };
+
+/// Get values V_i such that V uge V_i (GreaterEq) or V ule V_i (LowerEq).
+static void getUnsignedMonotonicValues(SmallPtrSetImpl<Value *> &Res, Value *V,
+ MonotonicType Type, unsigned Depth = 0) {
+ if (!Res.insert(V).second)
+ return;
+
+ // Can be increased if useful.
+ if (++Depth > 1)
+ return;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return;
+
+ Value *X, *Y;
+ if (Type == MonotonicType::GreaterEq) {
+ if (match(I, m_Or(m_Value(X), m_Value(Y))) ||
+ match(I, m_Intrinsic<Intrinsic::uadd_sat>(m_Value(X), m_Value(Y)))) {
+ getUnsignedMonotonicValues(Res, X, Type, Depth);
+ getUnsignedMonotonicValues(Res, Y, Type, Depth);
+ }
+ } else {
+ assert(Type == MonotonicType::LowerEq);
+ switch (I->getOpcode()) {
+ case Instruction::And:
+ getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+ getUnsignedMonotonicValues(Res, I->getOperand(1), Type, Depth);
+ break;
+ case Instruction::URem:
+ case Instruction::UDiv:
+ case Instruction::LShr:
+ getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+ break;
+ case Instruction::Call:
+ if (match(I, m_Intrinsic<Intrinsic::usub_sat>(m_Value(X))))
+ getUnsignedMonotonicValues(Res, X, Type, Depth);
+ break;
+ default:
+ break;
+ }
+ }
+}
+
+static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred,
+ Value *LHS, Value *RHS) {
+ if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT)
+ return nullptr;
+
+ // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the
+ // GreaterValues and LowerValues are the same, it follows that LHS uge RHS.
+ SmallPtrSet<Value *, 4> GreaterValues;
+ SmallPtrSet<Value *, 4> LowerValues;
+ getUnsignedMonotonicValues(GreaterValues, LHS, MonotonicType::GreaterEq);
+ getUnsignedMonotonicValues(LowerValues, RHS, MonotonicType::LowerEq);
+ for (Value *GV : GreaterValues)
+ if (LowerValues.contains(GV))
+ return ConstantInt::getBool(getCompareTy(LHS),
+ Pred == ICmpInst::ICMP_UGE);
+ return nullptr;
+}
+
static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
BinaryOperator *LBO, Value *RHS,
const SimplifyQuery &Q,
@@ -3079,11 +3142,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
Value *Y = nullptr;
// icmp pred (or X, Y), X
if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_ULT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_UGE)
- return getTrue(ITy);
-
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
@@ -3094,14 +3152,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (and X, Y), X
- if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// icmp pred (urem X, Y), Y
if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
switch (Pred) {
@@ -3132,27 +3182,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
}
}
- // icmp pred (urem X, Y), X
- if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) {
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- }
-
- // x >>u y <=u x --> true.
- // x >>u y >u x --> false.
- // x udiv y <=u x --> true.
- // x udiv y >u x --> false.
- if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
- match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
- // icmp pred (X op Y), X
- if (Pred == ICmpInst::ICMP_UGT)
- return getFalse(ITy);
- if (Pred == ICmpInst::ICMP_ULE)
- return getTrue(ITy);
- }
-
// If x is nonzero:
// x >>u C <u x --> true for C != 0.
// x >>u C != x --> true for C != 0.
@@ -3702,13 +3731,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
switch (II->getIntrinsicID()) {
case Intrinsic::uadd_sat:
- // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y
- if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) {
- if (Pred == ICmpInst::ICMP_UGE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_ULT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
// uadd.sat(X, Y) uge X + Y
if (match(RHS, m_c_Add(m_Specific(II->getArgOperand(0)),
m_Specific(II->getArgOperand(1))))) {
@@ -3719,13 +3741,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
}
return nullptr;
case Intrinsic::usub_sat:
- // usub.sat(X, Y) ule X
- if (II->getArgOperand(0) == RHS) {
- if (Pred == ICmpInst::ICMP_ULE)
- return ConstantInt::getTrue(getCompareTy(II));
- if (Pred == ICmpInst::ICMP_UGT)
- return ConstantInt::getFalse(getCompareTy(II));
- }
// usub.sat(X, Y) ule X - Y
if (match(RHS, m_Sub(m_Specific(II->getArgOperand(0)),
m_Specific(II->getArgOperand(1))))) {
@@ -4030,6 +4045,12 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
return V;
+ if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS))
+ return V;
+ if (Value *V = simplifyICmpUsingMonotonicValues(
+ ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
+ return V;
+
if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q))
return V;
diff --git a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
index a1daa6bd7b4021..e1a4ee91bd15c5 100644
--- a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
+++ b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
@@ -4,10 +4,7 @@
define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @lshr_or_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = lshr i32 %x, %y
%op2 = or i32 %x, %z
@@ -18,10 +15,7 @@ define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) {
define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @lshr_or_uge_swapped(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp uge i32 [[OP2]], [[OP1]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = lshr i32 %x, %y
%op2 = or i32 %x, %z
@@ -32,10 +26,7 @@ define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) {
define i1 @lshr_or_ugt(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @lshr_or_ugt(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 false
;
%op1 = lshr i32 %x, %y
%op2 = or i32 %x, %z
@@ -74,10 +65,7 @@ define i1 @lshr_or_sle_wrong_pred(i32 %x, i32 %y, i32 %z) {
define i1 @lshr_or_swapped_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @lshr_or_swapped_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[Z]], [[X]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = lshr i32 %x, %y
%op2 = or i32 %z, %x
@@ -102,10 +90,7 @@ define i1 @lshr_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @and_uadd_sat_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = and i32 %x, %y
%op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z)
@@ -116,10 +101,7 @@ define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
define i1 @urem_or_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @urem_or_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = urem i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = urem i32 %x, %y
%op2 = or i32 %x, %z
@@ -144,10 +126,7 @@ define i1 @urem_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
define i1 @udiv_or_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @udiv_or_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = udiv i32 [[X]], [[Y]]
-; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = udiv i32 %x, %y
%op2 = or i32 %x, %z
@@ -172,10 +151,7 @@ define i1 @udiv_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
define i1 @usub_sat_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: define i1 @usub_sat_uadd_sat_ule(
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT: [[OP1:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT: [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT: ret i1 [[CMP]]
+; CHECK-NEXT: ret i1 true
;
%op1 = call i32 @llvm.usub.sat(i32 %x, i32 %y)
%op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z)
``````````
</details>
https://github.com/llvm/llvm-project/pull/69471
More information about the llvm-commits
mailing list