[llvm] 6770480 - [SCEV] Track backedge taken count users (NFCI)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 1 01:16:56 PST 2021


Author: Nikita Popov
Date: 2021-12-01T10:16:47+01:00
New Revision: 67704801c679a59c552dfade92a3d7c53979e2c8

URL: https://github.com/llvm/llvm-project/commit/67704801c679a59c552dfade92a3d7c53979e2c8
DIFF: https://github.com/llvm/llvm-project/commit/67704801c679a59c552dfade92a3d7c53979e2c8.diff

LOG: [SCEV] Track backedge taken count users (NFCI)

Track which SCEVs are used as ExactNotTaken counts in
BackedgeTakenInfo structures, so we can directly determine which
loops need to be invalidated, rather than iterating over all BECounts.

This gives a small compile-time improvement on average, but the
motivation here is more to ensure there are no degenerate cases,
if the number of backedge taken counts is large.

Differential Revision: https://reviews.llvm.org/D114784

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 73faa0afcdc8f..df50611832cee 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1378,6 +1378,8 @@ class ScalarEvolution {
   /// includes an exact count and a maximum count.
   ///
   class BackedgeTakenInfo {
+    friend class ScalarEvolution;
+
     /// A list of computable exits and their not-taken counts.  Loops almost
     /// never have more than one computable exit.
     SmallVector<ExitNotTakenInfo, 1> ExitNotTaken;
@@ -1398,9 +1400,6 @@ class ScalarEvolution {
     /// True iff the backedge is taken either exactly Max or zero times.
     bool MaxOrZero = false;
 
-    /// SCEV expressions used in any of the ExitNotTakenInfo counts.
-    SmallPtrSet<const SCEV *, 4> Operands;
-
     bool isComplete() const { return IsComplete; }
     const SCEV *getConstantMax() const { return ConstantMax; }
 
@@ -1466,10 +1465,6 @@ class ScalarEvolution {
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
     bool isConstantMaxOrZero(ScalarEvolution *SE) const;
-
-    /// Return true if any backedge taken count expressions refer to the given
-    /// subexpression.
-    bool hasOperand(const SCEV *S) const;
   };
 
   /// Cache the backedge-taken count of the loops for this function as they
@@ -1480,6 +1475,10 @@ class ScalarEvolution {
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
+  /// Loops whose backedge taken counts directly use this non-constant SCEV.
+  DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
+      BECountUsers;
+
   /// This map contains entries for all of the PHI instructions that we
   /// attempt to compute constant evolutions for.  This allows us to avoid
   /// potentially expensive recomputation of these properties.  An instruction
@@ -1911,6 +1910,9 @@ class ScalarEvolution {
   bool splitBinaryAdd(const SCEV *Expr, const SCEV *&L, const SCEV *&R,
                       SCEV::NoWrapFlags &Flags);
 
+  /// Forget predicated/non-predicated backedge taken counts for the given loop.
+  void forgetBackedgeTakenCounts(const Loop *L, bool Predicated);
+
   /// Drop memoized information for all \p SCEVs.
   void forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs);
 

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index ece19099d65eb..7dc7f9904c709 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7603,6 +7603,7 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
   ValueExprMap.clear();
@@ -7629,8 +7630,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
     auto *CurrL = LoopWorklist.pop_back_val();
 
     // Drop any stored trip count value.
-    BackedgeTakenCounts.erase(CurrL);
-    PredicatedBackedgeTakenCounts.erase(CurrL);
+    forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
+    forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
 
     // Drop information about predicated SCEV rewrites for this loop.
     for (auto I = PredicatedSCEVRewrites.begin();
@@ -7804,10 +7805,6 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
   return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
 }
 
-bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const {
-  return Operands.contains(S);
-}
-
 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
     : ExitLimit(E, E, false, None) {
 }
@@ -7848,19 +7845,6 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
     : ExitLimit(E, M, MaxOrZero, None) {
 }
 
-class SCEVRecordOperands {
-  SmallPtrSetImpl<const SCEV *> &Operands;
-
-public:
-  SCEVRecordOperands(SmallPtrSetImpl<const SCEV *> &Operands)
-    : Operands(Operands) {}
-  bool follow(const SCEV *S) {
-    Operands.insert(S);
-    return true;
-  }
-  bool isDone() { return false; }
-};
-
 /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
 /// computable exit into a persistent ExitNotTakenInfo array.
 ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
@@ -7889,14 +7873,6 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
   assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
           isa<SCEVConstant>(ConstantMax)) &&
          "No point in having a non-constant max backedge taken count!");
-
-  SCEVRecordOperands RecordOperands(Operands);
-  SCEVTraversal<SCEVRecordOperands> ST(RecordOperands);
-  if (!isa<SCEVCouldNotCompute>(ConstantMax))
-    ST.visitAll(ConstantMax);
-  for (auto &ENT : ExitNotTaken)
-    if (!isa<SCEVCouldNotCompute>(ENT.ExactNotTaken))
-      ST.visitAll(ENT.ExactNotTaken);
 }
 
 /// Compute the number of times the backedge of the specified loop will execute.
@@ -7978,6 +7954,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   // The loop backedge will be taken the maximum or zero times if there's
   // a single exit that must be taken the maximum or zero times.
   bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
+
+  // Remember which SCEVs are used in exit limits for invalidation purposes.
+  // We only care about non-constant SCEVs here, so we can ignore EL.MaxNotTaken
+  // and MaxBECount, which must be SCEVConstant.
+  for (const auto &Pair : ExitCounts)
+    if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
+      BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
   return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
                            MaxBECount, MaxOrZero);
 }
@@ -12466,6 +12449,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
       BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
       PredicatedBackedgeTakenCounts(
           std::move(Arg.PredicatedBackedgeTakenCounts)),
+      BECountUsers(std::move(Arg.BECountUsers)),
       ConstantEvolutionLoopExitValue(
           std::move(Arg.ConstantEvolutionLoopExitValue)),
       ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
@@ -12882,6 +12866,23 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
   return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
 }
 
+void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
+                                                bool Predicated) {
+  auto &BECounts =
+      Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
+  auto It = BECounts.find(L);
+  if (It != BECounts.end()) {
+    for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
+      if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
+        auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
+        assert(UserIt != BECountUsers.end());
+        UserIt->second.erase({L, Predicated});
+      }
+    }
+    BECounts.erase(It);
+  }
+}
+
 void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
   SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
   SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
@@ -12906,21 +12907,6 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
     else
       ++I;
   }
-
-  auto RemoveSCEVFromBackedgeMap = [&ToForget](
-      DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
-        for (auto I = Map.begin(), E = Map.end(); I != E;) {
-          BackedgeTakenInfo &BEInfo = I->second;
-          if (any_of(ToForget,
-                     [&BEInfo](const SCEV *S) { return BEInfo.hasOperand(S); }))
-            Map.erase(I++);
-          else
-            ++I;
-        }
-  };
-
-  RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
-  RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
 }
 
 void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
@@ -12958,6 +12944,15 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
       erase_value(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
     ValuesAtScopesUsers.erase(ScopeUserIt);
   }
+
+  auto BEUsersIt = BECountUsers.find(S);
+  if (BEUsersIt != BECountUsers.end()) {
+    // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
+    auto Copy = BEUsersIt->second;
+    for (const auto &Pair : Copy)
+      forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
+    BECountUsers.erase(BEUsersIt);
+  }
 }
 
 void
@@ -13144,10 +13139,31 @@ void ScalarEvolution::verify() const {
           is_contained(It->second, std::make_pair(L, ValueAtScope)))
         continue;
       dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
-             << ValueAtScope << " missing in ValuesAtScopes";
+             << ValueAtScope << " missing in ValuesAtScopes\n";
       std::abort();
     }
   }
+
+  // Verify integrity of BECountUsers.
+  auto VerifyBECountUsers = [&](bool Predicated) {
+    auto &BECounts =
+        Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
+    for (const auto &LoopAndBEInfo : BECounts) {
+      for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
+        if (!isa<SCEVConstant>(ENT.ExactNotTaken)) {
+          auto UserIt = BECountUsers.find(ENT.ExactNotTaken);
+          if (UserIt != BECountUsers.end() &&
+              UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
+            continue;
+          dbgs() << "Value " << *ENT.ExactNotTaken << " for loop "
+                 << *LoopAndBEInfo.first << " missing from BECountUsers\n";
+          std::abort();
+        }
+      }
+    }
+  };
+  VerifyBECountUsers(/* Predicated */ false);
+  VerifyBECountUsers(/* Predicated */ true);
 }
 
 bool ScalarEvolution::invalidate(


        


More information about the llvm-commits mailing list