[llvm] d334fec - [SCEV] Make SCEVUnionPredicate externally immutable [NFC]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 9 13:47:37 PST 2022


Author: Philip Reames
Date: 2022-02-09T13:47:28-08:00
New Revision: d334fec1409c5b158bbed2f5694983cbb8a70f11

URL: https://github.com/llvm/llvm-project/commit/d334fec1409c5b158bbed2f5694983cbb8a70f11
DIFF: https://github.com/llvm/llvm-project/commit/d334fec1409c5b158bbed2f5694983cbb8a70f11.diff

LOG: [SCEV] Make SCEVUnionPredicate externally immutable [NFC]

This is the last major stepping stone before being able to allocate the node via the folding set allocator.  That will in turn allow more general SCEV predicate expression trees.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/unittests/Analysis/ScalarEvolutionTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 1af37422d5822..a81c19f7c7171 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -425,16 +425,16 @@ class SCEVUnionPredicate final : public SCEVPredicate {
   /// Maps SCEVs to predicates for quick look-ups.
   PredicateMap SCEVToPreds;
 
+  /// Adds a predicate to this union.
+  void add(const SCEVPredicate *N);
+
 public:
-  SCEVUnionPredicate();
+  SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds);
 
   const SmallVectorImpl<const SCEVPredicate *> &getPredicates() const {
     return Preds;
   }
 
-  /// Adds a predicate to this union.
-  void add(const SCEVPredicate *N);
-
   /// Returns a reference to a vector containing all predicates which apply to
   /// \p Expr.
   ArrayRef<const SCEVPredicate *> getPredicatesForExpr(const SCEV *Expr);
@@ -2254,7 +2254,7 @@ class PredicatedScalarEvolution {
 
   /// The SCEVPredicate that forms our context. We will rewrite all
   /// expressions assuming that this predicate true.
-  SCEVUnionPredicate Preds;
+  std::unique_ptr<SCEVUnionPredicate> Preds;
 
   /// Marks the version of the SCEV predicate used. When rewriting a SCEV
   /// expression we mark it with the version of the predicate. We use this to

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b0e04bc05188b..546268b4da20f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5489,8 +5489,8 @@ bool PredicatedScalarEvolution::areAddRecsEqualWithPreds(
     return true;
 
   auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
-    if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) &&
-        !Preds.implies(SE.getEqualPredicate(Expr2, Expr1)))
+    if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
+        !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
       return false;
     return true;
   };
@@ -12818,9 +12818,7 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
   if (!isa<SCEVCouldNotCompute>(PBT)) {
     OS << "Predicated backedge-taken count is " << *PBT << "\n";
     OS << " Predicates:\n";
-    SCEVUnionPredicate Dedup;
-    for (auto *P : Preds)
-      Dedup.add(P);
+    SCEVUnionPredicate Dedup(Preds);
     Dedup.print(OS, 4);
   } else {
     OS << "Unpredictable predicated backedge-taken count. ";
@@ -13807,8 +13805,11 @@ SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR,
 }
 
 /// Union predicates don't get cached so create a dummy set ID for it.
-SCEVUnionPredicate::SCEVUnionPredicate()
-    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {}
+SCEVUnionPredicate::SCEVUnionPredicate(ArrayRef<const SCEVPredicate *> Preds)
+  : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
+  for (auto *P : Preds)
+    add(P);
+}
 
 bool SCEVUnionPredicate::isAlwaysTrue() const {
   return all_of(Preds,
@@ -13864,7 +13865,10 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {
 
 PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,
                                                      Loop &L)
-    : SE(SE), L(L) {}
+    : SE(SE), L(L) {
+  SmallVector<const SCEVPredicate*, 4> Empty;
+  Preds = std::make_unique<SCEVUnionPredicate>(Empty);
+}
 
 void ScalarEvolution::registerUser(const SCEV *User,
                                    ArrayRef<const SCEV *> Ops) {
@@ -13889,7 +13893,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {
   if (Entry.second)
     Expr = Entry.second;
 
-  const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds);
+  const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
   Entry = {Generation, NewSCEV};
 
   return NewSCEV;
@@ -13906,14 +13910,18 @@ const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
 }
 
 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
-  if (Preds.implies(&Pred))
+  if (Preds->implies(&Pred))
     return;
-  Preds.add(&Pred);
+
+  auto &OldPreds = Preds->getPredicates();
+  SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
+  NewPreds.push_back(&Pred);
+  Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
   updateGeneration();
 }
 
 const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const {
-  return Preds;
+  return *Preds;
 }
 
 void PredicatedScalarEvolution::updateGeneration() {
@@ -13921,7 +13929,7 @@ void PredicatedScalarEvolution::updateGeneration() {
   if (++Generation == 0) {
     for (auto &II : RewriteMap) {
       const SCEV *Rewritten = II.second.second;
-      II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)};
+      II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
     }
   }
 }
@@ -13975,8 +13983,9 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {
 
 PredicatedScalarEvolution::PredicatedScalarEvolution(
     const PredicatedScalarEvolution &Init)
-    : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds),
-      Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
+  : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
+    Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
+    Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
   for (auto I : Init.FlagsMap)
     FlagsMap.insert(I);
 }

diff  --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 2101039abc15e..3e02fc7126e9b 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -942,7 +942,6 @@ TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstants) {
 
   // Make sure that SCEV doesn't blow up
   ScalarEvolution SE = buildSE(*F);
-  SCEVUnionPredicate Preds;
   const SCEV *Expr = SE.getSCEV(Phi);
   EXPECT_NE(nullptr, Expr);
   EXPECT_TRUE(isa<SCEVUnknown>(Expr));
@@ -1000,7 +999,6 @@ TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstantAccum) {
 
   // Make sure that SCEV doesn't blow up
   ScalarEvolution SE = buildSE(*F);
-  SCEVUnionPredicate Preds;
   const SCEV *Expr = SE.getSCEV(Phi);
   EXPECT_NE(nullptr, Expr);
   EXPECT_TRUE(isa<SCEVUnknown>(Expr));


        


More information about the llvm-commits mailing list