[llvm] 69acdfe - [SCEV] Prove implicaitons via AddRec start

Max Kazantsev via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 1 03:09:58 PDT 2020


Author: Max Kazantsev
Date: 2020-10-01T17:09:38+07:00
New Revision: 69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48

URL: https://github.com/llvm/llvm-project/commit/69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48
DIFF: https://github.com/llvm/llvm-project/commit/69acdfe075fa8eb18781f88f4d0cd1ea40fa6e48.diff

LOG: [SCEV] Prove implicaitons via AddRec start

If we know that some predicate is true for AddRec and an invariant
(w.r.t. this AddRec's loop), this fact is, in particular, true on the first
iteration. We can try to prove the facts we need using the start value.

The motivating example is proving things like
```
  isImpliedCondOperands(>=, X, 0, {X,+,-1}, 0}
```

Differential Revision: https://reviews.llvm.org/D88208
Reviewed By: reames

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index febca473776a..158257a5aa9a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1677,23 +1677,30 @@ class ScalarEvolution {
   getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) const;
 
   /// Test whether the condition described by Pred, LHS, and RHS is true
-  /// whenever the given FoundCondValue value evaluates to true.
+  /// whenever the given FoundCondValue value evaluates to true in given
+  /// Context. If Context is nullptr, then the found predicate is true
+  /// everywhere.
   bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
-                     const Value *FoundCondValue, bool Inverse);
+                     const Value *FoundCondValue, bool Inverse,
+                     const Instruction *Context = nullptr);
 
   /// Test whether the condition described by Pred, LHS, and RHS is true
   /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is
-  /// true.
+  /// true in given Context. If Context is nullptr, then the found predicate is
+  /// true everywhere.
   bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
                      ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
-                     const SCEV *FoundRHS);
+                     const SCEV *FoundRHS,
+                     const Instruction *Context = nullptr);
 
   /// Test whether the condition described by Pred, LHS, and RHS is true
   /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
-  /// true.
+  /// true in given Context. If Context is nullptr, then the found predicate is
+  /// true everywhere.
   bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS,
                              const SCEV *RHS, const SCEV *FoundLHS,
-                             const SCEV *FoundRHS);
+                             const SCEV *FoundRHS,
+                             const Instruction *Context = nullptr);
 
   /// Test whether the condition described by Pred, LHS, and RHS is true
   /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
@@ -1740,6 +1747,18 @@ class ScalarEvolution {
                                           const SCEV *FoundLHS,
                                           const SCEV *FoundRHS);
 
+  /// Test whether the condition described by Pred, LHS, and RHS is true
+  /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
+  /// true.
+  ///
+  /// This routine tries to weaken the known condition basing on fact that
+  /// FoundLHS is an AddRec.
+  bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred,
+                                           const SCEV *LHS, const SCEV *RHS,
+                                           const SCEV *FoundLHS,
+                                           const SCEV *FoundRHS,
+                                           const Instruction *Context);
+
   /// Test whether the condition described by Pred, LHS, and RHS is true
   /// whenever the condition described by Pred, FoundLHS, and FoundRHS is
   /// true.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e51b31673105..a3e454fefcf0 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9549,15 +9549,16 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
 
   // Try to prove (Pred, LHS, RHS) using isImpliedCond.
   auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
-    if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
+    const Instruction *Context = &BB->front();
+    if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context))
       return true;
     if (ProvingStrictComparison) {
       if (!ProvedNonStrictComparison)
-        ProvedNonStrictComparison =
-            isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
+        ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS,
+                                                  Condition, Inverse, Context);
       if (!ProvedNonEquality)
-        ProvedNonEquality =
-            isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
+        ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS,
+                                          Condition, Inverse, Context);
       if (ProvedNonStrictComparison && ProvedNonEquality)
         return true;
     }
@@ -9623,7 +9624,8 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
 
 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
                                     const SCEV *RHS,
-                                    const Value *FoundCondValue, bool Inverse) {
+                                    const Value *FoundCondValue, bool Inverse,
+                                    const Instruction *Context) {
   if (!PendingLoopPredicates.insert(FoundCondValue).second)
     return false;
 
@@ -9634,12 +9636,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
     if (BO->getOpcode() == Instruction::And) {
       if (!Inverse)
-        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
-               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
+                             Context) ||
+               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
+                             Context);
     } else if (BO->getOpcode() == Instruction::Or) {
       if (Inverse)
-        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
-               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
+                             Context) ||
+               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
+                             Context);
     }
   }
 
@@ -9657,14 +9663,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
   const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
 
-  return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
+  return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context);
 }
 
 bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
                                     const SCEV *RHS,
                                     ICmpInst::Predicate FoundPred,
-                                    const SCEV *FoundLHS,
-                                    const SCEV *FoundRHS) {
+                                    const SCEV *FoundLHS, const SCEV *FoundRHS,
+                                    const Instruction *Context) {
   // Balance the types.
   if (getTypeSizeInBits(LHS->getType()) <
       getTypeSizeInBits(FoundLHS->getType())) {
@@ -9708,16 +9714,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
 
   // Check whether the found predicate is the same as the desired predicate.
   if (FoundPred == Pred)
-    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
 
   // Check whether swapping the found predicate makes it the same as the
   // desired predicate.
   if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
     if (isa<SCEVConstant>(RHS))
-      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
+      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context);
     else
-      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
-                                   RHS, LHS, FoundLHS, FoundRHS);
+      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS,
+                                   LHS, FoundLHS, FoundRHS, Context);
   }
 
   // Unsigned comparison is the same as signed comparison when both the operands
@@ -9725,7 +9731,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   if (CmpInst::isUnsigned(FoundPred) &&
       CmpInst::getSignedPredicate(FoundPred) == Pred &&
       isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
-    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
 
   // Check if we can make progress by sharpening ranges.
   if (FoundPred == ICmpInst::ICMP_NE &&
@@ -9762,8 +9768,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
         case ICmpInst::ICMP_UGE:
           // We know V `Pred` SharperMin.  If this implies LHS `Pred`
           // RHS, we're done.
-          if (isImpliedCondOperands(Pred, LHS, RHS, V,
-                                    getConstant(SharperMin)))
+          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
+                                    Context))
             return true;
           LLVM_FALLTHROUGH;
 
@@ -9778,7 +9784,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
           //
           // If V `Pred` Min implies LHS `Pred` RHS, we're done.
 
-          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min),
+                                    Context))
             return true;
           break;
 
@@ -9786,14 +9793,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
         case ICmpInst::ICMP_SLE:
         case ICmpInst::ICMP_ULE:
           if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
-                                    LHS, V, getConstant(SharperMin)))
+                                    LHS, V, getConstant(SharperMin), Context))
             return true;
           LLVM_FALLTHROUGH;
 
         case ICmpInst::ICMP_SLT:
         case ICmpInst::ICMP_ULT:
           if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
-                                    LHS, V, getConstant(Min)))
+                                    LHS, V, getConstant(Min), Context))
             return true;
           break;
 
@@ -9807,11 +9814,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
   // Check whether the actual condition is beyond sufficient.
   if (FoundPred == ICmpInst::ICMP_EQ)
     if (ICmpInst::isTrueWhenEqual(Pred))
-      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
+      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context))
         return true;
   if (Pred == ICmpInst::ICMP_NE)
     if (!ICmpInst::isTrueWhenEqual(FoundPred))
-      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
+      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS,
+                                Context))
         return true;
 
   // Otherwise assume the worst.
@@ -9890,6 +9898,44 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
   return None;
 }
 
+bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
+    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+    const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
+  // Try to recognize the following pattern:
+  //
+  //   FoundRHS = ...
+  // ...
+  // loop:
+  //   FoundLHS = {Start,+,W}
+  // context_bb: // Basic block from the same loop
+  //   known(Pred, FoundLHS, FoundRHS)
+  //
+  // If some predicate is known in the context of a loop, it is also known on
+  // each iteration of this loop, including the first iteration. Therefore, in
+  // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
+  // prove the original pred using this fact.
+  if (!Context)
+    return false;
+  // Make sure AR varies in the context block.
+  if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
+    if (!AR->getLoop()->contains(Context->getParent()))
+      return false;
+    if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
+      return false;
+    return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
+  }
+
+  if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
+    if (!AR->getLoop()->contains(Context))
+      return false;
+    if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
+      return false;
+    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
+  }
+
+  return false;
+}
+
 bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
     ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
     const SCEV *FoundLHS, const SCEV *FoundRHS) {
@@ -10080,13 +10126,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
 bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
                                             const SCEV *LHS, const SCEV *RHS,
                                             const SCEV *FoundLHS,
-                                            const SCEV *FoundRHS) {
+                                            const SCEV *FoundRHS,
+                                            const Instruction *Context) {
   if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
     return true;
 
   if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
     return true;
 
+  if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
+                                          Context))
+    return true;
+
   return isImpliedCondOperandsHelper(Pred, LHS, RHS,
                                      FoundLHS, FoundRHS) ||
          // ~x < ~y --> x > y

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index ff33495f2271..e5ffc21fb664 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1251,4 +1251,36 @@ TEST_F(ScalarEvolutionsTest, SCEVgetExitLimitForGuardedLoop) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ImpliedViaAddRecStart) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(
+      "define void @foo(i32* %p) { "
+      "entry: "
+      "  %x = load i32, i32* %p, !range !0 "
+      "  br label %loop "
+      "loop: "
+      "  %iv = phi i32 [ %x, %entry], [%iv.next, %backedge] "
+      "  %ne.check = icmp ne i32 %iv, 0 "
+      "  br i1 %ne.check, label %backedge, label %exit "
+      "backedge: "
+      "  %iv.next = add i32 %iv, -1 "
+      "  br label %loop "
+      "exit:"
+      "  ret void "
+      "} "
+      "!0 = !{i32 0, i32 2147483647}",
+      Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    auto *X = SE.getSCEV(getInstructionByName(F, "x"));
+    auto *Context = getInstructionByName(F, "iv.next");
+    EXPECT_TRUE(SE.isKnownPredicateAt(ICmpInst::ICMP_NE, X,
+                                      SE.getZero(X->getType()), Context));
+  });
+}
+
 }  // end namespace llvm


        


More information about the llvm-commits mailing list