[llvm] 678f32a - [ValueTracking] Add more conditions in to `isTruePredicate`

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 4 10:43:15 PDT 2024


Author: Noah Goldstein
Date: 2024-04-04T12:42:58-05:00
New Revision: 678f32ab66508aea2068a5e4e07d53b71ce5cf31

URL: https://github.com/llvm/llvm-project/commit/678f32ab66508aea2068a5e4e07d53b71ce5cf31
DIFF: https://github.com/llvm/llvm-project/commit/678f32ab66508aea2068a5e4e07d53b71ce5cf31.diff

LOG: [ValueTracking] Add more conditions in to `isTruePredicate`

There is one notable "regression". This patch replaces the bespoke `or
disjoint` logic we a direct match. This means we fail some
simplification during `instsimplify`.
All the cases we fail in `instsimplify` we do handle in `instcombine`
as we add `disjoint` flags.

Other than that, just some basic cases.

See proofs: https://alive2.llvm.org/ce/z/_-g7C8

Closes #86083

Added: 
    

Modified: 
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/test/Transforms/InstCombine/implies.ll
    llvm/test/Transforms/InstSimplify/implies.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 33a69861cc3c57..5ad4da43bca7db 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8393,8 +8393,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
 
 /// Return true if "icmp Pred LHS RHS" is always true.
 static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
-                            const Value *RHS, const DataLayout &DL,
-                            unsigned Depth) {
+                            const Value *RHS) {
   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
     return true;
 
@@ -8406,8 +8405,26 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
     const APInt *C;
 
     // LHS s<= LHS +_{nsw} C   if C >= 0
-    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))))
+    // LHS s<= LHS | C         if C >= 0
+    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))) ||
+        match(RHS, m_Or(m_Specific(LHS), m_APInt(C))))
       return !C->isNegative();
+
+    // LHS s<= smax(LHS, V) for any V
+    if (match(RHS, m_c_SMax(m_Specific(LHS), m_Value())))
+      return true;
+
+    // smin(RHS, V) s<= RHS for any V
+    if (match(LHS, m_c_SMin(m_Specific(RHS), m_Value())))
+      return true;
+
+    // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
+    const Value *X;
+    const APInt *CLHS, *CRHS;
+    if (match(LHS, m_NSWAddLike(m_Value(X), m_APInt(CLHS))) &&
+        match(RHS, m_NSWAddLike(m_Specific(X), m_APInt(CRHS))))
+      return CLHS->sle(*CRHS);
+
     return false;
   }
 
@@ -8417,34 +8434,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
         cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
       return true;
 
+    // LHS u<= LHS | V for any V
+    if (match(RHS, m_c_Or(m_Specific(LHS), m_Value())))
+      return true;
+
+    // LHS u<= umax(LHS, V) for any V
+    if (match(RHS, m_c_UMax(m_Specific(LHS), m_Value())))
+      return true;
+
     // RHS >> V u<= RHS for any V
     if (match(LHS, m_LShr(m_Specific(RHS), m_Value())))
       return true;
 
-    // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
-    auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B,
-                                       const Value *&X,
-                                       const APInt *&CA, const APInt *&CB) {
-      if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) &&
-          match(B, m_NUWAdd(m_Specific(X), m_APInt(CB))))
-        return true;
+    // RHS u/ C_ugt_1 u<= RHS
+    const APInt *C;
+    if (match(LHS, m_UDiv(m_Specific(RHS), m_APInt(C))) && C->ugt(1))
+      return true;
 
-      // If X & C == 0 then (X | C) == X +_{nuw} C
-      if (match(A, m_Or(m_Value(X), m_APInt(CA))) &&
-          match(B, m_Or(m_Specific(X), m_APInt(CB)))) {
-        KnownBits Known(CA->getBitWidth());
-        computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr,
-                         /*CxtI*/ nullptr, /*DT*/ nullptr);
-        if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero))
-          return true;
-      }
+    // RHS & V u<= RHS for any V
+    if (match(LHS, m_c_And(m_Specific(RHS), m_Value())))
+      return true;
 
-      return false;
-    };
+    // umin(RHS, V) u<= RHS for any V
+    if (match(LHS, m_c_UMin(m_Specific(RHS), m_Value())))
+      return true;
 
+    // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
     const Value *X;
     const APInt *CLHS, *CRHS;
-    if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS))
+    if (match(LHS, m_NUWAddLike(m_Value(X), m_APInt(CLHS))) &&
+        match(RHS, m_NUWAddLike(m_Specific(X), m_APInt(CRHS))))
       return CLHS->ule(*CRHS);
 
     return false;
@@ -8456,37 +8475,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
 /// ALHS ARHS" is true.  Otherwise, return std::nullopt.
 static std::optional<bool>
 isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
-                      const Value *ARHS, const Value *BLHS, const Value *BRHS,
-                      const DataLayout &DL, unsigned Depth) {
+                      const Value *ARHS, const Value *BLHS, const Value *BRHS) {
   switch (Pred) {
   default:
     return std::nullopt;
 
   case CmpInst::ICMP_SLT:
   case CmpInst::ICMP_SLE:
-    if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) &&
+        isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_SGT:
   case CmpInst::ICMP_SGE:
-    if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS) &&
+        isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_ULT:
   case CmpInst::ICMP_ULE:
-    if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) &&
+        isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_UGT:
   case CmpInst::ICMP_UGE:
-    if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS) &&
+        isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS))
       return true;
     return std::nullopt;
   }
@@ -8530,7 +8548,7 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
                                               CmpInst::Predicate RPred,
                                               const Value *R0, const Value *R1,
                                               const DataLayout &DL,
-                                              bool LHSIsTrue, unsigned Depth) {
+                                              bool LHSIsTrue) {
   Value *L0 = LHS->getOperand(0);
   Value *L1 = LHS->getOperand(1);
 
@@ -8577,7 +8595,7 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
     return LPred == RPred;
 
   if (LPred == RPred)
-    return isImpliedCondOperands(LPred, L0, L1, R0, R1, DL, Depth);
+    return isImpliedCondOperands(LPred, L0, L1, R0, R1);
 
   return std::nullopt;
 }
@@ -8639,8 +8657,7 @@ llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred,
   // Both LHS and RHS are icmps.
   const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS);
   if (LHSCmp)
-    return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
-                              Depth);
+    return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue);
 
   /// The LHS should be an 'or', 'and', or a 'select' instruction.  We expect
   /// the RHS to be an icmp.

diff  --git a/llvm/test/Transforms/InstCombine/implies.ll b/llvm/test/Transforms/InstCombine/implies.ll
index 6741d59f4fccfa..c02d84d3f83711 100644
--- a/llvm/test/Transforms/InstCombine/implies.ll
+++ b/llvm/test/Transforms/InstCombine/implies.ll
@@ -7,8 +7,7 @@ define i1 @or_implies_sle(i8 %x, i8 %y, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[OR]], [[Y:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[R:%.*]] = icmp sle i8 [[X]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -49,9 +48,7 @@ define i1 @or_distjoint_implies_ule(i8 %x, i8 %y, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp ugt i8 [[X2]], [[Y:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[X1:%.*]] = or disjoint i8 [[X]], 23
-; CHECK-NEXT:    [[R:%.*]] = icmp ule i8 [[X1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -121,9 +118,7 @@ define i1 @src_or_distjoint_implies_sle(i8 %x, i8 %y, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[X2]], [[Y:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[X1:%.*]] = or disjoint i8 [[X]], 23
-; CHECK-NEXT:    [[R:%.*]] = icmp sle i8 [[X1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -169,9 +164,7 @@ define i1 @src_addnsw_implies_sle(i8 %x, i8 %y, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[X2]], [[Y:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[X1:%.*]] = add nsw i8 [[X]], 23
-; CHECK-NEXT:    [[R:%.*]] = icmp sle i8 [[X1]], [[Y]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -216,9 +209,7 @@ define i1 @src_and_implies_ult(i8 %x, i8 %y, i8 %z, i1 %other) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ult i8 [[X:%.*]], [[Z:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[T:%.*]], label [[F:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[Z]], [[X]]
-; CHECK-NEXT:    [[R:%.*]] = icmp ne i8 [[AND]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -280,8 +271,7 @@ define i1 @src_or_implies_ule(i8 %x, i8 %y, i8 %z, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp ugt i8 [[OR]], [[Z:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[R:%.*]] = icmp ule i8 [[X]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -322,9 +312,7 @@ define i1 @src_udiv_implies_ult(i8 %x, i8 %z, i1 %other) {
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ugt i8 [[Z:%.*]], [[X:%.*]]
 ; CHECK-NEXT:    br i1 [[COND]], label [[T:%.*]], label [[F:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[AND:%.*]] = udiv i8 [[X]], 3
-; CHECK-NEXT:    [[R:%.*]] = icmp ult i8 [[AND]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -345,9 +333,7 @@ define i1 @src_udiv_implies_ult2(i8 %x, i8 %z, i1 %other) {
 ; CHECK:       T:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ; CHECK:       F:
-; CHECK-NEXT:    [[AND:%.*]] = udiv i8 [[X]], 3
-; CHECK-NEXT:    [[R:%.*]] = icmp ult i8 [[AND]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ;
   %cond = icmp ule i8 %z, %x
   br i1 %cond, label %T, label %F
@@ -403,8 +389,7 @@ define i1 @src_umax_implies_ule(i8 %x, i8 %y, i8 %z, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp ugt i8 [[UM]], [[Z:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[R:%.*]] = icmp ule i8 [[X]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;
@@ -424,8 +409,7 @@ define i1 @src_smax_implies_sle(i8 %x, i8 %y, i8 %z, i1 %other) {
 ; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[UM]], [[Z:%.*]]
 ; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
 ; CHECK:       T:
-; CHECK-NEXT:    [[R:%.*]] = icmp sle i8 [[X]], [[Z]]
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 true
 ; CHECK:       F:
 ; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
 ;

diff  --git a/llvm/test/Transforms/InstSimplify/implies.ll b/llvm/test/Transforms/InstSimplify/implies.ll
index 8a011908dd38ab..7e3cb656bce158 100644
--- a/llvm/test/Transforms/InstSimplify/implies.ll
+++ b/llvm/test/Transforms/InstSimplify/implies.ll
@@ -155,7 +155,13 @@ define i1 @test9(i32 %length.i, i32 %i) {
 
 define i1 @test10(i32 %length.i, i32 %x.full) {
 ; CHECK-LABEL: @test10(
-; CHECK-NEXT:    ret i1 true
+; CHECK-NEXT:    [[X:%.*]] = and i32 [[X_FULL:%.*]], -65536
+; CHECK-NEXT:    [[LARGE:%.*]] = or i32 [[X]], 100
+; CHECK-NEXT:    [[SMALL:%.*]] = or i32 [[X]], 90
+; CHECK-NEXT:    [[KNOWN:%.*]] = icmp ult i32 [[LARGE]], [[LENGTH_I:%.*]]
+; CHECK-NEXT:    [[TO_PROVE:%.*]] = icmp ult i32 [[SMALL]], [[LENGTH_I]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp ule i1 [[KNOWN]], [[TO_PROVE]]
+; CHECK-NEXT:    ret i1 [[RES]]
 ;
   %x = and i32 %x.full, 4294901760  ;; 4294901760 == 0xffff0000
   %large = or i32 %x, 100
@@ -229,7 +235,13 @@ define i1 @test13(i32 %length.i, i32 %x) {
 
 define i1 @test14(i32 %length.i, i32 %x.full) {
 ; CHECK-LABEL: @test14(
-; CHECK-NEXT:    ret i1 true
+; CHECK-NEXT:    [[X:%.*]] = and i32 [[X_FULL:%.*]], -61681
+; CHECK-NEXT:    [[LARGE:%.*]] = or i32 [[X]], 8224
+; CHECK-NEXT:    [[SMALL:%.*]] = or i32 [[X]], 4112
+; CHECK-NEXT:    [[KNOWN:%.*]] = icmp ult i32 [[LARGE]], [[LENGTH_I:%.*]]
+; CHECK-NEXT:    [[TO_PROVE:%.*]] = icmp ult i32 [[SMALL]], [[LENGTH_I]]
+; CHECK-NEXT:    [[RES:%.*]] = icmp ule i1 [[KNOWN]], [[TO_PROVE]]
+; CHECK-NEXT:    ret i1 [[RES]]
 ;
   %x = and i32 %x.full, 4294905615  ;; 4294905615 == 0xffff0f0f
   %large = or i32 %x, 8224 ;; == 0x2020


        


More information about the llvm-commits mailing list