[llvm] [SCEV] Memoize collected loop guards. NFCI (PR #116947)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 21 02:41:00 PST 2024


https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/116947

>From 9a64d8a957d207a143c8b6fd4321af3b21b41422 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 20 Nov 2024 16:55:52 +0800
Subject: [PATCH 1/4] [SCEV] Cache collected loop guards. NFCI

This tries to compensate for it by caching the collected loop guards, which gives a -0.07% geomean reduction for stage2-O3: https://llvm-compile-time-tracker.com/compare.php?from=aff98e4be05a1060e489ce62a88ee0ff365e571a&to=198a76db2c0b8fbda5374ffd195731a9d47469e3&stat=instructions:u

LoopAccessAnalysis already had a LoopGuards cache for the innermost loop, so this hoists it up into ScalarEvolution.
---
 .../llvm/Analysis/LoopAccessAnalysis.h        |  3 ---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  5 +++-
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      | 16 ++++---------
 llvm/lib/Analysis/ScalarEvolution.cpp         | 24 ++++++++++---------
 4 files changed, 22 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index a35bc7402d1a89..872b68f924e654 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,9 +334,6 @@ class MemoryDepChecker {
            std::pair<const SCEV *, const SCEV *>>
       PointerBounds;
 
-  /// Cache for the loop guards of InnermostLoop.
-  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
-
   /// Check whether there is a plausible dependence between the two
   /// accesses.
   ///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 885c5985f9d23a..b7b9384c6e642c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1346,7 +1346,6 @@ class ScalarEvolution {
 
   /// Try to apply information from loop guards for \p L to \p Expr.
   const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
 
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
@@ -1651,6 +1650,10 @@ class ScalarEvolution {
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
+  /// Cache the collected loop guards of the loops of this function as they are
+  /// computed.
+  DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
+
   /// Loops whose backedge taken counts directly use this non-constant SCEV.
   DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
       BECountUsers;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 907bb7875dc807..9dfc3def3140af 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,16 +1945,13 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
         !isa<SCEVCouldNotCompute>(SrcEnd_) &&
         !isa<SCEVCouldNotCompute>(SinkStart_) &&
         !isa<SCEVCouldNotCompute>(SinkEnd_)) {
-      if (!LoopGuards)
-        LoopGuards.emplace(
-            ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
-      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
-      auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
+      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
+      auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
         return MemoryDepChecker::Dependence::NoDep;
 
-      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
-      auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
+      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
+      auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
         return MemoryDepChecker::Dependence::NoDep;
     }
@@ -2057,10 +2054,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       return Dependence::NoDep;
     }
   } else {
-    if (!LoopGuards)
-      LoopGuards.emplace(
-          ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
-    Dist = SE.applyLoopGuards(Dist, *LoopGuards);
+    Dist = SE.applyLoopGuards(Dist, InnermostLoop);
   }
 
   // Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 46b108606f6a62..70a45ef507e279 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,6 +8417,7 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  LoopGuardsCache.clear();
   BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
@@ -10551,9 +10552,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
-  LoopGuards Guards = LoopGuards::collect(L, *this);
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
+  const SCEV *StepWLG = applyLoopGuards(Step, L);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10570,7 +10570,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10611,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10629,7 +10629,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,6 +13674,7 @@ ScalarEvolution::~ScalarEvolution() {
   HasRecMap.clear();
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
+  LoopGuardsCache.clear();
 
   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
   assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15889,10 +15890,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
 }
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
-  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
-}
-
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
-                                             const LoopGuards &Guards) {
-  return Guards.rewrite(Expr);
+  auto Itr = LoopGuardsCache.find(L);
+  if (Itr == LoopGuardsCache.end()) {
+    LoopGuards Guard = LoopGuards::collect(L, *this);
+    LoopGuardsCache.insert({L, Guard});
+    return Guard.rewrite(Expr);
+  }
+  return Itr->second.rewrite(Expr);
 }

>From 7a0bf19b515a555f0b80c55e8d700dfcc8e88ff4 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 14:33:06 +0800
Subject: [PATCH 2/4] Rework to memoize loop guards across multiple exits

---
 .../llvm/Analysis/LoopAccessAnalysis.h        |   3 +
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 122 +++++++++---------
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      |  16 ++-
 llvm/lib/Analysis/ScalarEvolution.cpp         | 119 +++++++++--------
 llvm/lib/Transforms/Scalar/IndVarSimplify.cpp |   9 +-
 5 files changed, 150 insertions(+), 119 deletions(-)

diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 872b68f924e654..a35bc7402d1a89 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,6 +334,9 @@ class MemoryDepChecker {
            std::pair<const SCEV *, const SCEV *>>
       PointerBounds;
 
+  /// Cache for the loop guards of InnermostLoop.
+  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+
   /// Check whether there is a plausible dependence between the two
   /// accesses.
   ///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index b7b9384c6e642c..692dd02f4b4324 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1112,6 +1112,46 @@ class ScalarEvolution {
   bool isKnownOnEveryIteration(ICmpInst::Predicate Pred,
                                const SCEVAddRecExpr *LHS, const SCEV *RHS);
 
+  class LoopGuards {
+    DenseMap<const SCEV *, const SCEV *> RewriteMap;
+    bool PreserveNUW = false;
+    bool PreserveNSW = false;
+    ScalarEvolution &SE;
+
+    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+
+    /// Recursively collect loop guards in \p Guards, starting from
+    /// block \p Block with predecessor \p Pred. The intended starting point
+    /// is to collect from a loop header and its predecessor.
+    static void
+    collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+                     const BasicBlock *Block, const BasicBlock *Pred,
+                     SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+                     unsigned Depth = 0);
+
+    /// Collect loop guards in \p Guards, starting from PHINode \p
+    /// Phi, by calling \p collectFromBlock on the incoming blocks of
+    /// \Phi and trying to merge the found constraints into a single
+    /// combined one for \p Phi.
+    static void collectFromPHI(
+        ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+        const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+        SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
+        unsigned Depth);
+
+  public:
+    /// Collect rewrite map for loop guards for loop \p L, together with flags
+    /// indicating if NUW and NSW can be preserved during rewriting.
+    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
+
+    /// Try to apply the collected loop guards to \p Expr.
+    const SCEV *rewrite(const SCEV *Expr) const;
+  };
+
+  /// Try to apply information from loop guards for \p L to \p Expr.
+  const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
+  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
+
   /// Information about the number of loop iterations for which a loop exit's
   /// branch condition evaluates to the not-taken path.  This is a temporary
   /// pair of exact and max expressions that are eventually summarized in
@@ -1167,6 +1207,7 @@ class ScalarEvolution {
   /// If \p AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
   ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
+                                     std::function<LoopGuards()> GetLoopGuards,
                                      bool ExitIfTrue, bool ControlsOnlyExit,
                                      bool AllowPredicates = false);
 
@@ -1308,45 +1349,6 @@ class ScalarEvolution {
   /// sharpen it.
   void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
 
-  class LoopGuards {
-    DenseMap<const SCEV *, const SCEV *> RewriteMap;
-    bool PreserveNUW = false;
-    bool PreserveNSW = false;
-    ScalarEvolution &SE;
-
-    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
-
-    /// Recursively collect loop guards in \p Guards, starting from
-    /// block \p Block with predecessor \p Pred. The intended starting point
-    /// is to collect from a loop header and its predecessor.
-    static void
-    collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
-                     const BasicBlock *Block, const BasicBlock *Pred,
-                     SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
-                     unsigned Depth = 0);
-
-    /// Collect loop guards in \p Guards, starting from PHINode \p
-    /// Phi, by calling \p collectFromBlock on the incoming blocks of
-    /// \Phi and trying to merge the found constraints into a single
-    /// combined one for \p Phi.
-    static void collectFromPHI(
-        ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
-        const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
-        SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
-        unsigned Depth);
-
-  public:
-    /// Collect rewrite map for loop guards for loop \p L, together with flags
-    /// indicating if NUW and NSW can be preserved during rewriting.
-    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
-
-    /// Try to apply the collected loop guards to \p Expr.
-    const SCEV *rewrite(const SCEV *Expr) const;
-  };
-
-  /// Try to apply information from loop guards for \p L to \p Expr.
-  const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
   /// (As opposed to either a) throwing out of the function or b) entering a
@@ -1650,10 +1652,6 @@ class ScalarEvolution {
   /// function as they are computed.
   DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
 
-  /// Cache the collected loop guards of the loops of this function as they are
-  /// computed.
-  DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
-
   /// Loops whose backedge taken counts directly use this non-constant SCEV.
   DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
       BECountUsers;
@@ -1843,6 +1841,7 @@ class ScalarEvolution {
   /// this call will try to use a minimal set of SCEV predicates in order to
   /// return an exact answer.
   ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+                             std::function<LoopGuards()> GetLoopGuards,
                              bool IsOnlyExit, bool AllowPredicates = false);
 
   // Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1875,17 +1874,17 @@ class ScalarEvolution {
 
   using ExitLimitCacheTy = ExitLimitCache;
 
-  ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache,
-                                           const Loop *L, Value *ExitCond,
-                                           bool ExitIfTrue,
-                                           bool ControlsOnlyExit,
-                                           bool AllowPredicates);
-  ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const Loop *L,
-                                         Value *ExitCond, bool ExitIfTrue,
-                                         bool ControlsOnlyExit,
-                                         bool AllowPredicates);
+  ExitLimit computeExitLimitFromCondCached(
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      bool ControlsOnlyExit, bool AllowPredicates);
+  ExitLimit computeExitLimitFromCondImpl(
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      bool ControlsOnlyExit, bool AllowPredicates);
   std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
-      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+      ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
 
   /// Compute the number of times the backedge of the specified loop will
@@ -1894,8 +1893,8 @@ class ScalarEvolution {
   /// to use a minimal set of SCEV predicates in order to return an exact
   /// answer.
   ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
-                                     bool ExitIfTrue,
-                                     bool IsSubExpr,
+                                     std::function<LoopGuards()> GetLoopGuards,
+                                     bool ExitIfTrue, bool IsSubExpr,
                                      bool AllowPredicates = false);
 
   /// Variant of previous which takes the components representing an ICmp
@@ -1904,16 +1903,16 @@ class ScalarEvolution {
   /// has a materialized ICmp.
   ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
                                      const SCEV *LHS, const SCEV *RHS,
+                                     std::function<LoopGuards()> GetLoopGuards,
                                      bool IsSubExpr,
                                      bool AllowPredicates = false);
 
   /// Compute the number of times the backedge of the specified loop will
   /// execute if its exit condition were a switch with a single exiting case
   /// to ExitingBB.
-  ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L,
-                                                 SwitchInst *Switch,
-                                                 BasicBlock *ExitingBB,
-                                                 bool IsSubExpr);
+  ExitLimit computeExitLimitFromSingleExitSwitch(
+      const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
+      std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
 
   /// Compute the exit limit of a loop that is controlled by a
   /// "(IV >> 1) != 0" type comparison.  We cannot compute the exact trip
@@ -1937,8 +1936,9 @@ class ScalarEvolution {
   /// value to zero will execute.  If not computable, return CouldNotCompute.
   /// If AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
-  ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr,
-                         bool AllowPredicates = false);
+  ExitLimit howFarToZero(const SCEV *V, const Loop *L,
+                         std::function<LoopGuards()> GetLoopGuards,
+                         bool IsSubExpr, bool AllowPredicates = false);
 
   /// Return the number of times an exit condition checking the specified
   /// value for nonzero will execute.  If not computable, return
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 9dfc3def3140af..907bb7875dc807 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,13 +1945,16 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
         !isa<SCEVCouldNotCompute>(SrcEnd_) &&
         !isa<SCEVCouldNotCompute>(SinkStart_) &&
         !isa<SCEVCouldNotCompute>(SinkEnd_)) {
-      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
-      auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
+      if (!LoopGuards)
+        LoopGuards.emplace(
+            ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+      auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
+      auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
         return MemoryDepChecker::Dependence::NoDep;
 
-      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
-      auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
+      auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
+      auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
       if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
         return MemoryDepChecker::Dependence::NoDep;
     }
@@ -2054,7 +2057,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
       return Dependence::NoDep;
     }
   } else {
-    Dist = SE.applyLoopGuards(Dist, InnermostLoop);
+    if (!LoopGuards)
+      LoopGuards.emplace(
+          ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+    Dist = SE.applyLoopGuards(Dist, *LoopGuards);
   }
 
   // Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 70a45ef507e279..0ff2c486a06611 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,7 +8417,6 @@ void ScalarEvolution::forgetAllLoops() {
   // result.
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
-  LoopGuardsCache.clear();
   BECountUsers.clear();
   LoopPropertiesCache.clear();
   ConstantEvolutionLoopExitValue.clear();
@@ -8807,6 +8806,12 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   const SCEV *MayExitMaxBECount = nullptr;
   bool MustExitMaxOrZero = false;
   bool IsOnlyExit = ExitingBlocks.size() == 1;
+  std::optional<LoopGuards> LoopGuards;
+  auto GetLoopGuards = [&LoopGuards, &L, this]() {
+    if (!LoopGuards)
+      LoopGuards.emplace(LoopGuards::collect(L, *this));
+    return *LoopGuards;
+  };
 
   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
   // and compute maxBECount.
@@ -8822,7 +8827,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
           continue;
       }
 
-    ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
+    ExitLimit EL =
+        computeExitLimit(L, ExitBB, GetLoopGuards, IsOnlyExit, AllowPredicates);
 
     assert((AllowPredicates || EL.Predicates.empty()) &&
            "Predicated exit limit when predicates are not allowed!");
@@ -8897,6 +8903,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
 
 ScalarEvolution::ExitLimit
 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+                                  std::function<LoopGuards()> GetLoopGuards,
                                   bool IsOnlyExit, bool AllowPredicates) {
   assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
   // If our exiting block does not dominate the latch, then its connection with
@@ -8912,9 +8919,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
     assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
            "It should have one successor in loop and one exit block!");
     // Proceed to the next level to examine the exit condition expression.
-    return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
-                                    /*ControlsOnlyExit=*/IsOnlyExit,
-                                    AllowPredicates);
+    return computeExitLimitFromCond(
+        L, BI->getCondition(), GetLoopGuards, ExitIfTrue,
+        /*ControlsOnlyExit=*/IsOnlyExit, AllowPredicates);
   }
 
   if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8928,18 +8935,19 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
       }
     assert(Exit && "Exiting block must have at least one exit");
     return computeExitLimitFromSingleExitSwitch(
-        L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
+        L, SI, Exit, GetLoopGuards, /*ControlsOnlyExit=*/IsOnlyExit);
   }
 
   return getCouldNotCompute();
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
-    const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
-    bool AllowPredicates) {
+    const Loop *L, Value *ExitCond, std::function<LoopGuards()> GetLoopGuards,
+    bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) {
   ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
-  return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
-                                        ControlsOnlyExit, AllowPredicates);
+  return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards,
+                                        ExitIfTrue, ControlsOnlyExit,
+                                        AllowPredicates);
 }
 
 std::optional<ScalarEvolution::ExitLimit>
@@ -8975,37 +8983,41 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
 
   if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
                                 AllowPredicates))
     return *MaybeEL;
 
-  ExitLimit EL = computeExitLimitFromCondImpl(
-      Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
+  ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, GetLoopGuards,
+                                              ExitIfTrue, ControlsOnlyExit,
+                                              AllowPredicates);
   Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
   return EL;
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Handle BinOp conditions (And, Or).
   if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
-          Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
+          Cache, L, ExitCond, GetLoopGuards, ExitIfTrue, ControlsOnlyExit,
+          AllowPredicates))
     return *LimitFromBinOp;
 
   // With an icmp, it may be feasible to compute an exact backedge-taken count.
   // Proceed to the next level to examine the icmp.
   if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
-    ExitLimit EL =
-        computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
+    ExitLimit EL = computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards,
+                                            ExitIfTrue, ControlsOnlyExit);
     if (EL.hasFullInfo() || !AllowPredicates)
       return EL;
 
     // Try again, but use SCEV predicates this time.
-    return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
+    return computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards, ExitIfTrue,
                                     ControlsOnlyExit,
                                     /*AllowPredicates=*/true);
   }
@@ -9041,7 +9053,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     if (Offset != 0)
       LHS = getAddExpr(LHS, getConstant(Offset));
     auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
-                                       ControlsOnlyExit, AllowPredicates);
+                                       GetLoopGuards, ControlsOnlyExit,
+                                       AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
   }
@@ -9052,7 +9065,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
 
 std::optional<ScalarEvolution::ExitLimit>
 ScalarEvolution::computeExitLimitFromCondFromBinOp(
-    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Check if the controlling expression for this loop is an And or Or.
   Value *Op0, *Op1;
@@ -9069,11 +9083,11 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
   //   br (or  Op0 Op1), exit, loop
   bool EitherMayExit = IsAnd ^ ExitIfTrue;
   ExitLimit EL0 = computeExitLimitFromCondCached(
-      Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
-      AllowPredicates);
+      Cache, L, Op0, GetLoopGuards, ExitIfTrue,
+      ControlsOnlyExit && !EitherMayExit, AllowPredicates);
   ExitLimit EL1 = computeExitLimitFromCondCached(
-      Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
-      AllowPredicates);
+      Cache, L, Op1, GetLoopGuards, ExitIfTrue,
+      ControlsOnlyExit && !EitherMayExit, AllowPredicates);
 
   // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
   const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
@@ -9132,8 +9146,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
-    const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
-    bool AllowPredicates) {
+    const Loop *L, ICmpInst *ExitCond,
+    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    bool ControlsOnlyExit, bool AllowPredicates) {
   // If the condition was exit on true, convert the condition to exit on false
   ICmpInst::Predicate Pred;
   if (!ExitIfTrue)
@@ -9145,8 +9160,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
 
-  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
-                                          AllowPredicates);
+  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, GetLoopGuards,
+                                          ControlsOnlyExit, AllowPredicates);
   if (EL.hasAnyInfo())
     return EL;
 
@@ -9161,7 +9176,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 }
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
-    bool ControlsOnlyExit, bool AllowPredicates) {
+    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit,
+    bool AllowPredicates) {
 
   // Try to evaluate any dependencies out of the loop.
   LHS = getSCEVAtScope(LHS, L);
@@ -9249,8 +9265,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
       if (isa<SCEVCouldNotCompute>(RHS))
         return RHS;
     }
-    ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
-                                AllowPredicates);
+    ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards,
+                                ControlsOnlyExit, AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
     break;
@@ -9332,10 +9348,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 }
 
 ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
-                                                      SwitchInst *Switch,
-                                                      BasicBlock *ExitingBlock,
-                                                      bool ControlsOnlyExit) {
+ScalarEvolution::computeExitLimitFromSingleExitSwitch(
+    const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
+    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit) {
   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
 
   // Give up if the exit is the default dest of a switch.
@@ -9348,7 +9363,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
 
   // while (X != Y) --> while (X-Y != 0)
-  ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
+  ExitLimit EL =
+      howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards, ControlsOnlyExit);
   if (EL.hasAnyInfo())
     return EL;
 
@@ -10486,10 +10502,10 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
   return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
 }
 
-ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
-                                                         const Loop *L,
-                                                         bool ControlsOnlyExit,
-                                                         bool AllowPredicates) {
+ScalarEvolution::ExitLimit
+ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
+                              std::function<LoopGuards()> GetLoopGuards,
+                              bool ControlsOnlyExit, bool AllowPredicates) {
 
   // This is only used for loops with a "x != y" exit test. The exit condition
   // is now expressed as a single expression, V = x-y. So the exit test is
@@ -10552,8 +10568,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
+  LoopGuards Guards = GetLoopGuards();
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, L);
+  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10570,7 +10587,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10628,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10629,7 +10646,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,7 +13691,6 @@ ScalarEvolution::~ScalarEvolution() {
   HasRecMap.clear();
   BackedgeTakenCounts.clear();
   PredicatedBackedgeTakenCounts.clear();
-  LoopGuardsCache.clear();
 
   assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
   assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15890,11 +15906,10 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
 }
 
 const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
-  auto Itr = LoopGuardsCache.find(L);
-  if (Itr == LoopGuardsCache.end()) {
-    LoopGuards Guard = LoopGuards::collect(L, *this);
-    LoopGuardsCache.insert({L, Guard});
-    return Guard.rewrite(Expr);
-  }
-  return Itr->second.rewrite(Expr);
+  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
+                                             const LoopGuards &Guards) {
+  return Guards.rewrite(Expr);
 }
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 8a3e0bc3eb9712..62e6d541af5c65 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1335,6 +1335,13 @@ static bool optimizeLoopExitWithUnknownExitCount(
   Visited.insert(OldCond);
   Worklist.push_back(OldCond);
 
+  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+  auto GetLoopGuards = [&LoopGuards, &L, &SE]() {
+    if (!LoopGuards)
+      LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
+    return *LoopGuards;
+  };
+
   auto GoThrough = [&](Value *V) {
     Value *LHS = nullptr, *RHS = nullptr;
     if (Inverted) {
@@ -1371,7 +1378,7 @@ static bool optimizeLoopExitWithUnknownExitCount(
                        ScalarEvolution::ExitCountKind::SymbolicMaximum) ==
           MaxIter)
     for (auto *ICmp : LeafConditions) {
-      auto EL = SE->computeExitLimitFromCond(L, ICmp, Inverted,
+      auto EL = SE->computeExitLimitFromCond(L, ICmp, GetLoopGuards, Inverted,
                                              /*ControlsExit*/ false);
       const SCEV *ExitMax = EL.SymbolicMaxNotTaken;
       if (isa<SCEVCouldNotCompute>(ExitMax))

>From 9c8dac0196c3199842606a208e19235bc6505bc2 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 18:05:58 +0800
Subject: [PATCH 3/4] Use function_ref, return const reference

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 40 ++++++++++---------
 llvm/lib/Analysis/ScalarEvolution.cpp         | 37 ++++++++---------
 llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 11 ++---
 3 files changed, 46 insertions(+), 42 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 692dd02f4b4324..c67cbefd7fb927 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1206,10 +1206,11 @@ class ScalarEvolution {
   ///
   /// If \p AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
-  ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
-                                     std::function<LoopGuards()> GetLoopGuards,
-                                     bool ExitIfTrue, bool ControlsOnlyExit,
-                                     bool AllowPredicates = false);
+  ExitLimit
+  computeExitLimitFromCond(const Loop *L, Value *ExitCond,
+                           function_ref<const LoopGuards &()> GetLoopGuards,
+                           bool ExitIfTrue, bool ControlsOnlyExit,
+                           bool AllowPredicates = false);
 
   /// A predicate is said to be monotonically increasing if may go from being
   /// false to being true as the loop iterates, but never the other way
@@ -1841,7 +1842,7 @@ class ScalarEvolution {
   /// this call will try to use a minimal set of SCEV predicates in order to
   /// return an exact answer.
   ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
-                             std::function<LoopGuards()> GetLoopGuards,
+                             function_ref<const LoopGuards &()> GetLoopGuards,
                              bool IsOnlyExit, bool AllowPredicates = false);
 
   // Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1876,15 +1877,15 @@ class ScalarEvolution {
 
   ExitLimit computeExitLimitFromCondCached(
       ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
   ExitLimit computeExitLimitFromCondImpl(
       ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
   std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
       ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-      std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+      function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
       bool ControlsOnlyExit, bool AllowPredicates);
 
   /// Compute the number of times the backedge of the specified loop will
@@ -1892,27 +1893,28 @@ class ScalarEvolution {
   /// ExitCond and ExitIfTrue. If AllowPredicates is set, this call will try
   /// to use a minimal set of SCEV predicates in order to return an exact
   /// answer.
-  ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
-                                     std::function<LoopGuards()> GetLoopGuards,
-                                     bool ExitIfTrue, bool IsSubExpr,
-                                     bool AllowPredicates = false);
+  ExitLimit
+  computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
+                           function_ref<const LoopGuards &()> GetLoopGuards,
+                           bool ExitIfTrue, bool IsSubExpr,
+                           bool AllowPredicates = false);
 
   /// Variant of previous which takes the components representing an ICmp
   /// as opposed to the ICmpInst itself.  Note that the prior version can
   /// return more precise results in some cases and is preferred when caller
   /// has a materialized ICmp.
-  ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
-                                     const SCEV *LHS, const SCEV *RHS,
-                                     std::function<LoopGuards()> GetLoopGuards,
-                                     bool IsSubExpr,
-                                     bool AllowPredicates = false);
+  ExitLimit
+  computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
+                           const SCEV *LHS, const SCEV *RHS,
+                           function_ref<const LoopGuards &()> GetLoopGuards,
+                           bool IsSubExpr, bool AllowPredicates = false);
 
   /// Compute the number of times the backedge of the specified loop will
   /// execute if its exit condition were a switch with a single exiting case
   /// to ExitingBB.
   ExitLimit computeExitLimitFromSingleExitSwitch(
       const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
-      std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
+      function_ref<const LoopGuards &()> GetLoopGuards, bool IsSubExpr);
 
   /// Compute the exit limit of a loop that is controlled by a
   /// "(IV >> 1) != 0" type comparison.  We cannot compute the exact trip
@@ -1937,7 +1939,7 @@ class ScalarEvolution {
   /// If AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
   ExitLimit howFarToZero(const SCEV *V, const Loop *L,
-                         std::function<LoopGuards()> GetLoopGuards,
+                         function_ref<const LoopGuards &()> GetLoopGuards,
                          bool IsSubExpr, bool AllowPredicates = false);
 
   /// Return the number of times an exit condition checking the specified
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0ff2c486a06611..d4c4a127868729 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8806,11 +8806,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
   const SCEV *MayExitMaxBECount = nullptr;
   bool MustExitMaxOrZero = false;
   bool IsOnlyExit = ExitingBlocks.size() == 1;
-  std::optional<LoopGuards> LoopGuards;
-  auto GetLoopGuards = [&LoopGuards, &L, this]() {
-    if (!LoopGuards)
-      LoopGuards.emplace(LoopGuards::collect(L, *this));
-    return *LoopGuards;
+  std::optional<LoopGuards> CachedLoopGuards;
+  auto GetLoopGuards = [&CachedLoopGuards, &L, this]() -> const LoopGuards & {
+    if (!CachedLoopGuards)
+      CachedLoopGuards.emplace(LoopGuards::collect(L, *this));
+    return *CachedLoopGuards;
   };
 
   // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
@@ -8901,10 +8901,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
                            MaxBECount, MaxOrZero);
 }
 
-ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
-                                  std::function<LoopGuards()> GetLoopGuards,
-                                  bool IsOnlyExit, bool AllowPredicates) {
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(
+    const Loop *L, BasicBlock *ExitingBlock,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool IsOnlyExit,
+    bool AllowPredicates) {
   assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
   // If our exiting block does not dominate the latch, then its connection with
   // loop's exit limit may be far from trivial.
@@ -8942,8 +8942,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
 }
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
-    const Loop *L, Value *ExitCond, std::function<LoopGuards()> GetLoopGuards,
-    bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) {
+    const Loop *L, Value *ExitCond,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
+    bool ControlsOnlyExit, bool AllowPredicates) {
   ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
   return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards,
                                         ExitIfTrue, ControlsOnlyExit,
@@ -8984,7 +8985,7 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
 
   if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
@@ -9000,7 +9001,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Handle BinOp conditions (And, Or).
   if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
@@ -9066,7 +9067,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
 std::optional<ScalarEvolution::ExitLimit>
 ScalarEvolution::computeExitLimitFromCondFromBinOp(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
-    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // Check if the controlling expression for this loop is an And or Or.
   Value *Op0, *Op1;
@@ -9147,7 +9148,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
 
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     const Loop *L, ICmpInst *ExitCond,
-    std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
   // If the condition was exit on true, convert the condition to exit on false
   ICmpInst::Predicate Pred;
@@ -9176,7 +9177,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 }
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
-    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit,
     bool AllowPredicates) {
 
   // Try to evaluate any dependencies out of the loop.
@@ -9350,7 +9351,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 ScalarEvolution::ExitLimit
 ScalarEvolution::computeExitLimitFromSingleExitSwitch(
     const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
-    std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit) {
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit) {
   assert(!L->contains(ExitingBlock) && "Not an exiting block!");
 
   // Give up if the exit is the default dest of a switch.
@@ -10504,7 +10505,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
 
 ScalarEvolution::ExitLimit
 ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
-                              std::function<LoopGuards()> GetLoopGuards,
+                              function_ref<const LoopGuards &()> GetLoopGuards,
                               bool ControlsOnlyExit, bool AllowPredicates) {
 
   // This is only used for loops with a "x != y" exit test. The exit condition
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 62e6d541af5c65..1cf5aca2266c61 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1335,11 +1335,12 @@ static bool optimizeLoopExitWithUnknownExitCount(
   Visited.insert(OldCond);
   Worklist.push_back(OldCond);
 
-  std::optional<ScalarEvolution::LoopGuards> LoopGuards;
-  auto GetLoopGuards = [&LoopGuards, &L, &SE]() {
-    if (!LoopGuards)
-      LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
-    return *LoopGuards;
+  std::optional<ScalarEvolution::LoopGuards> CachedLoopGuards;
+  auto GetLoopGuards = [&CachedLoopGuards, &L,
+                        &SE]() -> const ScalarEvolution::LoopGuards & {
+    if (!CachedLoopGuards)
+      CachedLoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
+    return *CachedLoopGuards;
   };
 
   auto GoThrough = [&](Value *V) {

>From a10aad8c6acfcb879f56ab9267f61d0b2c899939 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 18:40:21 +0800
Subject: [PATCH 4/4] Use reference again and use in howManyLessThans

---
 llvm/include/llvm/Analysis/ScalarEvolution.h |  4 +++-
 llvm/lib/Analysis/ScalarEvolution.cpp        | 21 ++++++++++----------
 2 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c67cbefd7fb927..756fd0eeeef74a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1960,7 +1960,9 @@ class ScalarEvolution {
   /// If \p AllowPredicates is set, this call will try to use a minimal set of
   /// SCEV predicates in order to return an exact answer.
   ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L,
-                             bool isSigned, bool ControlsOnlyExit,
+                             bool isSigned,
+                             function_ref<const LoopGuards &()> GetLoopGuards,
+                             bool ControlsOnlyExit,
                              bool AllowPredicates = false);
 
   ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L,
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d4c4a127868729..51ac06121c9c12 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9317,8 +9317,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   case ICmpInst::ICMP_SLT:
   case ICmpInst::ICMP_ULT: { // while (X < Y)
     bool IsSigned = ICmpInst::isSigned(Pred);
-    ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
-                                    AllowPredicates);
+    ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, GetLoopGuards,
+                                    ControlsOnlyExit, AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
     break;
@@ -10569,7 +10569,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
-  LoopGuards Guards = GetLoopGuards();
+  const LoopGuards &Guards = GetLoopGuards();
   // Specialize step for this loop so we get context sensitive facts below.
   const SCEV *StepWLG = applyLoopGuards(Step, Guards);
 
@@ -12928,10 +12928,10 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
                          getConstant(StrideForMaxBECount) /* Step */);
 }
 
-ScalarEvolution::ExitLimit
-ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
-                                  const Loop *L, bool IsSigned,
-                                  bool ControlsOnlyExit, bool AllowPredicates) {
+ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(
+    const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
+    function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit,
+    bool AllowPredicates) {
   SmallVector<const SCEVPredicate *> Predicates;
 
   const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
@@ -12965,7 +12965,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
           APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
           APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
           Limit = Limit.zext(OuterBitWidth);
-          return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
+          return getUnsignedRangeMax(applyLoopGuards(RHS, GetLoopGuards()))
+              .ule(Limit);
         };
         auto Flags = AR->getNoWrapFlags();
         if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
@@ -13216,8 +13217,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (!BECount) {
       auto canProveRHSGreaterThanEqualStart = [&]() {
         auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
-        const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
-        const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
+        const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, GetLoopGuards());
+        const SCEV *GuardedStart = applyLoopGuards(OrigStart, GetLoopGuards());
 
         if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
             isKnownPredicate(CondGE, GuardedRHS, GuardedStart))



More information about the llvm-commits mailing list