[llvm] [InstSimplify] Generalize simplification of icmps with monotonic operands (PR #69471)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 28 08:07:41 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

InstSimplify currently folds patterns like `(x | y) uge x`  and `(x & y) ule x` to true. However, it cannot handle combinations of such situations, such as `(x | y) uge (x & z)` etc.

To support this, recursively collect operands of monotonic instructions (that preserve either a greater-or-equal or less-or-equal relationship) and then check whether any of them match.

Fixes https://github.com/llvm/llvm-project/issues/69333.

---
Full diff: https://github.com/llvm/llvm-project/pull/69471.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+69-48) 
- (modified) llvm/test/Transforms/InstSimplify/icmp-monotonic.ll (+8-32) 


``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 01b0a089aab718..be8c2e3c520d3b 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3070,6 +3070,69 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
+enum class MonotonicType { GreaterEq, LowerEq };
+
+/// Get values V_i such that V uge V_i (GreaterEq) or V ule V_i (LowerEq).
+static void getUnsignedMonotonicValues(SmallPtrSetImpl<Value *> &Res, Value *V,
+                                       MonotonicType Type, unsigned Depth = 0) {
+  if (!Res.insert(V).second)
+    return;
+
+  // Can be increased if useful.
+  if (++Depth > 1)
+    return;
+
+  auto *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return;
+
+  Value *X, *Y;
+  if (Type == MonotonicType::GreaterEq) {
+    if (match(I, m_Or(m_Value(X), m_Value(Y))) ||
+        match(I, m_Intrinsic<Intrinsic::uadd_sat>(m_Value(X), m_Value(Y)))) {
+      getUnsignedMonotonicValues(Res, X, Type, Depth);
+      getUnsignedMonotonicValues(Res, Y, Type, Depth);
+    }
+  } else {
+    assert(Type == MonotonicType::LowerEq);
+    switch (I->getOpcode()) {
+    case Instruction::And:
+      getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+      getUnsignedMonotonicValues(Res, I->getOperand(1), Type, Depth);
+      break;
+    case Instruction::URem:
+    case Instruction::UDiv:
+    case Instruction::LShr:
+      getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth);
+      break;
+    case Instruction::Call:
+      if (match(I, m_Intrinsic<Intrinsic::usub_sat>(m_Value(X))))
+        getUnsignedMonotonicValues(Res, X, Type, Depth);
+      break;
+    default:
+      break;
+    }
+  }
+}
+
+static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred,
+                                               Value *LHS, Value *RHS) {
+  if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT)
+    return nullptr;
+
+  // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the
+  // GreaterValues and LowerValues are the same, it follows that LHS uge RHS.
+  SmallPtrSet<Value *, 4> GreaterValues;
+  SmallPtrSet<Value *, 4> LowerValues;
+  getUnsignedMonotonicValues(GreaterValues, LHS, MonotonicType::GreaterEq);
+  getUnsignedMonotonicValues(LowerValues, RHS, MonotonicType::LowerEq);
+  for (Value *GV : GreaterValues)
+    if (LowerValues.contains(GV))
+      return ConstantInt::getBool(getCompareTy(LHS),
+                                  Pred == ICmpInst::ICMP_UGE);
+  return nullptr;
+}
+
 static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
                                          BinaryOperator *LBO, Value *RHS,
                                          const SimplifyQuery &Q,
@@ -3079,11 +3142,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
   Value *Y = nullptr;
   // icmp pred (or X, Y), X
   if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
-    if (Pred == ICmpInst::ICMP_ULT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_UGE)
-      return getTrue(ITy);
-
     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
       KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
       KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
@@ -3094,14 +3152,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
     }
   }
 
-  // icmp pred (and X, Y), X
-  if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-  }
-
   // icmp pred (urem X, Y), Y
   if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
     switch (Pred) {
@@ -3132,27 +3182,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
     }
   }
 
-  // icmp pred (urem X, Y), X
-  if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) {
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-  }
-
-  // x >>u y <=u x --> true.
-  // x >>u y >u  x --> false.
-  // x udiv y <=u x --> true.
-  // x udiv y >u  x --> false.
-  if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
-      match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
-    // icmp pred (X op Y), X
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-  }
-
   // If x is nonzero:
   // x >>u C <u  x --> true  for C != 0.
   // x >>u C !=  x --> true  for C != 0.
@@ -3702,13 +3731,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
 
   switch (II->getIntrinsicID()) {
   case Intrinsic::uadd_sat:
-    // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y
-    if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) {
-      if (Pred == ICmpInst::ICMP_UGE)
-        return ConstantInt::getTrue(getCompareTy(II));
-      if (Pred == ICmpInst::ICMP_ULT)
-        return ConstantInt::getFalse(getCompareTy(II));
-    }
     // uadd.sat(X, Y) uge X + Y
     if (match(RHS, m_c_Add(m_Specific(II->getArgOperand(0)),
                            m_Specific(II->getArgOperand(1))))) {
@@ -3719,13 +3741,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
     }
     return nullptr;
   case Intrinsic::usub_sat:
-    // usub.sat(X, Y) ule X
-    if (II->getArgOperand(0) == RHS) {
-      if (Pred == ICmpInst::ICMP_ULE)
-        return ConstantInt::getTrue(getCompareTy(II));
-      if (Pred == ICmpInst::ICMP_UGT)
-        return ConstantInt::getFalse(getCompareTy(II));
-    }
     // usub.sat(X, Y) ule X - Y
     if (match(RHS, m_Sub(m_Specific(II->getArgOperand(0)),
                          m_Specific(II->getArgOperand(1))))) {
@@ -4030,6 +4045,12 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
           ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
     return V;
 
+  if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS))
+    return V;
+  if (Value *V = simplifyICmpUsingMonotonicValues(
+          ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
+    return V;
+
   if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q))
     return V;
 
diff --git a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
index a1daa6bd7b4021..e1a4ee91bd15c5 100644
--- a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
+++ b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll
@@ -4,10 +4,7 @@
 define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @lshr_or_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = lshr i32 %x, %y
   %op2 = or i32 %x, %z
@@ -18,10 +15,7 @@ define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) {
 define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @lshr_or_uge_swapped(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp uge i32 [[OP2]], [[OP1]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = lshr i32 %x, %y
   %op2 = or i32 %x, %z
@@ -32,10 +26,7 @@ define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) {
 define i1 @lshr_or_ugt(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @lshr_or_ugt(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 false
 ;
   %op1 = lshr i32 %x, %y
   %op2 = or i32 %x, %z
@@ -74,10 +65,7 @@ define i1 @lshr_or_sle_wrong_pred(i32 %x, i32 %y, i32 %z) {
 define i1 @lshr_or_swapped_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @lshr_or_swapped_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = lshr i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[Z]], [[X]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = lshr i32 %x, %y
   %op2 = or i32 %z, %x
@@ -102,10 +90,7 @@ define i1 @lshr_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
 define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @and_uadd_sat_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = and i32 %x, %y
   %op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z)
@@ -116,10 +101,7 @@ define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
 define i1 @urem_or_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @urem_or_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = urem i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = urem i32 %x, %y
   %op2 = or i32 %x, %z
@@ -144,10 +126,7 @@ define i1 @urem_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
 define i1 @udiv_or_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @udiv_or_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = udiv i32 [[X]], [[Y]]
-; CHECK-NEXT:    [[OP2:%.*]] = or i32 [[X]], [[Z]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = udiv i32 %x, %y
   %op2 = or i32 %x, %z
@@ -172,10 +151,7 @@ define i1 @udiv_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) {
 define i1 @usub_sat_uadd_sat_ule(i32 %x, i32 %y, i32 %z) {
 ; CHECK-LABEL: define i1 @usub_sat_uadd_sat_ule(
 ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) {
-; CHECK-NEXT:    [[OP1:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[X]], i32 [[Y]])
-; CHECK-NEXT:    [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]])
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]]
-; CHECK-NEXT:    ret i1 [[CMP]]
+; CHECK-NEXT:    ret i1 true
 ;
   %op1 = call i32 @llvm.usub.sat(i32 %x, i32 %y)
   %op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z)

``````````

</details>


https://github.com/llvm/llvm-project/pull/69471


More information about the llvm-commits mailing list