[llvm] r282606 - [SCEV] Use a SmallPtrSet as a temporary union predicate; NFC

Sanjoy Das via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 28 10:14:59 PDT 2016


Author: sanjoy
Date: Wed Sep 28 12:14:58 2016
New Revision: 282606

URL: http://llvm.org/viewvc/llvm-project?rev=282606&view=rev
Log:
[SCEV] Use a SmallPtrSet as a temporary union predicate; NFC

Summary:
Instead of creating and destroying SCEVUnionPredicate instances (which
internally creates and destroys a DenseMap), use temporary SmallPtrSet
instances of remember the set of predicates that will get reified into a
SCEVUnionPredicate.

Reviewers: silviu.baranga, sbaranga

Subscribers: sanjoy, mcrosier, llvm-commits, mzolotukhin

Differential Revision: https://reviews.llvm.org/D25000

Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp

Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=282606&r1=282605&r2=282606&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Wed Sep 28 12:14:58 2016
@@ -551,19 +551,36 @@ private:
     const SCEV *ExactNotTaken;
     const SCEV *MaxNotTaken;
 
-    /// A predicate union guard for this ExitLimit. The result is only
-    /// valid if this predicate evaluates to 'true' at run-time.
-    SCEVUnionPredicate Predicate;
+    /// A set of predicate guards for this ExitLimit. The result is only valid
+    /// if all of the predicates in \c Predicates evaluate to 'true' at
+    /// run-time.
+    SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+
+    void addPredicate(const SCEVPredicate *P) {
+      assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
+      Predicates.insert(P);
+    }
 
     /*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {}
 
-    ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P)
-        : ExactNotTaken(E), MaxNotTaken(M), Predicate(P) {
+    ExitLimit(
+        const SCEV *E, const SCEV *M,
+        ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
+        : ExactNotTaken(E), MaxNotTaken(M) {
       assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
               !isa<SCEVCouldNotCompute>(MaxNotTaken)) &&
              "Exact is not allowed to be less precise than Max");
+      for (auto *PredSet : PredSetList)
+        for (auto *P : *PredSet)
+          addPredicate(P);
     }
 
+    ExitLimit(const SCEV *E, const SCEV *M,
+              const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
+        : ExitLimit(E, M, {&PredSet}) {}
+
+    ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {}
+
     /// Test whether this ExitLimit contains any computed information, or
     /// whether it's all SCEVCouldNotCompute values.
     bool hasAnyInfo() const {
@@ -1581,9 +1598,9 @@ public:
                                     SCEVUnionPredicate &A);
   /// Tries to convert the \p S expression to an AddRec expression,
   /// adding additional predicates to \p Preds as required.
-  const SCEVAddRecExpr *
-  convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
-                                    SCEVUnionPredicate &Preds);
+  const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
+      const SCEV *S, const Loop *L,
+      SmallPtrSetImpl<const SCEVPredicate *> &Preds);
 
 private:
   /// Compute the backedge taken count knowing the interval difference, the

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=282606&r1=282605&r2=282606&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Wed Sep 28 12:14:58 2016
@@ -5656,11 +5656,14 @@ ScalarEvolution::BackedgeTakenInfo::Back
       [&](const EdgeExitInfo &EEI) {
         BasicBlock *ExitBB = EEI.first;
         const ExitLimit &EL = EEI.second;
-        if (EL.Predicate.isAlwaysTrue())
+        if (EL.Predicates.empty())
           return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr);
-        return ExitNotTakenInfo(
-            ExitBB, EL.ExactNotTaken,
-            llvm::make_unique<SCEVUnionPredicate>(std::move(EL.Predicate)));
+
+        std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate);
+        for (auto *Pred : EL.Predicates)
+          Predicate->add(Pred);
+
+        return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate));
       });
 }
 
@@ -5691,7 +5694,7 @@ ScalarEvolution::computeBackedgeTakenCou
     BasicBlock *ExitBB = ExitingBlocks[i];
     ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
 
-    assert((AllowPredicates || EL.Predicate.isAlwaysTrue()) &&
+    assert((AllowPredicates || EL.Predicates.empty()) &&
            "Predicated exit limit when predicates are not allowed!");
 
     // 1. For each exit that can be computed, add an entry to ExitCounts.
@@ -5861,9 +5864,6 @@ ScalarEvolution::computeExitLimitFromCon
           BECount = EL0.ExactNotTaken;
       }
 
-      SCEVUnionPredicate NP;
-      NP.add(&EL0.Predicate);
-      NP.add(&EL1.Predicate);
       // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
       // to be more aggressive when computing BECount than when computing
       // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
@@ -5873,7 +5873,7 @@ ScalarEvolution::computeExitLimitFromCon
           !isa<SCEVCouldNotCompute>(BECount))
         MaxBECount = BECount;
 
-      return ExitLimit(BECount, MaxBECount, NP);
+      return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
     }
     if (BO->getOpcode() == Instruction::Or) {
       // Recurse on the operands of the or.
@@ -5912,10 +5912,7 @@ ScalarEvolution::computeExitLimitFromCon
           BECount = EL0.ExactNotTaken;
       }
 
-      SCEVUnionPredicate NP;
-      NP.add(&EL0.Predicate);
-      NP.add(&EL1.Predicate);
-      return ExitLimit(BECount, MaxBECount, NP);
+      return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates});
     }
   }
 
@@ -6300,8 +6297,7 @@ ScalarEvolution::ExitLimit ScalarEvoluti
     unsigned BitWidth = getTypeSizeInBits(RHS->getType());
     const SCEV *UpperBound =
         getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth);
-    SCEVUnionPredicate P;
-    return ExitLimit(getCouldNotCompute(), UpperBound, P);
+    return ExitLimit(getCouldNotCompute(), UpperBound);
   }
 
   return getCouldNotCompute();
@@ -7062,7 +7058,7 @@ ScalarEvolution::howFarToZero(const SCEV
   // effectively V != 0.  We know and take advantage of the fact that this
   // expression only being used in a comparison by zero context.
 
-  SCEVUnionPredicate P;
+  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
   // If the value is a constant
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
     // If the value is already zero, the branch will execute zero times.
@@ -7075,7 +7071,7 @@ ScalarEvolution::howFarToZero(const SCEV
     // Try to make this an AddRec using runtime tests, in the first X
     // iterations of this loop, where X is the SCEV expression found by the
     // algorithm below.
-    AddRec = convertSCEVToAddRecWithPredicates(V, L, P);
+    AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
 
   if (!AddRec || AddRec->getLoop() != L)
     return getCouldNotCompute();
@@ -7097,7 +7093,7 @@ ScalarEvolution::howFarToZero(const SCEV
         // should not accept a root of 2.
         const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
         if (Val->isZero())
-          return ExitLimit(R1, R1, P); // We found a quadratic root!
+          return ExitLimit(R1, R1, Predicates); // We found a quadratic root!
       }
     }
     return getCouldNotCompute();
@@ -7154,7 +7150,7 @@ ScalarEvolution::howFarToZero(const SCEV
     else
       MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
                                          : -CR.getUnsignedMin());
-    return ExitLimit(Distance, MaxBECount, P);
+    return ExitLimit(Distance, MaxBECount, Predicates);
   }
 
   // As a special case, handle the instance where Step is a positive power of
@@ -7209,7 +7205,7 @@ ScalarEvolution::howFarToZero(const SCEV
 
       const SCEV *Limit =
           getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy);
-      return ExitLimit(Limit, Limit, P);
+      return ExitLimit(Limit, Limit, Predicates);
     }
   }
 
@@ -7222,14 +7218,14 @@ ScalarEvolution::howFarToZero(const SCEV
       loopHasNoAbnormalExits(AddRec->getLoop())) {
     const SCEV *Exact =
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
-    return ExitLimit(Exact, Exact, P);
+    return ExitLimit(Exact, Exact, Predicates);
   }
 
   // Then, try to solve the above equation provided that Start is constant.
   if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {
     const SCEV *E = SolveLinEquationWithOverflow(
         StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this);
-    return ExitLimit(E, E, P);
+    return ExitLimit(E, E, Predicates);
   }
   return getCouldNotCompute();
 }
@@ -8634,7 +8630,7 @@ ScalarEvolution::ExitLimit
 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
                                   const Loop *L, bool IsSigned,
                                   bool ControlsExit, bool AllowPredicates) {
-  SCEVUnionPredicate P;
+  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
   // We handle only IV < Invariant
   if (!isLoopInvariant(RHS, L))
     return getCouldNotCompute();
@@ -8646,7 +8642,7 @@ ScalarEvolution::howManyLessThans(const
     // Try to make this an AddRec using runtime tests, in the first X
     // iterations of this loop, where X is the SCEV expression found by the
     // algorithm below.
-    IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
+    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
     PredicatedIV = true;
   }
 
@@ -8762,14 +8758,14 @@ ScalarEvolution::howManyLessThans(const
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     MaxBECount = BECount;
 
-  return ExitLimit(BECount, MaxBECount, P);
+  return ExitLimit(BECount, MaxBECount, Predicates);
 }
 
 ScalarEvolution::ExitLimit
 ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
                                      const Loop *L, bool IsSigned,
                                      bool ControlsExit, bool AllowPredicates) {
-  SCEVUnionPredicate P;
+  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
   // We handle only IV > Invariant
   if (!isLoopInvariant(RHS, L))
     return getCouldNotCompute();
@@ -8779,7 +8775,7 @@ ScalarEvolution::howManyGreaterThans(con
     // Try to make this an AddRec using runtime tests, in the first X
     // iterations of this loop, where X is the SCEV expression found by the
     // algorithm below.
-    IV = convertSCEVToAddRecWithPredicates(LHS, L, P);
+    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
 
   // Avoid weird loops
   if (!IV || IV->getLoop() != L || !IV->isAffine())
@@ -8839,7 +8835,7 @@ ScalarEvolution::howManyGreaterThans(con
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     MaxBECount = BECount;
 
-  return ExitLimit(BECount, MaxBECount, P);
+  return ExitLimit(BECount, MaxBECount, Predicates);
 }
 
 const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
@@ -10161,25 +10157,34 @@ namespace {
 
 class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
 public:
-  // Rewrites \p S in the context of a loop L and the predicate A.
-  // If Assume is true, rewrite is free to add further predicates to A
-  // such that the result will be an AddRecExpr.
+  /// Rewrites \p S in the context of a loop L and the SCEV predication
+  /// infrastructure.
+  ///
+  /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
+  /// equivalences present in \p Pred.
+  ///
+  /// If \p NewPreds is non-null, rewrite is free to add further predicates to
+  /// \p NewPreds such that the result will be an AddRecExpr.
   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
-                             SCEVUnionPredicate &A, bool Assume) {
-    SCEVPredicateRewriter Rewriter(L, SE, A, Assume);
+                             SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+                             SCEVUnionPredicate *Pred) {
+    SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
     return Rewriter.visit(S);
   }
 
   SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
-                        SCEVUnionPredicate &P, bool Assume)
-      : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {}
+                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+                        SCEVUnionPredicate *Pred)
+      : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
 
   const SCEV *visitUnknown(const SCEVUnknown *Expr) {
-    auto ExprPreds = P.getPredicatesForExpr(Expr);
-    for (auto *Pred : ExprPreds)
-      if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
-        if (IPred->getLHS() == Expr)
-          return IPred->getRHS();
+    if (Pred) {
+      auto ExprPreds = Pred->getPredicatesForExpr(Expr);
+      for (auto *Pred : ExprPreds)
+        if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred))
+          if (IPred->getLHS() == Expr)
+            return IPred->getRHS();
+    }
 
     return Expr;
   }
@@ -10220,32 +10225,31 @@ private:
   bool addOverflowAssumption(const SCEVAddRecExpr *AR,
                              SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {
     auto *A = SE.getWrapPredicate(AR, AddedFlags);
-    if (!Assume) {
+    if (!NewPreds) {
       // Check if we've already made this assumption.
-      if (P.implies(A))
-        return true;
-      return false;
+      return Pred && Pred->implies(A);
     }
-    P.add(A);
+    NewPreds->insert(A);
     return true;
   }
 
-  SCEVUnionPredicate &P;
+  SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
+  SCEVUnionPredicate *Pred;
   const Loop *L;
-  bool Assume;
 };
 } // end anonymous namespace
 
 const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
                                                    SCEVUnionPredicate &Preds) {
-  return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false);
+  return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
 }
 
-const SCEVAddRecExpr *
-ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,
-                                                   SCEVUnionPredicate &Preds) {
-  SCEVUnionPredicate TransformPreds;
-  S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true);
+const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
+    const SCEV *S, const Loop *L,
+    SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
+
+  SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
+  S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
 
   if (!AddRec)
@@ -10253,7 +10257,9 @@ ScalarEvolution::convertSCEVToAddRecWith
 
   // Since the transformation was successful, we can now transfer the SCEV
   // predicates.
-  Preds.add(&TransformPreds);
+  for (auto *P : TransformPreds)
+    Preds.insert(P);
+
   return AddRec;
 }
 
@@ -10480,11 +10486,15 @@ bool PredicatedScalarEvolution::hasNoOve
 
 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
   const SCEV *Expr = this->getSCEV(V);
-  auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds);
+  SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
+  auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
 
   if (!New)
     return nullptr;
 
+  for (auto *P : NewPreds)
+    Preds.add(P);
+
   updateGeneration();
   RewriteMap[SE.getSCEV(V)] = {Generation, New};
   return New;




More information about the llvm-commits mailing list