[llvm] IR: teach implied-by-matching-cmp about samesign (PR #120120)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 09:49:03 PST 2024


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/120120

Move isImplied{True,False}ByMatchingCmp from CmpInst to ICmpInst, so that it can operate on CmpPredicate instead of CmpInst::Predicate, and teach it about samesign. Since all callers of these functions operate on CmpInst::Predicate, this patch does not introduce functional changes.

>From c84387c53f19c2df0d054fb13144c192a2445781 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 16 Dec 2024 17:26:20 +0000
Subject: [PATCH] IR: teach implied-by-matching-cmp about samesign

Move isImplied{True,False}ByMatchingCmp from CmpInst to ICmpInst, so
that it can operate on CmpPredicate instead of CmpInst::Predicate, and
teach it about samesign. Since all callers of these functions operate on
CmpInst::Predicate, this patch does not introduce functional changes.
---
 llvm/include/llvm/IR/InstrTypes.h         |  8 --------
 llvm/include/llvm/IR/Instructions.h       | 10 ++++++++++
 llvm/include/llvm/SandboxIR/Instruction.h | 16 +++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp       |  4 ++--
 llvm/lib/IR/Instructions.cpp              | 13 ++++++++++---
 llvm/lib/Transforms/Scalar/NewGVN.cpp     |  8 ++++----
 6 files changed, 35 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e6332a16df7d5f..7ad34e4f223394 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -967,14 +967,6 @@ class CmpInst : public Instruction {
   /// Determine if the predicate is false when comparing a value with itself.
   static bool isFalseWhenEqual(Predicate predicate);
 
-  /// Determine if Pred1 implies Pred2 is true when two compares have matching
-  /// operands.
-  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2);
-
-  /// Determine if Pred1 implies Pred2 is false when two compares have matching
-  /// operands.
-  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2);
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp ||
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index a42bf6bca1b9fb..fe51f34e707ead 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1266,6 +1266,16 @@ class ICmpInst: public CmpInst {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
+  /// Determine if Pred1 implies Pred2 is true when two compares have matching
+  /// operands.
+  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                         CmpPredicate Pred2);
+
+  /// Determine if Pred1 implies Pred2 is false when two compares have matching
+  /// operands.
+  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2);
+
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
   }
diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index 4d21c4d3da3556..d7c1eda81c0060 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -2511,13 +2511,6 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
   WRAP_STATIC_PREDICATE(isOrdered);
   WRAP_STATIC_PREDICATE(isUnordered);
 
-  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
-    return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
-  }
-  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
-    return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
-  }
-
   /// Method for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *From) {
     return From->getSubclassID() == ClassID::ICmp ||
@@ -2554,6 +2547,15 @@ class ICmpInst : public CmpInst {
   WRAP_STATIC_PREDICATE(isGE);
   WRAP_STATIC_PREDICATE(isLE);
 
+  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                         CmpPredicate Pred2) {
+    return llvm::ICmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
+  }
+  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2) {
+    return llvm::ICmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+  }
+
   static auto predicates() { return llvm::ICmpInst::predicates(); }
   static bool compare(const APInt &LHS, const APInt &RHS,
                       ICmpInst::Predicate Pred) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a43f5b6cec2f4e..dd6ba8d6497f4e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9265,9 +9265,9 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
 static std::optional<bool>
 isImpliedCondMatchingOperands(CmpInst::Predicate LPred,
                               CmpInst::Predicate RPred) {
-  if (CmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
+  if (ICmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
     return true;
-  if (CmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
+  if (ICmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
     return false;
 
   return std::nullopt;
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 2d6fe40f4c1de0..49c148bb68a4d3 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3886,12 +3886,18 @@ bool CmpInst::isFalseWhenEqual(Predicate predicate) {
   }
 }
 
-bool CmpInst::isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+bool ICmpInst::isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2) {
   // If the predicates match, then we know the first condition implies the
   // second is true.
-  if (Pred1 == Pred2)
+  if (CmpPredicate::getMatching(Pred1, Pred2))
     return true;
 
+  if (Pred1.hasSameSign() && CmpInst::isSigned(Pred2))
+    Pred1 = ICmpInst::getFlippedSignednessPredicate(Pred1);
+  else if (Pred2.hasSameSign() && CmpInst::isSigned(Pred1))
+    Pred2 = ICmpInst::getFlippedSignednessPredicate(Pred2);
+
   switch (Pred1) {
   default:
     break;
@@ -3911,7 +3917,8 @@ bool CmpInst::isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
   return false;
 }
 
-bool CmpInst::isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+bool ICmpInst::isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                           CmpPredicate Pred2) {
   return isImpliedTrueByMatchingCmp(Pred1, getInversePredicate(Pred2));
 }
 
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 0cba8739441bcb..3812e99508f738 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -1964,15 +1964,15 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
         if (PBranch->TrueEdge) {
           // If we know the previous predicate is true and we are in the true
           // edge then we may be implied true or false.
-          if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
-                                                  OurPredicate)) {
+          if (ICmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
+                                                   OurPredicate)) {
             return ExprResult::some(
                 createConstantExpression(ConstantInt::getTrue(CI->getType())),
                 PI);
           }
 
-          if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
-                                                   OurPredicate)) {
+          if (ICmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
+                                                    OurPredicate)) {
             return ExprResult::some(
                 createConstantExpression(ConstantInt::getFalse(CI->getType())),
                 PI);



More information about the llvm-commits mailing list