[llvm] 6022a3a - [SCEV] Store predicates for EL/ENT in SmallVector.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 28 13:24:32 PDT 2024


Author: Florian Hahn
Date: 2024-09-28T21:24:17+01:00
New Revision: 6022a3a05f951632022c84416209fe6d70d9105c

URL: https://github.com/llvm/llvm-project/commit/6022a3a05f951632022c84416209fe6d70d9105c
DIFF: https://github.com/llvm/llvm-project/commit/6022a3a05f951632022c84416209fe6d70d9105c.diff

LOG: [SCEV] Store predicates for EL/ENT in SmallVector.

Store predicates in ExitLimit and ExitNotTaken in a SmallVector instead
of a SmallPtrSet. This guarantees the predicates can be iterated on in a
predictable manner. This ensures the predicates can be printed and
generated in a predictable order.

This shifts de-duplication of predicates to construction time for
ExitLimit. ExitNotTaken just takes predicates from ExitLimit, so they
should also be free of duplicates.

This was exposed by 2f7ccaf4a8565628a4c7d2b5a49bb45478940be6
(https://github.com/llvm/llvm-project/pull/108777).

Should fix https://lab.llvm.org/buildbot/#/builders/110/builds/1494.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 68b860725752d0..179a2c38d9d3c2 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1125,15 +1125,10 @@ class ScalarEvolution {
     // Not taken either exactly ConstantMaxNotTaken or zero times
     bool MaxOrZero = false;
 
-    /// A set of predicate guards for this ExitLimit. The result is only valid
-    /// if all of the predicates in \c Predicates evaluate to 'true' at
+    /// A vector of predicate guards for this ExitLimit. The result is only
+    /// valid if all of the predicates in \c Predicates evaluate to 'true' at
     /// run-time.
-    SmallPtrSet<const SCEVPredicate *, 4> Predicates;
-
-    void addPredicate(const SCEVPredicate *P) {
-      assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
-      Predicates.insert(P);
-    }
+    SmallVector<const SCEVPredicate *, 4> Predicates;
 
     /// Construct either an exact exit limit from a constant, or an unknown
     /// one from a SCEVCouldNotCompute.  No other types of SCEVs are allowed
@@ -1142,12 +1137,11 @@ class ScalarEvolution {
 
     ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken,
               const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
-              ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *>
-                  PredSetList = {});
+              ArrayRef<ArrayRef<const SCEVPredicate *>> PredLists = {});
 
     ExitLimit(const SCEV *E, const SCEV *ConstantMaxNotTaken,
               const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
-              const SmallPtrSetImpl<const SCEVPredicate *> &PredSet);
+              ArrayRef<const SCEVPredicate *> PredList);
 
     /// Test whether this ExitLimit contains any computed information, or
     /// whether it's all SCEVCouldNotCompute values.
@@ -1297,7 +1291,7 @@ class ScalarEvolution {
   /// adding additional predicates to \p Preds as required.
   const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates(
       const SCEV *S, const Loop *L,
-      SmallPtrSetImpl<const SCEVPredicate *> &Preds);
+      SmallVectorImpl<const SCEVPredicate *> &Preds);
 
   /// Compute \p LHS - \p RHS and returns the result as an APInt if it is a
   /// constant, and std::nullopt if it isn't.
@@ -1489,12 +1483,13 @@ class ScalarEvolution {
     const SCEV *ExactNotTaken;
     const SCEV *ConstantMaxNotTaken;
     const SCEV *SymbolicMaxNotTaken;
-    SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+    SmallVector<const SCEVPredicate *, 4> Predicates;
 
-    explicit ExitNotTakenInfo(
-        PoisoningVH<BasicBlock> ExitingBlock, const SCEV *ExactNotTaken,
-        const SCEV *ConstantMaxNotTaken, const SCEV *SymbolicMaxNotTaken,
-        const SmallPtrSet<const SCEVPredicate *, 4> &Predicates)
+    explicit ExitNotTakenInfo(PoisoningVH<BasicBlock> ExitingBlock,
+                              const SCEV *ExactNotTaken,
+                              const SCEV *ConstantMaxNotTaken,
+                              const SCEV *SymbolicMaxNotTaken,
+                              ArrayRef<const SCEVPredicate *> Predicates)
         : ExitingBlock(ExitingBlock), ExactNotTaken(ExactNotTaken),
           ConstantMaxNotTaken(ConstantMaxNotTaken),
           SymbolicMaxNotTaken(SymbolicMaxNotTaken), Predicates(Predicates) {}

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d5db3263294a6d..c939270ed39a65 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8693,12 +8693,12 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
 }
 
 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
-    : ExitLimit(E, E, E, false, {}) {}
+    : ExitLimit(E, E, E, false) {}
 
 ScalarEvolution::ExitLimit::ExitLimit(
     const SCEV *E, const SCEV *ConstantMaxNotTaken,
     const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
-    ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList)
+    ArrayRef<ArrayRef<const SCEVPredicate *>> PredLists)
     : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
       SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
   // If we prove the max count is zero, so is the symbolic bound.  This happens
@@ -8721,9 +8721,15 @@ ScalarEvolution::ExitLimit::ExitLimit(
   assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
           isa<SCEVConstant>(ConstantMaxNotTaken)) &&
          "No point in having a non-constant max backedge taken count!");
-  for (const auto *PredSet : PredSetList)
-    for (const auto *P : *PredSet)
-      addPredicate(P);
+  SmallPtrSet<const SCEVPredicate *, 4> SeenPreds;
+  for (const auto PredList : PredLists)
+    for (const auto *P : PredList) {
+      if (SeenPreds.contains(P))
+        continue;
+      assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
+      SeenPreds.insert(P);
+      Predicates.push_back(P);
+    }
   assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
          "Backedge count should be int");
   assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
@@ -8731,12 +8737,13 @@ ScalarEvolution::ExitLimit::ExitLimit(
          "Max backedge count should be int");
 }
 
-ScalarEvolution::ExitLimit::ExitLimit(
-    const SCEV *E, const SCEV *ConstantMaxNotTaken,
-    const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
-    const SmallPtrSetImpl<const SCEVPredicate *> &PredSet)
+ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E,
+                                      const SCEV *ConstantMaxNotTaken,
+                                      const SCEV *SymbolicMaxNotTaken,
+                                      bool MaxOrZero,
+                                      ArrayRef<const SCEVPredicate *> PredList)
     : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
-                { &PredSet }) {}
+                ArrayRef({PredList})) {}
 
 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
 /// computable exit into a persistent ExitNotTakenInfo array.
@@ -9098,7 +9105,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
     SymbolicMaxBECount =
         isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
   return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
-                   { &EL0.Predicates, &EL1.Predicates });
+                   {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
@@ -10131,7 +10138,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
 /// If the equation does not have a solution, SCEVCouldNotCompute is returned.
 static const SCEV *
 SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
-                             SmallPtrSetImpl<const SCEVPredicate *> *Predicates,
+                             SmallVectorImpl<const SCEVPredicate *> *Predicates,
 
                              ScalarEvolution &SE) {
   uint32_t BW = A.getBitWidth();
@@ -10162,7 +10169,7 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
       // Avoid adding a predicate that is known to be false.
       if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
         return SE.getCouldNotCompute();
-      Predicates->insert(SE.getEqualPredicate(URem, Zero));
+      Predicates->push_back(SE.getEqualPredicate(URem, Zero));
     }
   }
 
@@ -10466,7 +10473,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   // effectively V != 0.  We know and take advantage of the fact that this
   // expression only being used in a comparison by zero context.
 
-  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+  SmallVector<const SCEVPredicate *> Predicates;
   // If the value is a constant
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
     // If the value is already zero, the branch will execute zero times.
@@ -12885,7 +12892,7 @@ ScalarEvolution::ExitLimit
 ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
                                   const Loop *L, bool IsSigned,
                                   bool ControlsOnlyExit, bool AllowPredicates) {
-  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+  SmallVector<const SCEVPredicate *> Predicates;
 
   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
   bool PredicatedIV = false;
@@ -13325,7 +13332,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
 ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
     const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
     bool ControlsOnlyExit, bool AllowPredicates) {
-  SmallPtrSet<const SCEVPredicate *, 4> Predicates;
+  SmallVector<const SCEVPredicate *> Predicates;
   // We handle only IV > Invariant
   if (!isLoopInvariant(RHS, L))
     return getCouldNotCompute();
@@ -13695,7 +13702,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
       PrintSCEVWithTypeHint(OS, EC);
       if (isa<SCEVCouldNotCompute>(EC)) {
         // Retry with predicates.
-        SmallVector<const SCEVPredicate *, 4> Predicates;
+        SmallVector<const SCEVPredicate *> Predicates;
         EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
         if (!isa<SCEVCouldNotCompute>(EC)) {
           OS << "\n  predicated exit count for " << ExitingBlock->getName()
@@ -13747,7 +13754,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
       PrintSCEVWithTypeHint(OS, ExitBTC);
       if (isa<SCEVCouldNotCompute>(ExitBTC)) {
         // Retry with predicates.
-        SmallVector<const SCEVPredicate *, 4> Predicates;
+        SmallVector<const SCEVPredicate *> Predicates;
         ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
                                              ScalarEvolution::SymbolicMaximum);
         if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
@@ -14727,7 +14734,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
   /// If \p NewPreds is non-null, rewrite is free to add further predicates to
   /// \p NewPreds such that the result will be an AddRecExpr.
   static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
-                             SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
+                             SmallVectorImpl<const SCEVPredicate *> *NewPreds,
                              const SCEVPredicate *Pred) {
     SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
     return Rewriter.visit(S);
@@ -14783,9 +14790,10 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
   }
 
 private:
-  explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
-                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds,
-                        const SCEVPredicate *Pred)
+  explicit SCEVPredicateRewriter(
+      const Loop *L, ScalarEvolution &SE,
+      SmallVectorImpl<const SCEVPredicate *> *NewPreds,
+      const SCEVPredicate *Pred)
       : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
 
   bool addOverflowAssumption(const SCEVPredicate *P) {
@@ -14793,7 +14801,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
       // Check if we've already made this assumption.
       return Pred && Pred->implies(P);
     }
-    NewPreds->insert(P);
+    NewPreds->push_back(P);
     return true;
   }
 
@@ -14829,7 +14837,7 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
     return PredicatedRewrite->first;
   }
 
-  SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;
+  SmallVectorImpl<const SCEVPredicate *> *NewPreds;
   const SCEVPredicate *Pred;
   const Loop *L;
 };
@@ -14844,8 +14852,8 @@ ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,
 
 const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
     const SCEV *S, const Loop *L,
-    SmallPtrSetImpl<const SCEVPredicate *> &Preds) {
-  SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;
+    SmallVectorImpl<const SCEVPredicate *> &Preds) {
+  SmallVector<const SCEVPredicate *> TransformPreds;
   S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
   auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
 
@@ -14854,7 +14862,7 @@ const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(
 
   // Since the transformation was successful, we can now transfer the SCEV
   // predicates.
-  Preds.insert(TransformPreds.begin(), TransformPreds.end());
+  Preds.append(TransformPreds.begin(), TransformPreds.end());
 
   return AddRec;
 }
@@ -15101,7 +15109,7 @@ bool PredicatedScalarEvolution::hasNoOverflow(
 
 const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
   const SCEV *Expr = this->getSCEV(V);
-  SmallPtrSet<const SCEVPredicate *, 4> NewPreds;
+  SmallVector<const SCEVPredicate *, 4> NewPreds;
   auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
 
   if (!New)


        


More information about the llvm-commits mailing list