[llvm] [SCEV] Cache collected loop guards. NFCI (PR #116947)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 20 22:33:54 PST 2024


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/116947

>From 9a64d8a957d207a143c8b6fd4321af3b21b41422 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 20 Nov 2024 16:55:52 +0800
Subject: [PATCH 1/2] [SCEV] Cache collected loop guards. NFCI

This tries to compensate for it by caching the collected loop guards, which gives a -0.07% geomean reduction for stage2-O3: https://llvm-compile-time-tracker.com/compare.php?from=aff98e4be05a1060e489ce62a88ee0ff365e571a&to=198a76db2c0b8fbda5374ffd195731a9d47469e3&stat=instructions:u

LoopAccessAnalysis already had a LoopGuards cache for the innermost loop, so this hoists it up into ScalarEvolution.
---
 .../llvm/Analysis/LoopAccessAnalysis.h        |  3 ---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  5 +++-
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      | 16 ++++---------
 llvm/lib/Analysis/ScalarEvolution.cpp         | 24 ++++++++++---------
 4 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index a35bc7402d1a89..872b68f924e654 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,9 +334,6 @@ class MemoryDepChecker {
            std::pair<const SCEV *, const SCEV *>>
       PointerBounds;
 
-  /// Cache for the loop guards of InnermostLoop.
-  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
-
   /// Check whether there is a plausible dependence between the two
   /// accesses.
   ///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 885c5985f9d23a..b7b9384c6e642c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1346,7 +1346,6 @@ class ScalarEvolution {
 
   /// Try to apply information from loop guards for \p L to \p Expr.
   const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
 
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
@@ -1651,6 +1650,10 @@ class ScalarEvolution {
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
+  /// Cache the collected loop guards of the loops of this function as they are
+  /// computed.
+  DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
+
   /// Loops whose backedge taken counts directly use this non-constant SCEV.
   DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
       BECountUsers;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 907bb7875dc807..9dfc3def3140af 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,16 +1945,13 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
         !isa<SCEVCouldNotCompute>(SrcEnd_) &&
         !isa<SCEVCouldNotCompute>(SinkStart_) &&
         !isa<SCEVCouldNotCompute>(SinkEnd_)) {
-      if (!LoopGuards)
-        LoopGuards.emplace(
-            ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
-      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
-      auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
+      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
+      auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
         return MemoryDepChecker::Dependence::NoDep;
 
-      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
-      auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
+      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
+      auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
         return MemoryDepChecker::Dependence::NoDep;
     }
@@ -2057,10 +2054,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       return Dependence::NoDep;
     }
   } else {
-    if (!LoopGuards)
-      LoopGuards.emplace(
-          ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
-    Dist = SE.applyLoopGuards(Dist, *LoopGuards);
+    Dist = SE.applyLoopGuards(Dist, InnermostLoop);
   }
 
   // Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 46b108606f6a62..70a45ef507e279 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,6 +8417,7 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  LoopGuardsCache.clear();
   BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
@@ -10551,9 +10552,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
-  LoopGuards Guards = LoopGuards::collect(L, *this);
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
+  const SCEV *StepWLG = applyLoopGuards(Step, L);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10570,7 +10570,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10611,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10629,7 +10629,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,6 +13674,7 @@ ScalarEvolution::~ScalarEvolution() {
   HasRecMap.clear();
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  LoopGuardsCache.clear();
 
   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
   assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15889,10 +15890,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
 }
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
-  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
-}
-
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
-                                             const LoopGuards &Guards) {
-  return Guards.rewrite(Expr);
+  auto Itr = LoopGuardsCache.find(L);
+  if (Itr == LoopGuardsCache.end()) {
+    LoopGuards Guard = LoopGuards::collect(L, *this);
+    LoopGuardsCache.insert({L, Guard});
+    return Guard.rewrite(Expr);
+  }
+  return Itr->second.rewrite(Expr);
 }

>From 7a0bf19b515a555f0b80c55e8d700dfcc8e88ff4 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 14:33:06 +0800
Subject: [PATCH 2/2] Rework to memoize loop guards across multiple exits

---
 .../llvm/Analysis/LoopAccessAnalysis.h        |   3 +
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 122 +++++++++---------
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      |  16 ++-
 llvm/lib/Analysis/ScalarEvolution.cpp         | 119 +++++++++--------
 llvm/lib/Transforms/Scalar/IndVarSimplify.cpp |   9 +-
 5 files changed, 150 insertions(+), 119 deletions(-)

diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 872b68f924e654..a35bc7402d1a89 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,6 +334,9 @@ class MemoryDepChecker {
            std::pair<const SCEV *, const SCEV *>>
       PointerBounds;
 
+  /// Cache for the loop guards of InnermostLoop.
+  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+
   /// Check whether there is a plausible dependence between the two
   /// accesses.
   ///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index b7b9384c6e642c..692dd02f4b4324 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1112,6 +1112,46 @@ class ScalarEvolution {
   bool isKnownOnEveryIteration(ICmpInst::Predicate Pred,
                                const SCEVAddRecExpr *LHS, const SCEV *RHS);
 
+  class LoopGuards {
+    DenseMap<const SCEV *, const SCEV *> RewriteMap;
+    bool PreserveNUW = false;
+    bool PreserveNSW = false;
+    ScalarEvolution &SE;
+
+    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+
+    /// Recursively collect loop guards in \p Guards, starting from
+    /// block \p Block with predecessor \p Pred. The intended starting point
+    /// is to collect from a loop header and its predecessor.
+    static void
+    collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+                     const BasicBlock *Block, const BasicBlock *Pred,
+                     SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+                     unsigned Depth = 0);
+
+    /// Collect loop guards in \p Guards, starting from PHINode \p
+    /// Phi, by calling \p collectFromBlock on the incoming blocks of
+    /// \Phi and trying to merge the found constraints into a single
+    /// combined one for \p Phi.
+    static void collectFromPHI(
+        ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+        const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+        SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
+        unsigned Depth);
+
+  public:
+    /// Collect rewrite map for loop guards for loop \p L, together with flags
+    /// indicating if NUW and NSW can be preserved during rewriting.
+    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
+
+    /// Try to apply the collected loop guards to \p Expr.
+    const SCEV *rewrite(const SCEV *Expr) const;
+  };
+
+  /// Try to apply information from loop guards for \p L to \p Expr.
+  const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
+  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
+
   /// Information about the number of loop iterations for which a loop exit's
   /// branch condition evaluates to the not-taken path.  This is a temporary
   /// pair of exact and max expressions that are eventually summarized in
@@ -1167,6 +1207,7 @@ class ScalarEvolution {
   /// If \p AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
   ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
+                                     std::function<LoopGuards()> GetLoopGuards,
                                      bool ExitIfTrue, bool ControlsOnlyExit,
                                      bool AllowPredicates = false);
 
@@ -1308,45 +1349,6 @@ class ScalarEvolution {
   /// sharpen it.
   void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
 
-  class LoopGuards {
-    DenseMap<const SCEV *, const SCEV *> RewriteMap;
-    bool PreserveNUW = false;
-    bool PreserveNSW = false;
-    ScalarEvolution &SE;
-
-    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
-
-    /// Recursively collect loop guards in \p Guards, starting from
-    /// block \p Block with predecessor \p Pred. The intended starting point
-    /// is to collect from a loop header and its predecessor.
-    static void
-    collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
-                     const BasicBlock *Block, const BasicBlock *Pred,
-                     SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
-                     unsigned Depth = 0);
-
-    /// Collect loop guards in \p Guards, starting from PHINode \p
-    /// Phi, by calling \p collectFromBlock on the incoming blocks of
-    /// \Phi and trying to merge the found constraints into a single
-    /// combined one for \p Phi.
-    static void collectFromPHI(
-        ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
-        const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
-        SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
-        unsigned Depth);
-
-  public:
-    /// Collect rewrite map for loop guards for loop \p L, together with flags
-    /// indicating if NUW and NSW can be preserved during rewriting.
-    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
-
-    /// Try to apply the collected loop guards to \p Expr.
-    const SCEV *rewrite(const SCEV *Expr) const;
-  };
-
-  /// Try to apply information from loop guards for \p L to \p Expr.
-  const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
   /// (As opposed to either a) throwing out of the function or b) entering a
@@ -1650,10 +1652,6 @@ class ScalarEvolution {
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
-  /// Cache the collected loop guards of the loops of this function as they are
-  /// computed.
-  DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
-
   /// Loops whose backedge taken counts directly use this non-constant SCEV.
   DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
       BECountUsers;
@@ -1843,6 +1841,7 @@ class ScalarEvolution {
   /// this call will try to use a minimal set of SCEV predicates in order to
   /// return an exact answer.
   ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+                             std::function<LoopGuards()> GetLoopGuards,
                              bool IsOnlyExit, bool AllowPredicates = false);
 
   // Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1875,17 +1874,17 @@ class ScalarEvolution {
 
   using ExitLimitCacheTy = ExitLimitCache;
 
-  ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache,
-                                           const Loop *L, Value *ExitCond,
-                                           bool ExitIfTrue,
-                                           bool ControlsOnlyExit,
-                                           bool AllowPredicates);
-  ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const Loop *L,
-                                         Value *ExitCond, bool ExitIfTrue,
-                                         bool ControlsOnlyExit,
-                                         bool AllowPredicates);
+  ExitLimit computeExitLimitFromCondCached(
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      bool ControlsOnlyExit, bool AllowPredicates);
+  ExitLimit computeExitLimitFromCondImpl(
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      bool ControlsOnlyExit, bool AllowPredicates);
   std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
-      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
 
   /// Compute the number of times the backedge of the specified loop will
@@ -1894,8 +1893,8 @@ class ScalarEvolution {
   /// to use a minimal set of SCEV predicates in order to return an exact
   /// answer.
   ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
-                                     bool ExitIfTrue,
-                                     bool IsSubExpr,
+                                     std::function<LoopGuards()> GetLoopGuards,
+                                     bool ExitIfTrue, bool IsSubExpr,
                                      bool AllowPredicates = false);
 
   /// Variant of previous which takes the components representing an ICmp
@@ -1904,16 +1903,16 @@ class ScalarEvolution {
   /// has a materialized ICmp.
   ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
                                      const SCEV *LHS, const SCEV *RHS,
+                                     std::function<LoopGuards()> GetLoopGuards,
                                      bool IsSubExpr,
                                      bool AllowPredicates = false);
 
   /// Compute the number of times the backedge of the specified loop will
   /// execute if its exit condition were a switch with a single exiting case
   /// to ExitingBB.
-  ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L,
-                                                 SwitchInst *Switch,
-                                                 BasicBlock *ExitingBB,
-                                                 bool IsSubExpr);
+  ExitLimit computeExitLimitFromSingleExitSwitch(
+      const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
+      std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
 
   /// Compute the exit limit of a loop that is controlled by a
   /// "(IV >> 1) != 0" type comparison.  We cannot compute the exact trip
@@ -1937,8 +1936,9 @@ class ScalarEvolution {
   /// value to zero will execute.  If not computable, return CouldNotCompute.
   /// If AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
-  ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr,
-                         bool AllowPredicates = false);
+  ExitLimit howFarToZero(const SCEV *V, const Loop *L,
+                         std::function<LoopGuards()> GetLoopGuards,
+                         bool IsSubExpr, bool AllowPredicates = false);
 
   /// Return the number of times an exit condition checking the specified
   /// value for nonzero will execute.  If not computable, return
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 9dfc3def3140af..907bb7875dc807 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,13 +1945,16 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
         !isa<SCEVCouldNotCompute>(SrcEnd_) &&
         !isa<SCEVCouldNotCompute>(SinkStart_) &&
         !isa<SCEVCouldNotCompute>(SinkEnd_)) {
-      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
-      auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
+      if (!LoopGuards)
+        LoopGuards.emplace(
+            ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
+      auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
         return MemoryDepChecker::Dependence::NoDep;
 
-      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
-      auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
+      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
+      auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
         return MemoryDepChecker::Dependence::NoDep;
     }
@@ -2054,7 +2057,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       return Dependence::NoDep;
     }
   } else {
-    Dist = SE.applyLoopGuards(Dist, InnermostLoop);
+    if (!LoopGuards)
+      LoopGuards.emplace(
+          ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+    Dist = SE.applyLoopGuards(Dist, *LoopGuards);
   }
 
   // Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 70a45ef507e279..0ff2c486a06611 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,7 +8417,6 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
-  LoopGuardsCache.clear();
   BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
@@ -8807,6 +8806,12 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   const SCEV *MayExitMaxBECount = nullptr;
   bool MustExitMaxOrZero = false;
   bool IsOnlyExit = ExitingBlocks.size() == 1;
+  std::optional<LoopGuards> LoopGuards;
+  auto GetLoopGuards = [&LoopGuards, &L, this]() {
+    if (!LoopGuards)
+      LoopGuards.emplace(LoopGuards::collect(L, *this));
+    return *LoopGuards;
+  };
 
   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
   // and compute maxBECount.
@@ -8822,7 +8827,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
           continue;
       }
 
-    ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
+    ExitLimit EL =
+        computeExitLimit(L, ExitBB, GetLoopGuards, IsOnlyExit, AllowPredicates);
 
     assert((AllowPredicates || EL.Predicates.empty()) &&
            "Predicated exit limit when predicates are not allowed!");
@@ -8897,6 +8903,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
 
 ScalarEvolution::ExitLimit
 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+                                  std::function<LoopGuards()> GetLoopGuards,
                                   bool IsOnlyExit, bool AllowPredicates) {
   assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
   // If our exiting block does not dominate the latch, then its connection with
@@ -8912,9 +8919,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
     assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
            "It should have one successor in loop and one exit block!");
     // Proceed to the next level to examine the exit condition expression.
-    return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
-                                    /*ControlsOnlyExit=*/IsOnlyExit,
-                                    AllowPredicates);
+    return computeExitLimitFromCond(
+        L, BI->getCondition(), GetLoopGuards, ExitIfTrue,
+        /*ControlsOnlyExit=*/IsOnlyExit, AllowPredicates);
   }
 
   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8928,18 +8935,19 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
       }
     assert(Exit && "Exiting block must have at least one exit");
     return computeExitLimitFromSingleExitSwitch(
-        L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
+        L, SI, Exit, GetLoopGuards, /*ControlsOnlyExit=*/IsOnlyExit);
   }
 
   return getCouldNotCompute();
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
-    const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
-    bool AllowPredicates) {
+    const Loop *L, Value *ExitCond, std::function<LoopGuards()> GetLoopGuards,
+    bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) {
   ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
-  return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
-                                        ControlsOnlyExit, AllowPredicates);
+  return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards,
+                                        ExitIfTrue, ControlsOnlyExit,
+                                        AllowPredicates);
 }
 
 std::optional<ScalarEvolution::ExitLimit>
@@ -8975,37 +8983,41 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
 
   if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
                                 AllowPredicates))
     return *MaybeEL;
 
-  ExitLimit EL = computeExitLimitFromCondImpl(
-      Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
+  ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, GetLoopGuards,
+                                              ExitIfTrue, ControlsOnlyExit,
+                                              AllowPredicates);
   Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
   return EL;
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Handle BinOp conditions (And, Or).
   if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
-          Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
+          Cache, L, ExitCond, GetLoopGuards, ExitIfTrue, ControlsOnlyExit,
+          AllowPredicates))
     return *LimitFromBinOp;
 
   // With an icmp, it may be feasible to compute an exact backedge-taken count.
   // Proceed to the next level to examine the icmp.
   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
-    ExitLimit EL =
-        computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
+    ExitLimit EL = computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards,
+                                            ExitIfTrue, ControlsOnlyExit);
     if (EL.hasFullInfo() || !AllowPredicates)
       return EL;
 
     // Try again, but use SCEV predicates this time.
-    return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
+    return computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards, ExitIfTrue,
                                     ControlsOnlyExit,
                                     /*AllowPredicates=*/true);
   }
@@ -9041,7 +9053,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     if (Offset != 0)
       LHS = getAddExpr(LHS, getConstant(Offset));
     auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
-                                       ControlsOnlyExit, AllowPredicates);
+                                       GetLoopGuards, ControlsOnlyExit,
+                                       AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
   }
@@ -9052,7 +9065,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
 
 std::optional<ScalarEvolution::ExitLimit>
 ScalarEvolution::computeExitLimitFromCondFromBinOp(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Check if the controlling expression for this loop is an And or Or.
   Value *Op0, *Op1;
@@ -9069,11 +9083,11 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
   //   br (or  Op0 Op1), exit, loop
   bool EitherMayExit = IsAnd ^ ExitIfTrue;
   ExitLimit EL0 = computeExitLimitFromCondCached(
-      Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
-      AllowPredicates);
+      Cache, L, Op0, GetLoopGuards, ExitIfTrue,
+      ControlsOnlyExit && !EitherMayExit, AllowPredicates);
   ExitLimit EL1 = computeExitLimitFromCondCached(
-      Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
-      AllowPredicates);
+      Cache, L, Op1, GetLoopGuards, ExitIfTrue,
+      ControlsOnlyExit && !EitherMayExit, AllowPredicates);
 
   // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
   const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
@@ -9132,8 +9146,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
-    const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
-    bool AllowPredicates) {
+    const Loop *L, ICmpInst *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    bool ControlsOnlyExit, bool AllowPredicates) {
   // If the condition was exit on true, convert the condition to exit on false
   ICmpInst::Predicate Pred;
   if (!ExitIfTrue)
@@ -9145,8 +9160,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
 
-  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
-                                          AllowPredicates);
+  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, GetLoopGuards,
+                                          ControlsOnlyExit, AllowPredicates);
   if (EL.hasAnyInfo())
     return EL;
 
@@ -9161,7 +9176,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 }
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
-    bool ControlsOnlyExit, bool AllowPredicates) {
+    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit,
+    bool AllowPredicates) {
 
   // Try to evaluate any dependencies out of the loop.
   LHS = getSCEVAtScope(LHS, L);
@@ -9249,8 +9265,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
       if (isa<SCEVCouldNotCompute>(RHS))
         return RHS;
     }
-    ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
-                                AllowPredicates);
+    ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards,
+                                ControlsOnlyExit, AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
     break;
@@ -9332,10 +9348,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 }
 
 ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
-                                                      SwitchInst *Switch,
-                                                      BasicBlock *ExitingBlock,
-                                                      bool ControlsOnlyExit) {
+ScalarEvolution::computeExitLimitFromSingleExitSwitch(
+    const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
+    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit) {
   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
 
   // Give up if the exit is the default dest of a switch.
@@ -9348,7 +9363,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
 
   // while (X != Y) --> while (X-Y != 0)
-  ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
+  ExitLimit EL =
+      howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards, ControlsOnlyExit);
   if (EL.hasAnyInfo())
     return EL;
 
@@ -10486,10 +10502,10 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
   return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
 }
 
-ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
-                                                         const Loop *L,
-                                                         bool ControlsOnlyExit,
-                                                         bool AllowPredicates) {
+ScalarEvolution::ExitLimit
+ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
+                              std::function<LoopGuards()> GetLoopGuards,
+                              bool ControlsOnlyExit, bool AllowPredicates) {
 
   // This is only used for loops with a "x != y" exit test. The exit condition
   // is now expressed as a single expression, V = x-y. So the exit test is
@@ -10552,8 +10568,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
+  LoopGuards Guards = GetLoopGuards();
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, L);
+  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10570,7 +10587,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10628,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10629,7 +10646,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,7 +13691,6 @@ ScalarEvolution::~ScalarEvolution() {
   HasRecMap.clear();
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
-  LoopGuardsCache.clear();
 
   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
   assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15890,11 +15906,10 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
 }
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
-  auto Itr = LoopGuardsCache.find(L);
-  if (Itr == LoopGuardsCache.end()) {
-    LoopGuards Guard = LoopGuards::collect(L, *this);
-    LoopGuardsCache.insert({L, Guard});
-    return Guard.rewrite(Expr);
-  }
-  return Itr->second.rewrite(Expr);
+  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
+                                             const LoopGuards &Guards) {
+  return Guards.rewrite(Expr);
 }
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 8a3e0bc3eb9712..62e6d541af5c65 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1335,6 +1335,13 @@ static bool optimizeLoopExitWithUnknownExitCount(
   Visited.insert(OldCond);
   Worklist.push_back(OldCond);
 
+  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+  auto GetLoopGuards = [&LoopGuards, &L, &SE]() {
+    if (!LoopGuards)
+      LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
+    return *LoopGuards;
+  };
+
   auto GoThrough = [&](Value *V) {
     Value *LHS = nullptr, *RHS = nullptr;
     if (Inverted) {
@@ -1371,7 +1378,7 @@ static bool optimizeLoopExitWithUnknownExitCount(
                        ScalarEvolution::ExitCountKind::SymbolicMaximum) ==
           MaxIter)
     for (auto *ICmp : LeafConditions) {
-      auto EL = SE->computeExitLimitFromCond(L, ICmp, Inverted,
+      auto EL = SE->computeExitLimitFromCond(L, ICmp, GetLoopGuards, Inverted,
                                              /*ControlsExit*/ false);
       const SCEV *ExitMax = EL.SymbolicMaxNotTaken;
       if (isa<SCEVCouldNotCompute>(ExitMax))



More information about the llvm-commits mailing list