[llvm] IR: teach implied-by-matching-cmp about samesign (PR #120120)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 17 06:42:48 PST 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/120120

>From c84387c53f19c2df0d054fb13144c192a2445781 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 16 Dec 2024 17:26:20 +0000
Subject: [PATCH 1/3] IR: teach implied-by-matching-cmp about samesign

Move isImplied{True,False}ByMatchingCmp from CmpInst to ICmpInst, so
that it can operate on CmpPredicate instead of CmpInst::Predicate, and
teach it about samesign. Since all callers of these functions operate on
CmpInst::Predicate, this patch does not introduce functional changes.
---
 llvm/include/llvm/IR/InstrTypes.h         |  8 --------
 llvm/include/llvm/IR/Instructions.h       | 10 ++++++++++
 llvm/include/llvm/SandboxIR/Instruction.h | 16 +++++++++-------
 llvm/lib/Analysis/ValueTracking.cpp       |  4 ++--
 llvm/lib/IR/Instructions.cpp              | 13 ++++++++++---
 llvm/lib/Transforms/Scalar/NewGVN.cpp     |  8 ++++----
 6 files changed, 35 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h
index e6332a16df7d5f..7ad34e4f223394 100644
--- a/llvm/include/llvm/IR/InstrTypes.h
+++ b/llvm/include/llvm/IR/InstrTypes.h
@@ -967,14 +967,6 @@ class CmpInst : public Instruction {
   /// Determine if the predicate is false when comparing a value with itself.
   static bool isFalseWhenEqual(Predicate predicate);
 
-  /// Determine if Pred1 implies Pred2 is true when two compares have matching
-  /// operands.
-  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2);
-
-  /// Determine if Pred1 implies Pred2 is false when two compares have matching
-  /// operands.
-  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2);
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp ||
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index a42bf6bca1b9fb..fe51f34e707ead 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1266,6 +1266,16 @@ class ICmpInst: public CmpInst {
     return getFlippedSignednessPredicate(getPredicate());
   }
 
+  /// Determine if Pred1 implies Pred2 is true when two compares have matching
+  /// operands.
+  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                         CmpPredicate Pred2);
+
+  /// Determine if Pred1 implies Pred2 is false when two compares have matching
+  /// operands.
+  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2);
+
   void setSameSign(bool B = true) {
     SubclassOptionalData = (SubclassOptionalData & ~SameSign) | (B * SameSign);
   }
diff --git a/llvm/include/llvm/SandboxIR/Instruction.h b/llvm/include/llvm/SandboxIR/Instruction.h
index 4d21c4d3da3556..d7c1eda81c0060 100644
--- a/llvm/include/llvm/SandboxIR/Instruction.h
+++ b/llvm/include/llvm/SandboxIR/Instruction.h
@@ -2511,13 +2511,6 @@ class CmpInst : public SingleLLVMInstructionImpl<llvm::CmpInst> {
   WRAP_STATIC_PREDICATE(isOrdered);
   WRAP_STATIC_PREDICATE(isUnordered);
 
-  static bool isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
-    return llvm::CmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
-  }
-  static bool isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
-    return llvm::CmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
-  }
-
   /// Method for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *From) {
     return From->getSubclassID() == ClassID::ICmp ||
@@ -2554,6 +2547,15 @@ class ICmpInst : public CmpInst {
   WRAP_STATIC_PREDICATE(isGE);
   WRAP_STATIC_PREDICATE(isLE);
 
+  static bool isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                         CmpPredicate Pred2) {
+    return llvm::ICmpInst::isImpliedTrueByMatchingCmp(Pred1, Pred2);
+  }
+  static bool isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2) {
+    return llvm::ICmpInst::isImpliedFalseByMatchingCmp(Pred1, Pred2);
+  }
+
   static auto predicates() { return llvm::ICmpInst::predicates(); }
   static bool compare(const APInt &LHS, const APInt &RHS,
                       ICmpInst::Predicate Pred) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index a43f5b6cec2f4e..dd6ba8d6497f4e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9265,9 +9265,9 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
 static std::optional<bool>
 isImpliedCondMatchingOperands(CmpInst::Predicate LPred,
                               CmpInst::Predicate RPred) {
-  if (CmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
+  if (ICmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
     return true;
-  if (CmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
+  if (ICmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
     return false;
 
   return std::nullopt;
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 2d6fe40f4c1de0..49c148bb68a4d3 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -3886,12 +3886,18 @@ bool CmpInst::isFalseWhenEqual(Predicate predicate) {
   }
 }
 
-bool CmpInst::isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+bool ICmpInst::isImpliedTrueByMatchingCmp(CmpPredicate Pred1,
+                                          CmpPredicate Pred2) {
   // If the predicates match, then we know the first condition implies the
   // second is true.
-  if (Pred1 == Pred2)
+  if (CmpPredicate::getMatching(Pred1, Pred2))
     return true;
 
+  if (Pred1.hasSameSign() && CmpInst::isSigned(Pred2))
+    Pred1 = ICmpInst::getFlippedSignednessPredicate(Pred1);
+  else if (Pred2.hasSameSign() && CmpInst::isSigned(Pred1))
+    Pred2 = ICmpInst::getFlippedSignednessPredicate(Pred2);
+
   switch (Pred1) {
   default:
     break;
@@ -3911,7 +3917,8 @@ bool CmpInst::isImpliedTrueByMatchingCmp(Predicate Pred1, Predicate Pred2) {
   return false;
 }
 
-bool CmpInst::isImpliedFalseByMatchingCmp(Predicate Pred1, Predicate Pred2) {
+bool ICmpInst::isImpliedFalseByMatchingCmp(CmpPredicate Pred1,
+                                           CmpPredicate Pred2) {
   return isImpliedTrueByMatchingCmp(Pred1, getInversePredicate(Pred2));
 }
 
diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp
index 0cba8739441bcb..3812e99508f738 100644
--- a/llvm/lib/Transforms/Scalar/NewGVN.cpp
+++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp
@@ -1964,15 +1964,15 @@ NewGVN::ExprResult NewGVN::performSymbolicCmpEvaluation(Instruction *I) const {
         if (PBranch->TrueEdge) {
           // If we know the previous predicate is true and we are in the true
           // edge then we may be implied true or false.
-          if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
-                                                  OurPredicate)) {
+          if (ICmpInst::isImpliedTrueByMatchingCmp(BranchPredicate,
+                                                   OurPredicate)) {
             return ExprResult::some(
                 createConstantExpression(ConstantInt::getTrue(CI->getType())),
                 PI);
           }
 
-          if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
-                                                   OurPredicate)) {
+          if (ICmpInst::isImpliedFalseByMatchingCmp(BranchPredicate,
+                                                    OurPredicate)) {
             return ExprResult::some(
                 createConstantExpression(ConstantInt::getFalse(CI->getType())),
                 PI);

>From d04a7a60d0dd733674bea999142d0310f51b0d7f Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 17 Dec 2024 12:51:18 +0000
Subject: [PATCH 2/3] VT/test: pre-commit tests to enable samesign

---
 .../implied-condition-samesign.ll             | 228 ++++++++++++++++++
 1 file changed, 228 insertions(+)
 create mode 100644 llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll

diff --git a/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
new file mode 100644
index 00000000000000..1d58ca5ead9cca
--- /dev/null
+++ b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
@@ -0,0 +1,228 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=instsimplify -S %s | FileCheck %s
+
+define i1 @incr_sle(i32 %i, i32 %len) {
+; CHECK-LABEL: define i1 @incr_sle(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
+; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ugt i32 [[I]], [[LEN]]
+; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp sgt i32 [[I_INCR]], [[LEN]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp sle i1 [[I_INCR_LT_LEN]], [[I_LT_LEN]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %i.incr = add nsw nuw i32 %i, 1
+  %i.lt.len = icmp samesign ugt i32 %i, %len
+  %i.incr.lt.len = icmp sgt i32 %i.incr, %len
+  %res = icmp sle i1 %i.incr.lt.len, %i.lt.len
+  ret i1 %res
+}
+
+define i1 @incr_sge(i32 %i, i32 %len) {
+; CHECK-LABEL: define i1 @incr_sge(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
+; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ult i32 [[I]], [[LEN]]
+; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp slt i32 [[I_INCR]], [[LEN]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp sge i1 [[I_INCR_LT_LEN]], [[I_LT_LEN]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %i.incr = add nsw nuw i32 %i, 1
+  %i.lt.len = icmp samesign ult i32 %i, %len
+  %i.incr.lt.len = icmp slt i32 %i.incr, %len
+  %res = icmp sge i1 %i.incr.lt.len, %i.lt.len
+  ret i1 %res
+}
+
+define i1 @incr_ule(i32 %i, i32 %len) {
+; CHECK-LABEL: define i1 @incr_ule(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
+; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ugt i32 [[I]], [[LEN]]
+; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp sgt i32 [[I_INCR]], [[LEN]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp ule i1 [[I_LT_LEN]], [[I_INCR_LT_LEN]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %i.incr = add nsw nuw i32 %i, 1
+  %i.lt.len = icmp samesign ugt i32 %i, %len
+  %i.incr.lt.len = icmp sgt i32 %i.incr, %len
+  %res = icmp ule i1 %i.lt.len, %i.incr.lt.len
+  ret i1 %res
+}
+
+define i1 @incr_uge(i32 %i, i32 %len) {
+; CHECK-LABEL: define i1 @incr_uge(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
+; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ult i32 [[I]], [[LEN]]
+; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp slt i32 [[I_INCR]], [[LEN]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp uge i1 [[I_LT_LEN]], [[I_INCR_LT_LEN]]
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %i.incr = add nsw nuw i32 %i, 1
+  %i.lt.len = icmp samesign ult i32 %i, %len
+  %i.incr.lt.len = icmp slt i32 %i.incr, %len
+  %res = icmp uge i1 %i.lt.len, %i.incr.lt.len
+  ret i1 %res
+}
+
+define i1 @sgt_implies_ge_via_assume(i32 %i, i32 %j) {
+; CHECK-LABEL: define i1 @sgt_implies_ge_via_assume(
+; CHECK-SAME: i32 [[I:%.*]], i32 [[J:%.*]]) {
+; CHECK-NEXT:    [[I_GT_J:%.*]] = icmp sgt i32 [[I]], [[J]]
+; CHECK-NEXT:    call void @llvm.assume(i1 [[I_GT_J]])
+; CHECK-NEXT:    [[I_GE_J:%.*]] = icmp samesign uge i32 [[I]], [[J]]
+; CHECK-NEXT:    ret i1 [[I_GE_J]]
+;
+  %i.gt.j = icmp sgt i32 %i, %j
+  call void @llvm.assume(i1 %i.gt.j)
+  %i.ge.j = icmp samesign uge i32 %i, %j
+  ret i1 %i.ge.j
+}
+
+define i32 @gt_implies_sge_dominating(i32 %a, i32 %len) {
+; CHECK-LABEL: define i32 @gt_implies_sge_dominating(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    [[A_GT_LEN:%.*]] = icmp samesign ugt i32 [[A]], [[LEN]]
+; CHECK-NEXT:    br i1 [[A_GT_LEN]], label %[[TAKEN:.*]], label %[[END:.*]]
+; CHECK:       [[TAKEN]]:
+; CHECK-NEXT:    [[A_SGE_LEN:%.*]] = icmp sge i32 [[A]], [[LEN]]
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_SGE_LEN]], i32 30, i32 0
+; CHECK-NEXT:    br label %[[END]]
+; CHECK:       [[END]]:
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+entry:
+  %a.gt.len = icmp samesign ugt i32 %a, %len
+  br i1 %a.gt.len, label %taken, label %end
+
+taken:
+  %a.sge.len = icmp sge i32 %a, %len
+  %c = select i1 %a.sge.len, i32 30, i32 0
+  br label %end
+
+end:
+  %res = phi i32 [-1, %entry], [%c, %taken]
+  ret i32 %res
+}
+
+define i32 @gt_implies_sge_dominating_cr(i32 %a, i32 %len) {
+; CHECK-LABEL: define i32 @gt_implies_sge_dominating_cr(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    [[A_GT_20:%.*]] = icmp samesign ugt i32 [[A]], 20
+; CHECK-NEXT:    br i1 [[A_GT_20]], label %[[TAKEN:.*]], label %[[END:.*]]
+; CHECK:       [[TAKEN]]:
+; CHECK-NEXT:    [[A_SGE_MINUS_10:%.*]] = icmp sge i32 [[A]], 10
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_SGE_MINUS_10]], i32 30, i32 0
+; CHECK-NEXT:    br label %[[END]]
+; CHECK:       [[END]]:
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+entry:
+  %a.gt.20 = icmp samesign ugt i32 %a, 20
+  br i1 %a.gt.20, label %taken, label %end
+
+taken:
+  %a.sge.minus.10 = icmp sge i32 %a, 10
+  %c = select i1 %a.sge.minus.10, i32 30, i32 0
+  br label %end
+
+end:
+  %res = phi i32 [-1, %entry], [%c, %taken]
+  ret i32 %res
+}
+
+define i32 @sgt_implies_ge_dominating_cr(i32 %a, i32 %len) {
+; CHECK-LABEL: define i32 @sgt_implies_ge_dominating_cr(
+; CHECK-SAME: i32 [[A:%.*]], i32 [[LEN:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*]]:
+; CHECK-NEXT:    [[A_SGT_20:%.*]] = icmp sgt i32 [[A]], -10
+; CHECK-NEXT:    br i1 [[A_SGT_20]], label %[[TAKEN:.*]], label %[[END:.*]]
+; CHECK:       [[TAKEN]]:
+; CHECK-NEXT:    [[A_GE_10:%.*]] = icmp samesign uge i32 [[A]], -20
+; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_GE_10]], i32 30, i32 0
+; CHECK-NEXT:    br label %[[END]]
+; CHECK:       [[END]]:
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+entry:
+  %a.sgt.20 = icmp sgt i32 %a, -10
+  br i1 %a.sgt.20, label %taken, label %end
+
+taken:
+  %a.ge.10 = icmp samesign uge i32 %a, -20
+  %c = select i1 %a.ge.10, i32 30, i32 0
+  br label %end
+
+end:
+  %res = phi i32 [-1, %entry], [%c, %taken]
+  ret i32 %res
+}
+
+define i32 @gt_sub_nsw(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @gt_sub_nsw(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp samesign ugt i32 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_END:.*]]
+; CHECK:       [[COND_TRUE]]:
+; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[SUB]], 1
+; CHECK-NEXT:    [[NEG:%.*]] = xor i32 [[SUB]], -1
+; CHECK-NEXT:    [[ABSCOND:%.*]] = icmp samesign ult i32 [[SUB]], -1
+; CHECK-NEXT:    [[ABS:%.*]] = select i1 [[ABSCOND]], i32 [[NEG]], i32 [[ADD]]
+; CHECK-NEXT:    ret i32 [[ABS]]
+; CHECK:       [[COND_END]]:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cmp = icmp samesign ugt i32 %x, %y
+  br i1 %cmp, label %cond.true, label %cond.end
+
+cond.true:
+  %sub = sub nsw i32 %x, %y
+  %add = add nsw i32 %sub, 1
+  %neg = xor i32 %sub, -1
+  %abscond = icmp samesign ult i32 %sub, -1
+  %abs = select i1 %abscond, i32 %neg, i32 %add
+  ret i32 %abs
+
+cond.end:
+  ret i32 0
+}
+
+define i32 @ge_sub_nsw(i32 %x, i32 %y) {
+; CHECK-LABEL: define i32 @ge_sub_nsw(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp samesign uge i32 [[X]], [[Y]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[COND_TRUE:.*]], label %[[COND_END:.*]]
+; CHECK:       [[COND_TRUE]]:
+; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]]
+; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[SUB]], 1
+; CHECK-NEXT:    [[NEG:%.*]] = xor i32 [[SUB]], -1
+; CHECK-NEXT:    [[ABSCOND:%.*]] = icmp samesign ult i32 [[SUB]], -1
+; CHECK-NEXT:    [[ABS:%.*]] = select i1 [[ABSCOND]], i32 [[NEG]], i32 [[ADD]]
+; CHECK-NEXT:    ret i32 [[ABS]]
+; CHECK:       [[COND_END]]:
+; CHECK-NEXT:    ret i32 0
+;
+entry:
+  %cmp = icmp samesign uge i32 %x, %y
+  br i1 %cmp, label %cond.true, label %cond.end
+
+cond.true:
+  %sub = sub nsw i32 %x, %y
+  %add = add nsw i32 %sub, 1
+  %neg = xor i32 %sub, -1
+  %abscond = icmp samesign ult i32 %sub, -1
+  %abs = select i1 %abscond, i32 %neg, i32 %add
+  ret i32 %abs
+
+cond.end:
+  ret i32 0
+}

>From 8fcfd95cfaaa9f23b7a0ee1662975589bf33b1fc Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Mon, 16 Dec 2024 17:29:22 +0000
Subject: [PATCH 3/3] VT: teach implied-cond about samesign

Change isImpliedCondICmps() and its callees to operate on a CmpPredicate
instead of a CmpInst::Predicate, and teach them about samesign.

This patch also moves isImplied{True,False}ByMatchingCmp, one of the
callees, from CmpInst to ICmpInst, so that it can operate on
CmpPredicate instead of CmpInst::Predicate, and teaches it about
samesign.
---
 llvm/lib/Analysis/ValueTracking.cpp           | 76 +++++++++++--------
 .../implied-condition-samesign.ll             | 49 +++---------
 2 files changed, 55 insertions(+), 70 deletions(-)

diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index dd6ba8d6497f4e..514406e4a786ae 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -9141,7 +9141,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
 }
 
 /// Return true if "icmp Pred LHS RHS" is always true.
-static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
+static bool isTruePredicate(CmpPredicate Pred, const Value *LHS,
                             const Value *RHS) {
   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
     return true;
@@ -9223,8 +9223,8 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
 /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
 static std::optional<bool>
-isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
-                      const Value *ARHS, const Value *BLHS, const Value *BRHS) {
+isImpliedCondOperands(CmpPredicate Pred, const Value *ALHS, const Value *ARHS,
+                      const Value *BLHS, const Value *BRHS) {
   switch (Pred) {
   default:
     return std::nullopt;
@@ -9262,9 +9262,8 @@ isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
 /// Return true if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is true.
 /// Return false if "icmp1 LPred X, Y" implies "icmp2 RPred X, Y" is false.
 /// Otherwise, return std::nullopt if we can't infer anything.
-static std::optional<bool>
-isImpliedCondMatchingOperands(CmpInst::Predicate LPred,
-                              CmpInst::Predicate RPred) {
+static std::optional<bool> isImpliedCondMatchingOperands(CmpPredicate LPred,
+                                                         CmpPredicate RPred) {
   if (ICmpInst::isImpliedTrueByMatchingCmp(LPred, RPred))
     return true;
   if (ICmpInst::isImpliedFalseByMatchingCmp(LPred, RPred))
@@ -9276,53 +9275,66 @@ isImpliedCondMatchingOperands(CmpInst::Predicate LPred,
 /// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
 /// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
 /// Otherwise, return std::nullopt if we can't infer anything.
-static std::optional<bool> isImpliedCondCommonOperandWithCR(
-    CmpInst::Predicate LPred, const ConstantRange &LCR,
-    CmpInst::Predicate RPred, const ConstantRange &RCR) {
-  ConstantRange DomCR = ConstantRange::makeAllowedICmpRegion(LPred, LCR);
-  // If all true values for lhs and true for rhs, lhs implies rhs
-  if (DomCR.icmp(RPred, RCR))
-    return true;
+static std::optional<bool>
+isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
+                                 CmpPredicate RPred, const ConstantRange &RCR) {
+  auto CRImpliesPred = [&](ConstantRange CR,
+                           CmpInst::Predicate Pred) -> std::optional<bool> {
+    // If all true values for lhs and true for rhs, lhs implies rhs
+    if (CR.icmp(Pred, RCR))
+      return true;
 
-  // If there is no overlap, lhs implies not rhs
-  if (DomCR.icmp(CmpInst::getInversePredicate(RPred), RCR))
-    return false;
+    // If there is no overlap, lhs implies not rhs
+    if (CR.icmp(CmpInst::getInversePredicate(Pred), RCR))
+      return false;
+
+    return std::nullopt;
+  };
+  if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
+                               RPred))
+    return Res;
+  if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
+    LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(LPred)
+                                : static_cast<CmpInst::Predicate>(LPred);
+    RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(RPred)
+                                : static_cast<CmpInst::Predicate>(RPred);
+    return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(LPred, LCR),
+                         RPred);
+  }
   return std::nullopt;
 }
 
 /// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
 /// is true.  Return false if LHS implies RHS is false. Otherwise, return
 /// std::nullopt if we can't infer anything.
-static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
-                                              CmpInst::Predicate RPred,
-                                              const Value *R0, const Value *R1,
-                                              const DataLayout &DL,
-                                              bool LHSIsTrue) {
+static std::optional<bool>
+isImpliedCondICmps(const ICmpInst *LHS, CmpPredicate RPred, const Value *R0,
+                   const Value *R1, const DataLayout &DL, bool LHSIsTrue) {
   Value *L0 = LHS->getOperand(0);
   Value *L1 = LHS->getOperand(1);
 
   // The rest of the logic assumes the LHS condition is true.  If that's not the
   // case, invert the predicate to make it so.
-  CmpInst::Predicate LPred =
-      LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate();
+  CmpPredicate LPred =
+      LHSIsTrue ? LHS->getCmpPredicate() : LHS->getInverseCmpPredicate();
 
   // We can have non-canonical operands, so try to normalize any common operand
   // to L0/R0.
   if (L0 == R1) {
     std::swap(R0, R1);
-    RPred = ICmpInst::getSwappedPredicate(RPred);
+    RPred = ICmpInst::getSwappedCmpPredicate(RPred);
   }
   if (R0 == L1) {
     std::swap(L0, L1);
-    LPred = ICmpInst::getSwappedPredicate(LPred);
+    LPred = ICmpInst::getSwappedCmpPredicate(LPred);
   }
   if (L1 == R1) {
     // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
     if (L0 != R0 || match(L0, m_ImmConstant())) {
       std::swap(L0, L1);
-      LPred = ICmpInst::getSwappedPredicate(LPred);
+      LPred = ICmpInst::getSwappedCmpPredicate(LPred);
       std::swap(R0, R1);
-      RPred = ICmpInst::getSwappedPredicate(RPred);
+      RPred = ICmpInst::getSwappedCmpPredicate(RPred);
     }
   }
 
@@ -9358,7 +9370,8 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
   // must be positive if X >= Y and no overflow".
   // Take SGT as an example:  L0:x > L1:y and C >= 0
   //                      ==> R0:(x -nsw y) < R1:(-C) is false
-  if ((LPred == ICmpInst::ICMP_SGT || LPred == ICmpInst::ICMP_SGE) &&
+  if ((ICmpInst::isSigned(LPred) || LPred.hasSameSign()) &&
+      (ICmpInst::isGE(LPred) || ICmpInst::isGT(LPred)) &&
       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
     if (match(R1, m_NonPositive()) &&
         isImpliedCondMatchingOperands(LPred, RPred) == false)
@@ -9367,7 +9380,8 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
 
   // Take SLT as an example:  L0:x < L1:y and C <= 0
   //                      ==> R0:(x -nsw y) < R1:(-C) is true
-  if ((LPred == ICmpInst::ICMP_SLT || LPred == ICmpInst::ICMP_SLE) &&
+  if ((ICmpInst::isSigned(LPred) || LPred.hasSameSign()) &&
+      (ICmpInst::isLE(LPred) || ICmpInst::isLT(LPred)) &&
       match(R0, m_NSWSub(m_Specific(L0), m_Specific(L1)))) {
     if (match(R1, m_NonNegative()) &&
         isImpliedCondMatchingOperands(LPred, RPred) == true)
@@ -9381,8 +9395,8 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
       match(L0, m_c_Add(m_Specific(L1), m_Specific(R1))))
     return CmpPredicate::getMatching(LPred, RPred).has_value();
 
-  if (LPred == RPred)
-    return isImpliedCondOperands(LPred, L0, L1, R0, R1);
+  if (auto P = CmpPredicate::getMatching(LPred, RPred))
+    return isImpliedCondOperands(*P, L0, L1, R0, R1);
 
   return std::nullopt;
 }
diff --git a/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
index 1d58ca5ead9cca..3985e62d142612 100644
--- a/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
+++ b/llvm/test/Analysis/ValueTracking/implied-condition-samesign.ll
@@ -4,11 +4,7 @@
 define i1 @incr_sle(i32 %i, i32 %len) {
 ; CHECK-LABEL: define i1 @incr_sle(
 ; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
-; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
-; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ugt i32 [[I]], [[LEN]]
-; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp sgt i32 [[I_INCR]], [[LEN]]
-; CHECK-NEXT:    [[RES:%.*]] = icmp sle i1 [[I_INCR_LT_LEN]], [[I_LT_LEN]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 true
 ;
   %i.incr = add nsw nuw i32 %i, 1
   %i.lt.len = icmp samesign ugt i32 %i, %len
@@ -20,11 +16,7 @@ define i1 @incr_sle(i32 %i, i32 %len) {
 define i1 @incr_sge(i32 %i, i32 %len) {
 ; CHECK-LABEL: define i1 @incr_sge(
 ; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
-; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
-; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ult i32 [[I]], [[LEN]]
-; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp slt i32 [[I_INCR]], [[LEN]]
-; CHECK-NEXT:    [[RES:%.*]] = icmp sge i1 [[I_INCR_LT_LEN]], [[I_LT_LEN]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 true
 ;
   %i.incr = add nsw nuw i32 %i, 1
   %i.lt.len = icmp samesign ult i32 %i, %len
@@ -36,11 +28,7 @@ define i1 @incr_sge(i32 %i, i32 %len) {
 define i1 @incr_ule(i32 %i, i32 %len) {
 ; CHECK-LABEL: define i1 @incr_ule(
 ; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
-; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
-; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ugt i32 [[I]], [[LEN]]
-; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp sgt i32 [[I_INCR]], [[LEN]]
-; CHECK-NEXT:    [[RES:%.*]] = icmp ule i1 [[I_LT_LEN]], [[I_INCR_LT_LEN]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 true
 ;
   %i.incr = add nsw nuw i32 %i, 1
   %i.lt.len = icmp samesign ugt i32 %i, %len
@@ -52,11 +40,7 @@ define i1 @incr_ule(i32 %i, i32 %len) {
 define i1 @incr_uge(i32 %i, i32 %len) {
 ; CHECK-LABEL: define i1 @incr_uge(
 ; CHECK-SAME: i32 [[I:%.*]], i32 [[LEN:%.*]]) {
-; CHECK-NEXT:    [[I_INCR:%.*]] = add nuw nsw i32 [[I]], 1
-; CHECK-NEXT:    [[I_LT_LEN:%.*]] = icmp samesign ult i32 [[I]], [[LEN]]
-; CHECK-NEXT:    [[I_INCR_LT_LEN:%.*]] = icmp slt i32 [[I_INCR]], [[LEN]]
-; CHECK-NEXT:    [[RES:%.*]] = icmp uge i1 [[I_LT_LEN]], [[I_INCR_LT_LEN]]
-; CHECK-NEXT:    ret i1 [[RES]]
+; CHECK-NEXT:    ret i1 true
 ;
   %i.incr = add nsw nuw i32 %i, 1
   %i.lt.len = icmp samesign ult i32 %i, %len
@@ -70,8 +54,7 @@ define i1 @sgt_implies_ge_via_assume(i32 %i, i32 %j) {
 ; CHECK-SAME: i32 [[I:%.*]], i32 [[J:%.*]]) {
 ; CHECK-NEXT:    [[I_GT_J:%.*]] = icmp sgt i32 [[I]], [[J]]
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[I_GT_J]])
-; CHECK-NEXT:    [[I_GE_J:%.*]] = icmp samesign uge i32 [[I]], [[J]]
-; CHECK-NEXT:    ret i1 [[I_GE_J]]
+; CHECK-NEXT:    ret i1 true
 ;
   %i.gt.j = icmp sgt i32 %i, %j
   call void @llvm.assume(i1 %i.gt.j)
@@ -86,11 +69,9 @@ define i32 @gt_implies_sge_dominating(i32 %a, i32 %len) {
 ; CHECK-NEXT:    [[A_GT_LEN:%.*]] = icmp samesign ugt i32 [[A]], [[LEN]]
 ; CHECK-NEXT:    br i1 [[A_GT_LEN]], label %[[TAKEN:.*]], label %[[END:.*]]
 ; CHECK:       [[TAKEN]]:
-; CHECK-NEXT:    [[A_SGE_LEN:%.*]] = icmp sge i32 [[A]], [[LEN]]
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_SGE_LEN]], i32 30, i32 0
 ; CHECK-NEXT:    br label %[[END]]
 ; CHECK:       [[END]]:
-; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ 30, %[[TAKEN]] ]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -114,11 +95,9 @@ define i32 @gt_implies_sge_dominating_cr(i32 %a, i32 %len) {
 ; CHECK-NEXT:    [[A_GT_20:%.*]] = icmp samesign ugt i32 [[A]], 20
 ; CHECK-NEXT:    br i1 [[A_GT_20]], label %[[TAKEN:.*]], label %[[END:.*]]
 ; CHECK:       [[TAKEN]]:
-; CHECK-NEXT:    [[A_SGE_MINUS_10:%.*]] = icmp sge i32 [[A]], 10
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_SGE_MINUS_10]], i32 30, i32 0
 ; CHECK-NEXT:    br label %[[END]]
 ; CHECK:       [[END]]:
-; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ 30, %[[TAKEN]] ]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -142,11 +121,9 @@ define i32 @sgt_implies_ge_dominating_cr(i32 %a, i32 %len) {
 ; CHECK-NEXT:    [[A_SGT_20:%.*]] = icmp sgt i32 [[A]], -10
 ; CHECK-NEXT:    br i1 [[A_SGT_20]], label %[[TAKEN:.*]], label %[[END:.*]]
 ; CHECK:       [[TAKEN]]:
-; CHECK-NEXT:    [[A_GE_10:%.*]] = icmp samesign uge i32 [[A]], -20
-; CHECK-NEXT:    [[C:%.*]] = select i1 [[A_GE_10]], i32 30, i32 0
 ; CHECK-NEXT:    br label %[[END]]
 ; CHECK:       [[END]]:
-; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ [[C]], %[[TAKEN]] ]
+; CHECK-NEXT:    [[RES:%.*]] = phi i32 [ -1, %[[ENTRY]] ], [ 30, %[[TAKEN]] ]
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
 entry:
@@ -172,10 +149,7 @@ define i32 @gt_sub_nsw(i32 %x, i32 %y) {
 ; CHECK:       [[COND_TRUE]]:
 ; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[SUB]], 1
-; CHECK-NEXT:    [[NEG:%.*]] = xor i32 [[SUB]], -1
-; CHECK-NEXT:    [[ABSCOND:%.*]] = icmp samesign ult i32 [[SUB]], -1
-; CHECK-NEXT:    [[ABS:%.*]] = select i1 [[ABSCOND]], i32 [[NEG]], i32 [[ADD]]
-; CHECK-NEXT:    ret i32 [[ABS]]
+; CHECK-NEXT:    ret i32 [[ADD]]
 ; CHECK:       [[COND_END]]:
 ; CHECK-NEXT:    ret i32 0
 ;
@@ -204,10 +178,7 @@ define i32 @ge_sub_nsw(i32 %x, i32 %y) {
 ; CHECK:       [[COND_TRUE]]:
 ; CHECK-NEXT:    [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]]
 ; CHECK-NEXT:    [[ADD:%.*]] = add nsw i32 [[SUB]], 1
-; CHECK-NEXT:    [[NEG:%.*]] = xor i32 [[SUB]], -1
-; CHECK-NEXT:    [[ABSCOND:%.*]] = icmp samesign ult i32 [[SUB]], -1
-; CHECK-NEXT:    [[ABS:%.*]] = select i1 [[ABSCOND]], i32 [[NEG]], i32 [[ADD]]
-; CHECK-NEXT:    ret i32 [[ABS]]
+; CHECK-NEXT:    ret i32 [[ADD]]
 ; CHECK:       [[COND_END]]:
 ; CHECK-NEXT:    ret i32 0
 ;



More information about the llvm-commits mailing list