[llvm] f1632d2 - IR: introduce ICmpInst::isImpliedByMatchingCmp (#122597)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 13 08:20:08 PST 2025


Author: Ramkumar Ramachandra
Date: 2025-01-13T16:20:00Z
New Revision: f1632d25db47629221b8a25d79b7993b397f6886

URL: https://github.com/llvm/llvm-project/commit/f1632d25db47629221b8a25d79b7993b397f6886
DIFF: https://github.com/llvm/llvm-project/commit/f1632d25db47629221b8a25d79b7993b397f6886.diff

LOG: IR: introduce ICmpInst::isImpliedByMatchingCmp (#122597)

Create an abstraction over isImplied{True,False}ByMatchingCmp to
faithfully communicate the result of both functions, cleaning up code in
callsites. While at it, fix a bug in the implied-false version of the
function, which was inadvertedenly dropping samesign information.

Added: 
    

Modified: 
    llvm/include/llvm/IR/Instructions.h
    llvm/include/llvm/SandboxIR/Instruction.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/IR/Instructions.cpp
    llvm/lib/Transforms/Scalar/NewGVN.cpp
    llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index 59eb50409883786..9a41971b63373c9 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1266,15 +1266,10 @@ class ICmpInst: public CmpInst {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
-  /// Determine if Pred1 implies Pred2 is true when two compares have matching
-  /// operands.
-  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
-                                         CmpPredicate Pred2);
-
-  /// Determine if Pred1 implies Pred2 is false when two compares have matching
-  /// operands.
-  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
-                                          CmpPredicate Pred2);
+  /// Determine if Pred1 implies Pred2 is true, false, or if nothing can be
+  /// inferred about the implication, when two compares have matching operands.
+  static std::optional<bool> isImpliedByMatchingCmp(CmpPredicate Pred1,
+                                                    CmpPredicate Pred2);
 
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);

diff  --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index d7c1eda81c00607..34a7feb63bec455 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -2547,13 +2547,9 @@ class ICmpInst : public CmpInst {
   WRAP_STATIC_PREDICATE(isGE);
   WRAP_STATIC_PREDICATE(isLE);
 
-  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
-                                         CmpPredicate Pred2) {
-    return llvm::ICmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
-  }
-  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
-                                          CmpPredicate Pred2) {
-    return llvm::ICmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+  static std::optional<bool> isImpliedByMatchingCmp(CmpPredicate Pred1,
+                                                    CmpPredicate Pred2) {
+    return llvm::ICmpInst::isImpliedByMatchingCmp(Pred1, Pred2);
   }
 
   static auto predicates() { return llvm::ICmpInst::predicates(); }

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 0e50fc60ce79218..d03e6f5a5754d50 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9384,19 +9384,6 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
   }
 }
 
-/// Return true if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is true.
-/// Return false if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is false.
-/// Otherwise, return std::nullopt if we can't infer anything.
-static std::optional<bool> isImpliedCondMatchingOperands(CmpPredicate LPred,
-                                                         CmpPredicate RPred) {
-  if (ICmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
-    return true;
-  if (ICmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
-    return false;
-
-  return std::nullopt;
-}
-
 /// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
 /// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
 /// Otherwise, return std::nullopt if we can't infer anything.
@@ -9489,7 +9476,7 @@ isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
 
   // Can we infer anything when the two compares have matching operands?
   if (L0 == R0 && L1 == R1)
-    return isImpliedCondMatchingOperands(LPred, RPred);
+    return ICmpInst::isImpliedByMatchingCmp(LPred, RPred);
 
   // It only really makes sense in the context of signed comparison for "X - Y
   // must be positive if X >= Y and no overflow".
@@ -9499,7 +9486,7 @@ isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
        CmpPredicate::getMatching(LPred, ICmpInst::ICMP_SGE)) &&
       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
     if (match(R1, m_NonPositive()) &&
-        isImpliedCondMatchingOperands(LPred, RPred) == false)
+        ICmpInst::isImpliedByMatchingCmp(LPred, RPred) == false)
       return false;
   }
 
@@ -9509,7 +9496,7 @@ isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
        CmpPredicate::getMatching(LPred, ICmpInst::ICMP_SLE)) &&
       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
     if (match(R1, m_NonNegative()) &&
-        isImpliedCondMatchingOperands(LPred, RPred) == true)
+        ICmpInst::isImpliedByMatchingCmp(LPred, RPred) == true)
       return true;
   }
 

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 49c148bb68a4d38..b8b2c1d7f9a8598 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3886,8 +3886,7 @@ bool CmpInst::isFalseWhenEqual(Predicate predicate) {
   }
 }
 
-bool ICmpInst::isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
-                                          CmpPredicate Pred2) {
+static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1, CmpPredicate Pred2) {
   // If the predicates match, then we know the first condition implies the
   // second is true.
   if (CmpPredicate::getMatching(Pred1, Pred2))
@@ -3901,25 +3900,35 @@ bool ICmpInst::isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
   switch (Pred1) {
   default:
     break;
-  case ICMP_EQ:
+  case CmpInst::ICMP_EQ:
     // A == B implies A >=u B, A <=u B, A >=s B, and A <=s B are true.
-    return Pred2 == ICMP_UGE || Pred2 == ICMP_ULE || Pred2 == ICMP_SGE ||
-           Pred2 == ICMP_SLE;
-  case ICMP_UGT: // A >u B implies A != B and A >=u B are true.
-    return Pred2 == ICMP_NE || Pred2 == ICMP_UGE;
-  case ICMP_ULT: // A <u B implies A != B and A <=u B are true.
-    return Pred2 == ICMP_NE || Pred2 == ICMP_ULE;
-  case ICMP_SGT: // A >s B implies A != B and A >=s B are true.
-    return Pred2 == ICMP_NE || Pred2 == ICMP_SGE;
-  case ICMP_SLT: // A <s B implies A != B and A <=s B are true.
-    return Pred2 == ICMP_NE || Pred2 == ICMP_SLE;
+    return Pred2 == CmpInst::ICMP_UGE || Pred2 == CmpInst::ICMP_ULE ||
+           Pred2 == CmpInst::ICMP_SGE || Pred2 == CmpInst::ICMP_SLE;
+  case CmpInst::ICMP_UGT: // A >u B implies A != B and A >=u B are true.
+    return Pred2 == CmpInst::ICMP_NE || Pred2 == CmpInst::ICMP_UGE;
+  case CmpInst::ICMP_ULT: // A <u B implies A != B and A <=u B are true.
+    return Pred2 == CmpInst::ICMP_NE || Pred2 == CmpInst::ICMP_ULE;
+  case CmpInst::ICMP_SGT: // A >s B implies A != B and A >=s B are true.
+    return Pred2 == CmpInst::ICMP_NE || Pred2 == CmpInst::ICMP_SGE;
+  case CmpInst::ICMP_SLT: // A <s B implies A != B and A <=s B are true.
+    return Pred2 == CmpInst::ICMP_NE || Pred2 == CmpInst::ICMP_SLE;
   }
   return false;
 }
 
-bool ICmpInst::isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
-                                           CmpPredicate Pred2) {
-  return isImpliedTrueByMatchingCmp(Pred1, getInversePredicate(Pred2));
+static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                        CmpPredicate Pred2) {
+  return isImpliedTrueByMatchingCmp(Pred1,
+                                    ICmpInst::getInverseCmpPredicate(Pred2));
+}
+
+std::optional<bool> ICmpInst::isImpliedByMatchingCmp(CmpPredicate Pred1,
+                                                     CmpPredicate Pred2) {
+  if (isImpliedTrueByMatchingCmp(Pred1, Pred2))
+    return true;
+  if (isImpliedFalseByMatchingCmp(Pred1, Pred2))
+    return false;
+  return std::nullopt;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 3812e99508f7385..b5ce860d73523e0 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -1964,18 +1964,10 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
         if (PBranch->TrueEdge) {
           // If we know the previous predicate is true and we are in the true
           // edge then we may be implied true or false.
-          if (ICmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
-                                                   OurPredicate)) {
-            return ExprResult::some(
-                createConstantExpression(ConstantInt::getTrue(CI->getType())),
-                PI);
-          }
-
-          if (ICmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
-                                                    OurPredicate)) {
-            return ExprResult::some(
-                createConstantExpression(ConstantInt::getFalse(CI->getType())),
-                PI);
+          if (auto R = ICmpInst::isImpliedByMatchingCmp(BranchPredicate,
+                                                        OurPredicate)) {
+            auto *C = ConstantInt::getBool(CI->getType(), *R);
+            return ExprResult::some(createConstantExpression(C), PI);
           }
         } else {
           // Just handle the ne and eq cases, where if we have the same

diff  --git a/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
index 35cfadaa2965a7b..0e6db403512aee4 100644
--- a/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
+++ b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
@@ -126,6 +126,19 @@ define i1 @sgt_implies_ge_via_assume(i32 %i, i32 %j) {
   ret i1 %i.ge.j
 }
 
+define i1 @sgt_implies_false_le_via_assume(i32 %i, i32 %j) {
+; CHECK-LABEL: define i1 @sgt_implies_false_le_via_assume(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[J:%.*]]) {
+; CHECK-NEXT:    [[I_SGT_J:%.*]] = icmp sgt i32 [[I]], [[J]]
+; CHECK-NEXT:    call void @llvm.assume(i1 [[I_SGT_J]])
+; CHECK-NEXT:    ret i1 false
+;
+  %i.sgt.j = icmp sgt i32 %i, %j
+  call void @llvm.assume(i1 %i.sgt.j)
+  %i.le.j = icmp samesign ule i32 %i, %j
+  ret i1 %i.le.j
+}
+
 define i32 @gt_implies_sge_dominating(i32 %a, i32 %len) {
 ; CHECK-LABEL: define i32 @gt_implies_sge_dominating(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {
@@ -150,6 +163,30 @@ end:
   ret i32 -1
 }
 
+define i32 @gt_implies_false_sle_dominating(i32 %a, i32 %len) {
+; CHECK-LABEL: define i32 @gt_implies_false_sle_dominating(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[A_GT_LEN:%.*]] = icmp samesign ugt i32 [[A]], [[LEN]]
+; CHECK-NEXT:    br i1 [[A_GT_LEN]], label %[[TAKEN:.*]], label %[[END:.*]]
+; CHECK:       [[TAKEN]]:
+; CHECK-NEXT:    ret i32 0
+; CHECK:       [[END]]:
+; CHECK-NEXT:    ret i32 -1
+;
+entry:
+  %a.gt.len = icmp samesign ugt i32 %a, %len
+  br i1 %a.gt.len, label %taken, label %end
+
+taken:
+  %a.sle.len = icmp sle i32 %a, %len
+  %res = select i1 %a.sle.len, i32 30, i32 0
+  ret i32 %res
+
+end:
+  ret i32 -1
+}
+
 define i32 @gt_implies_sge_dominating_cr(i32 %a, i32 %len) {
 ; CHECK-LABEL: define i32 @gt_implies_sge_dominating_cr(
 ; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {


        


More information about the llvm-commits mailing list