[llvm] [InstSimplify] Generalize simplification of icmps with monotonic operands (PR #69471)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 18 07:52:34 PDT 2023


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/69471

InstSimplify currently folds patterns like `(x | y) uge x`  and `(x & y) ule x` to true. However, it cannot handle combinations of such situations, such as `(x | y) uge (x & z)` etc.

To support this, recursively collect operands of monotonic instructions (that preserve either a greater-or-equal or less-or-equal relationship) and then check whether any of them match.

>From f8b434f2ad17477358a6290e866645a98c4bd61f Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 18 Oct 2023 16:21:54 +0200
Subject: [PATCH] wip

---
 llvm/lib/Analysis/InstructionSimplify.cpp | 116 ++++++++++------------
 1 file changed, 50 insertions(+), 66 deletions(-)

diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index b3feb2470e58efd..04a73cbb9c628df 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3103,6 +3103,54 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   return nullptr;
 }
 
+/// Get values V_i such that V uge V_i (Greater) or V ule V_i (!Greater).
+static void getUnsignedMonotonicValues(SmallPtrSetImpl<Value *> &Res, Value *V,
+                                       bool Greater, unsigned Depth = 0) {
+  if (!Res.insert(V).second)
+    return;
+
+  // Can be increased if useful.
+  if (++Depth > 1)
+    return;
+
+  Value *X, *Y;
+  if (Greater) {
+    if (match(V, m_Or(m_Value(X), m_Value(Y))) ||
+        match(V, m_Intrinsic<Intrinsic::uadd_sat>(m_Value(X), m_Value(Y)))) {
+      getUnsignedMonotonicValues(Res, X, Greater, Depth);
+      getUnsignedMonotonicValues(Res, Y, Greater, Depth);
+    }
+  } else {
+    if (match(V, m_And(m_Value(X), m_Value(Y)))) {
+      getUnsignedMonotonicValues(Res, X, Greater, Depth);
+      getUnsignedMonotonicValues(Res, Y, Greater, Depth);
+    } else if (match(V, m_URem(m_Value(X), m_Value())) ||
+               match(V, m_UDiv(m_Value(X), m_Value())) ||
+               match(V, m_LShr(m_Value(X), m_Value())) ||
+               match(V, m_Intrinsic<Intrinsic::usub_sat>(m_Value(X)))) {
+      getUnsignedMonotonicValues(Res, X, Greater, Depth);
+    }
+  }
+}
+
+static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred,
+                                               Value *LHS, Value *RHS) {
+  if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT)
+    return nullptr;
+
+  // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the
+  // GreaterValues and LowerValues are the same, it follows that LHS uge RHS.
+  SmallPtrSet<Value *, 4> GreaterValues;
+  SmallPtrSet<Value *, 4> LowerValues;
+  getUnsignedMonotonicValues(GreaterValues, LHS, /*Greater*/ true);
+  getUnsignedMonotonicValues(LowerValues, RHS, /*Greater*/ false);
+  for (Value *GV : GreaterValues)
+    if (LowerValues.contains(GV))
+      return ConstantInt::getBool(getCompareTy(LHS),
+                                  Pred == ICmpInst::ICMP_UGE);
+  return nullptr;
+}
+
 static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
                                          BinaryOperator *LBO, Value *RHS,
                                          const SimplifyQuery &Q,
@@ -3112,11 +3160,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
   Value *Y = nullptr;
   // icmp pred (or X, Y), X
   if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) {
-    if (Pred == ICmpInst::ICMP_ULT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_UGE)
-      return getTrue(ITy);
-
     if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
       KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
       KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
@@ -3127,14 +3170,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
     }
   }
 
-  // icmp pred (and X, Y), X
-  if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) {
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-  }
-
   // icmp pred (urem X, Y), Y
   if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) {
     switch (Pred) {
@@ -3165,27 +3200,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
     }
   }
 
-  // icmp pred (urem X, Y), X
-  if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) {
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-  }
-
-  // x >>u y <=u x --> true.
-  // x >>u y >u  x --> false.
-  // x udiv y <=u x --> true.
-  // x udiv y >u  x --> false.
-  if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) ||
-      match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) {
-    // icmp pred (X op Y), X
-    if (Pred == ICmpInst::ICMP_UGT)
-      return getFalse(ITy);
-    if (Pred == ICmpInst::ICMP_ULE)
-      return getTrue(ITy);
-  }
-
   // If x is nonzero:
   // x >>u C <u  x --> true  for C != 0.
   // x >>u C !=  x --> true  for C != 0.
@@ -3727,36 +3741,6 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate,
   return nullptr;
 }
 
-static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred,
-                                             Value *LHS, Value *RHS) {
-  auto *II = dyn_cast<IntrinsicInst>(LHS);
-  if (!II)
-    return nullptr;
-
-  switch (II->getIntrinsicID()) {
-  case Intrinsic::uadd_sat:
-    // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y
-    if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) {
-      if (Pred == ICmpInst::ICMP_UGE)
-        return ConstantInt::getTrue(getCompareTy(II));
-      if (Pred == ICmpInst::ICMP_ULT)
-        return ConstantInt::getFalse(getCompareTy(II));
-    }
-    return nullptr;
-  case Intrinsic::usub_sat:
-    // usub.sat(X, Y) ule X
-    if (II->getArgOperand(0) == RHS) {
-      if (Pred == ICmpInst::ICMP_ULE)
-        return ConstantInt::getTrue(getCompareTy(II));
-      if (Pred == ICmpInst::ICMP_UGT)
-        return ConstantInt::getFalse(getCompareTy(II));
-    }
-    return nullptr;
-  default:
-    return nullptr;
-  }
-}
-
 /// Given operands for an ICmpInst, see if we can fold the result.
 /// If not, this returns null.
 static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
@@ -4034,9 +4018,9 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse))
     return V;
 
-  if (Value *V = simplifyICmpWithIntrinsicOnLHS(Pred, LHS, RHS))
+  if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS))
     return V;
-  if (Value *V = simplifyICmpWithIntrinsicOnLHS(
+  if (Value *V = simplifyICmpUsingMonotonicValues(
           ICmpInst::getSwappedPredicate(Pred), RHS, LHS))
     return V;
 



More information about the llvm-commits mailing list