[llvm] [SCEV] Memoize collected loop guards. NFCI (PR #116947)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 21 02:41:00 PST 2024
https://github.com/lukel97 updated https://github.com/llvm/llvm-project/pull/116947
>From 9a64d8a957d207a143c8b6fd4321af3b21b41422 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Wed, 20 Nov 2024 16:55:52 +0800
Subject: [PATCH 1/4] [SCEV] Cache collected loop guards. NFCI
This tries to compensate for it by caching the collected loop guards, which gives a -0.07% geomean reduction for stage2-O3: https://llvm-compile-time-tracker.com/compare.php?from=aff98e4be05a1060e489ce62a88ee0ff365e571a&to=198a76db2c0b8fbda5374ffd195731a9d47469e3&stat=instructions:u
LoopAccessAnalysis already had a LoopGuards cache for the innermost loop, so this hoists it up into ScalarEvolution.
---
.../llvm/Analysis/LoopAccessAnalysis.h | 3 ---
llvm/include/llvm/Analysis/ScalarEvolution.h | 5 +++-
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++++---------
llvm/lib/Analysis/ScalarEvolution.cpp | 24 ++++++++++---------
4 files changed, 22 insertions(+), 26 deletions(-)
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index a35bc7402d1a89..872b68f924e654 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,9 +334,6 @@ class MemoryDepChecker {
std::pair<const SCEV *, const SCEV *>>
PointerBounds;
- /// Cache for the loop guards of InnermostLoop.
- std::optional<ScalarEvolution::LoopGuards> LoopGuards;
-
/// Check whether there is a plausible dependence between the two
/// accesses.
///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 885c5985f9d23a..b7b9384c6e642c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1346,7 +1346,6 @@ class ScalarEvolution {
/// Try to apply information from loop guards for \p L to \p Expr.
const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
- const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
/// Return true if the loop has no abnormal exits. That is, if the loop
/// is not infinite, it must exit through an explicit edge in the CFG.
@@ -1651,6 +1650,10 @@ class ScalarEvolution {
/// function as they are computed.
DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
+ /// Cache the collected loop guards of the loops of this function as they are
+ /// computed.
+ DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
+
/// Loops whose backedge taken counts directly use this non-constant SCEV.
DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
BECountUsers;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 907bb7875dc807..9dfc3def3140af 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,16 +1945,13 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
!isa<SCEVCouldNotCompute>(SrcEnd_) &&
!isa<SCEVCouldNotCompute>(SinkStart_) &&
!isa<SCEVCouldNotCompute>(SinkEnd_)) {
- if (!LoopGuards)
- LoopGuards.emplace(
- ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
- auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
- auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
+ auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
+ auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
return MemoryDepChecker::Dependence::NoDep;
- auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
- auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
+ auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
+ auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
return MemoryDepChecker::Dependence::NoDep;
}
@@ -2057,10 +2054,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
return Dependence::NoDep;
}
} else {
- if (!LoopGuards)
- LoopGuards.emplace(
- ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
- Dist = SE.applyLoopGuards(Dist, *LoopGuards);
+ Dist = SE.applyLoopGuards(Dist, InnermostLoop);
}
// Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 46b108606f6a62..70a45ef507e279 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,6 +8417,7 @@ void ScalarEvolution::forgetAllLoops() {
// result.
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
+ LoopGuardsCache.clear();
BECountUsers.clear();
LoopPropertiesCache.clear();
ConstantEvolutionLoopExitValue.clear();
@@ -10551,9 +10552,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
- LoopGuards Guards = LoopGuards::collect(L, *this);
// Specialize step for this loop so we get context sensitive facts below.
- const SCEV *StepWLG = applyLoopGuards(Step, Guards);
+ const SCEV *StepWLG = applyLoopGuards(Step, L);
// For positive steps (counting up until unsigned overflow):
// N = -Start/Step (as unsigned)
@@ -10570,7 +10570,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// N = Distance (as unsigned)
if (StepC &&
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
- APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
+ APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
// When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10611,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
const SCEV *ConstantMax = getCouldNotCompute();
if (Exact != getCouldNotCompute()) {
- APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
+ APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
ConstantMax =
getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
}
@@ -10629,7 +10629,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
const SCEV *M = E;
if (E != getCouldNotCompute()) {
- APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
+ APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
}
auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,6 +13674,7 @@ ScalarEvolution::~ScalarEvolution() {
HasRecMap.clear();
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
+ LoopGuardsCache.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15889,10 +15890,11 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
- return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
-}
-
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
- const LoopGuards &Guards) {
- return Guards.rewrite(Expr);
+ auto Itr = LoopGuardsCache.find(L);
+ if (Itr == LoopGuardsCache.end()) {
+ LoopGuards Guard = LoopGuards::collect(L, *this);
+ LoopGuardsCache.insert({L, Guard});
+ return Guard.rewrite(Expr);
+ }
+ return Itr->second.rewrite(Expr);
}
>From 7a0bf19b515a555f0b80c55e8d700dfcc8e88ff4 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 14:33:06 +0800
Subject: [PATCH 2/4] Rework to memoize loop guards across multiple exits
---
.../llvm/Analysis/LoopAccessAnalysis.h | 3 +
llvm/include/llvm/Analysis/ScalarEvolution.h | 122 +++++++++---------
llvm/lib/Analysis/LoopAccessAnalysis.cpp | 16 ++-
llvm/lib/Analysis/ScalarEvolution.cpp | 119 +++++++++--------
llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 9 +-
5 files changed, 150 insertions(+), 119 deletions(-)
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index 872b68f924e654..a35bc7402d1a89 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -334,6 +334,9 @@ class MemoryDepChecker {
std::pair<const SCEV *, const SCEV *>>
PointerBounds;
+ /// Cache for the loop guards of InnermostLoop.
+ std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+
/// Check whether there is a plausible dependence between the two
/// accesses.
///
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index b7b9384c6e642c..692dd02f4b4324 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1112,6 +1112,46 @@ class ScalarEvolution {
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred,
const SCEVAddRecExpr *LHS, const SCEV *RHS);
+ class LoopGuards {
+ DenseMap<const SCEV *, const SCEV *> RewriteMap;
+ bool PreserveNUW = false;
+ bool PreserveNSW = false;
+ ScalarEvolution &SE;
+
+ LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+
+ /// Recursively collect loop guards in \p Guards, starting from
+ /// block \p Block with predecessor \p Pred. The intended starting point
+ /// is to collect from a loop header and its predecessor.
+ static void
+ collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+ const BasicBlock *Block, const BasicBlock *Pred,
+ SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+ unsigned Depth = 0);
+
+ /// Collect loop guards in \p Guards, starting from PHINode \p
+ /// Phi, by calling \p collectFromBlock on the incoming blocks of
+ /// \Phi and trying to merge the found constraints into a single
+ /// combined one for \p Phi.
+ static void collectFromPHI(
+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+ const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
+ SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
+ unsigned Depth);
+
+ public:
+ /// Collect rewrite map for loop guards for loop \p L, together with flags
+ /// indicating if NUW and NSW can be preserved during rewriting.
+ static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
+
+ /// Try to apply the collected loop guards to \p Expr.
+ const SCEV *rewrite(const SCEV *Expr) const;
+ };
+
+ /// Try to apply information from loop guards for \p L to \p Expr.
+ const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
+ const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
+
/// Information about the number of loop iterations for which a loop exit's
/// branch condition evaluates to the not-taken path. This is a temporary
/// pair of exact and max expressions that are eventually summarized in
@@ -1167,6 +1207,7 @@ class ScalarEvolution {
/// If \p AllowPredicates is set, this call will try to use a minimal set of
/// SCEV predicates in order to return an exact answer.
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards,
bool ExitIfTrue, bool ControlsOnlyExit,
bool AllowPredicates = false);
@@ -1308,45 +1349,6 @@ class ScalarEvolution {
/// sharpen it.
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
- class LoopGuards {
- DenseMap<const SCEV *, const SCEV *> RewriteMap;
- bool PreserveNUW = false;
- bool PreserveNSW = false;
- ScalarEvolution &SE;
-
- LoopGuards(ScalarEvolution &SE) : SE(SE) {}
-
- /// Recursively collect loop guards in \p Guards, starting from
- /// block \p Block with predecessor \p Pred. The intended starting point
- /// is to collect from a loop header and its predecessor.
- static void
- collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
- const BasicBlock *Block, const BasicBlock *Pred,
- SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
- unsigned Depth = 0);
-
- /// Collect loop guards in \p Guards, starting from PHINode \p
- /// Phi, by calling \p collectFromBlock on the incoming blocks of
- /// \Phi and trying to merge the found constraints into a single
- /// combined one for \p Phi.
- static void collectFromPHI(
- ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
- const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
- SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
- unsigned Depth);
-
- public:
- /// Collect rewrite map for loop guards for loop \p L, together with flags
- /// indicating if NUW and NSW can be preserved during rewriting.
- static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
-
- /// Try to apply the collected loop guards to \p Expr.
- const SCEV *rewrite(const SCEV *Expr) const;
- };
-
- /// Try to apply information from loop guards for \p L to \p Expr.
- const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-
/// Return true if the loop has no abnormal exits. That is, if the loop
/// is not infinite, it must exit through an explicit edge in the CFG.
/// (As opposed to either a) throwing out of the function or b) entering a
@@ -1650,10 +1652,6 @@ class ScalarEvolution {
/// function as they are computed.
DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
- /// Cache the collected loop guards of the loops of this function as they are
- /// computed.
- DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
-
/// Loops whose backedge taken counts directly use this non-constant SCEV.
DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
BECountUsers;
@@ -1843,6 +1841,7 @@ class ScalarEvolution {
/// this call will try to use a minimal set of SCEV predicates in order to
/// return an exact answer.
ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+ std::function<LoopGuards()> GetLoopGuards,
bool IsOnlyExit, bool AllowPredicates = false);
// Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1875,17 +1874,17 @@ class ScalarEvolution {
using ExitLimitCacheTy = ExitLimitCache;
- ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache,
- const Loop *L, Value *ExitCond,
- bool ExitIfTrue,
- bool ControlsOnlyExit,
- bool AllowPredicates);
- ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const Loop *L,
- Value *ExitCond, bool ExitIfTrue,
- bool ControlsOnlyExit,
- bool AllowPredicates);
+ ExitLimit computeExitLimitFromCondCached(
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ bool ControlsOnlyExit, bool AllowPredicates);
+ ExitLimit computeExitLimitFromCondImpl(
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ bool ControlsOnlyExit, bool AllowPredicates);
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates);
/// Compute the number of times the backedge of the specified loop will
@@ -1894,8 +1893,8 @@ class ScalarEvolution {
/// to use a minimal set of SCEV predicates in order to return an exact
/// answer.
ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
- bool ExitIfTrue,
- bool IsSubExpr,
+ std::function<LoopGuards()> GetLoopGuards,
+ bool ExitIfTrue, bool IsSubExpr,
bool AllowPredicates = false);
/// Variant of previous which takes the components representing an ICmp
@@ -1904,16 +1903,16 @@ class ScalarEvolution {
/// has a materialized ICmp.
ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
+ std::function<LoopGuards()> GetLoopGuards,
bool IsSubExpr,
bool AllowPredicates = false);
/// Compute the number of times the backedge of the specified loop will
/// execute if its exit condition were a switch with a single exiting case
/// to ExitingBB.
- ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L,
- SwitchInst *Switch,
- BasicBlock *ExitingBB,
- bool IsSubExpr);
+ ExitLimit computeExitLimitFromSingleExitSwitch(
+ const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
+ std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
/// Compute the exit limit of a loop that is controlled by a
/// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip
@@ -1937,8 +1936,9 @@ class ScalarEvolution {
/// value to zero will execute. If not computable, return CouldNotCompute.
/// If AllowPredicates is set, this call will try to use a minimal set of
/// SCEV predicates in order to return an exact answer.
- ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr,
- bool AllowPredicates = false);
+ ExitLimit howFarToZero(const SCEV *V, const Loop *L,
+ std::function<LoopGuards()> GetLoopGuards,
+ bool IsSubExpr, bool AllowPredicates = false);
/// Return the number of times an exit condition checking the specified
/// value for nonzero will execute. If not computable, return
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 9dfc3def3140af..907bb7875dc807 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1945,13 +1945,16 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
!isa<SCEVCouldNotCompute>(SrcEnd_) &&
!isa<SCEVCouldNotCompute>(SinkStart_) &&
!isa<SCEVCouldNotCompute>(SinkEnd_)) {
- auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
- auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
+ if (!LoopGuards)
+ LoopGuards.emplace(
+ ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+ auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
+ auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
return MemoryDepChecker::Dependence::NoDep;
- auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
- auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
+ auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
+ auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
return MemoryDepChecker::Dependence::NoDep;
}
@@ -2054,7 +2057,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
return Dependence::NoDep;
}
} else {
- Dist = SE.applyLoopGuards(Dist, InnermostLoop);
+ if (!LoopGuards)
+ LoopGuards.emplace(
+ ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
+ Dist = SE.applyLoopGuards(Dist, *LoopGuards);
}
// Negative distances are not plausible dependencies.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 70a45ef507e279..0ff2c486a06611 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8417,7 +8417,6 @@ void ScalarEvolution::forgetAllLoops() {
// result.
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
- LoopGuardsCache.clear();
BECountUsers.clear();
LoopPropertiesCache.clear();
ConstantEvolutionLoopExitValue.clear();
@@ -8807,6 +8806,12 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
const SCEV *MayExitMaxBECount = nullptr;
bool MustExitMaxOrZero = false;
bool IsOnlyExit = ExitingBlocks.size() == 1;
+ std::optional<LoopGuards> LoopGuards;
+ auto GetLoopGuards = [&LoopGuards, &L, this]() {
+ if (!LoopGuards)
+ LoopGuards.emplace(LoopGuards::collect(L, *this));
+ return *LoopGuards;
+ };
// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
// and compute maxBECount.
@@ -8822,7 +8827,8 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
continue;
}
- ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
+ ExitLimit EL =
+ computeExitLimit(L, ExitBB, GetLoopGuards, IsOnlyExit, AllowPredicates);
assert((AllowPredicates || EL.Predicates.empty()) &&
"Predicated exit limit when predicates are not allowed!");
@@ -8897,6 +8903,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
+ std::function<LoopGuards()> GetLoopGuards,
bool IsOnlyExit, bool AllowPredicates) {
assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
// If our exiting block does not dominate the latch, then its connection with
@@ -8912,9 +8919,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
"It should have one successor in loop and one exit block!");
// Proceed to the next level to examine the exit condition expression.
- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
- /*ControlsOnlyExit=*/IsOnlyExit,
- AllowPredicates);
+ return computeExitLimitFromCond(
+ L, BI->getCondition(), GetLoopGuards, ExitIfTrue,
+ /*ControlsOnlyExit=*/IsOnlyExit, AllowPredicates);
}
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8928,18 +8935,19 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
}
assert(Exit && "Exiting block must have at least one exit");
return computeExitLimitFromSingleExitSwitch(
- L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
+ L, SI, Exit, GetLoopGuards, /*ControlsOnlyExit=*/IsOnlyExit);
}
return getCouldNotCompute();
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
- const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
- bool AllowPredicates) {
+ const Loop *L, Value *ExitCond, std::function<LoopGuards()> GetLoopGuards,
+ bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) {
ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
- return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
- ControlsOnlyExit, AllowPredicates);
+ return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards,
+ ExitIfTrue, ControlsOnlyExit,
+ AllowPredicates);
}
std::optional<ScalarEvolution::ExitLimit>
@@ -8975,37 +8983,41 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
AllowPredicates))
return *MaybeEL;
- ExitLimit EL = computeExitLimitFromCondImpl(
- Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
+ ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, GetLoopGuards,
+ ExitIfTrue, ControlsOnlyExit,
+ AllowPredicates);
Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
return EL;
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
// Handle BinOp conditions (And, Or).
if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
- Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
+ Cache, L, ExitCond, GetLoopGuards, ExitIfTrue, ControlsOnlyExit,
+ AllowPredicates))
return *LimitFromBinOp;
// With an icmp, it may be feasible to compute an exact backedge-taken count.
// Proceed to the next level to examine the icmp.
if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
- ExitLimit EL =
- computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
+ ExitLimit EL = computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards,
+ ExitIfTrue, ControlsOnlyExit);
if (EL.hasFullInfo() || !AllowPredicates)
return EL;
// Try again, but use SCEV predicates this time.
- return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
+ return computeExitLimitFromICmp(L, ExitCondICmp, GetLoopGuards, ExitIfTrue,
ControlsOnlyExit,
/*AllowPredicates=*/true);
}
@@ -9041,7 +9053,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
if (Offset != 0)
LHS = getAddExpr(LHS, getConstant(Offset));
auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
- ControlsOnlyExit, AllowPredicates);
+ GetLoopGuards, ControlsOnlyExit,
+ AllowPredicates);
if (EL.hasAnyInfo())
return EL;
}
@@ -9052,7 +9065,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
std::optional<ScalarEvolution::ExitLimit>
ScalarEvolution::computeExitLimitFromCondFromBinOp(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
// Check if the controlling expression for this loop is an And or Or.
Value *Op0, *Op1;
@@ -9069,11 +9083,11 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
// br (or Op0 Op1), exit, loop
bool EitherMayExit = IsAnd ^ ExitIfTrue;
ExitLimit EL0 = computeExitLimitFromCondCached(
- Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, Op0, GetLoopGuards, ExitIfTrue,
+ ControlsOnlyExit && !EitherMayExit, AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCondCached(
- Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, Op1, GetLoopGuards, ExitIfTrue,
+ ControlsOnlyExit && !EitherMayExit, AllowPredicates);
// Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
@@ -9132,8 +9146,9 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
- const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
- bool AllowPredicates) {
+ const Loop *L, ICmpInst *ExitCond,
+ std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ bool ControlsOnlyExit, bool AllowPredicates) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Pred;
if (!ExitIfTrue)
@@ -9145,8 +9160,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
- ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
- AllowPredicates);
+ ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, GetLoopGuards,
+ ControlsOnlyExit, AllowPredicates);
if (EL.hasAnyInfo())
return EL;
@@ -9161,7 +9176,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
- bool ControlsOnlyExit, bool AllowPredicates) {
+ std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit,
+ bool AllowPredicates) {
// Try to evaluate any dependencies out of the loop.
LHS = getSCEVAtScope(LHS, L);
@@ -9249,8 +9265,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
if (isa<SCEVCouldNotCompute>(RHS))
return RHS;
}
- ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
- AllowPredicates);
+ ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards,
+ ControlsOnlyExit, AllowPredicates);
if (EL.hasAnyInfo())
return EL;
break;
@@ -9332,10 +9348,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
}
ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
- SwitchInst *Switch,
- BasicBlock *ExitingBlock,
- bool ControlsOnlyExit) {
+ScalarEvolution::computeExitLimitFromSingleExitSwitch(
+ const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
+ std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit) {
assert(!L->contains(ExitingBlock) && "Not an exiting block!");
// Give up if the exit is the default dest of a switch.
@@ -9348,7 +9363,8 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
// while (X != Y) --> while (X-Y != 0)
- ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
+ ExitLimit EL =
+ howFarToZero(getMinusSCEV(LHS, RHS), L, GetLoopGuards, ControlsOnlyExit);
if (EL.hasAnyInfo())
return EL;
@@ -10486,10 +10502,10 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
}
-ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
- const Loop *L,
- bool ControlsOnlyExit,
- bool AllowPredicates) {
+ScalarEvolution::ExitLimit
+ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
+ std::function<LoopGuards()> GetLoopGuards,
+ bool ControlsOnlyExit, bool AllowPredicates) {
// This is only used for loops with a "x != y" exit test. The exit condition
// is now expressed as a single expression, V = x-y. So the exit test is
@@ -10552,8 +10568,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
+ LoopGuards Guards = GetLoopGuards();
// Specialize step for this loop so we get context sensitive facts below.
- const SCEV *StepWLG = applyLoopGuards(Step, L);
+ const SCEV *StepWLG = applyLoopGuards(Step, Guards);
// For positive steps (counting up until unsigned overflow):
// N = -Start/Step (as unsigned)
@@ -10570,7 +10587,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
// N = Distance (as unsigned)
if (StepC &&
(StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
- APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+ APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
// When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10611,7 +10628,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
const SCEV *ConstantMax = getCouldNotCompute();
if (Exact != getCouldNotCompute()) {
- APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
+ APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
ConstantMax =
getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
}
@@ -10629,7 +10646,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
const SCEV *M = E;
if (E != getCouldNotCompute()) {
- APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
+ APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
}
auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -13674,7 +13691,6 @@ ScalarEvolution::~ScalarEvolution() {
HasRecMap.clear();
BackedgeTakenCounts.clear();
PredicatedBackedgeTakenCounts.clear();
- LoopGuardsCache.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
assert(PendingPhiRanges.empty() && "getRangeRef garbage");
@@ -15890,11 +15906,10 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
- auto Itr = LoopGuardsCache.find(L);
- if (Itr == LoopGuardsCache.end()) {
- LoopGuards Guard = LoopGuards::collect(L, *this);
- LoopGuardsCache.insert({L, Guard});
- return Guard.rewrite(Expr);
- }
- return Itr->second.rewrite(Expr);
+ return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
+ const LoopGuards &Guards) {
+ return Guards.rewrite(Expr);
}
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 8a3e0bc3eb9712..62e6d541af5c65 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1335,6 +1335,13 @@ static bool optimizeLoopExitWithUnknownExitCount(
Visited.insert(OldCond);
Worklist.push_back(OldCond);
+ std::optional<ScalarEvolution::LoopGuards> LoopGuards;
+ auto GetLoopGuards = [&LoopGuards, &L, &SE]() {
+ if (!LoopGuards)
+ LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
+ return *LoopGuards;
+ };
+
auto GoThrough = [&](Value *V) {
Value *LHS = nullptr, *RHS = nullptr;
if (Inverted) {
@@ -1371,7 +1378,7 @@ static bool optimizeLoopExitWithUnknownExitCount(
ScalarEvolution::ExitCountKind::SymbolicMaximum) ==
MaxIter)
for (auto *ICmp : LeafConditions) {
- auto EL = SE->computeExitLimitFromCond(L, ICmp, Inverted,
+ auto EL = SE->computeExitLimitFromCond(L, ICmp, GetLoopGuards, Inverted,
/*ControlsExit*/ false);
const SCEV *ExitMax = EL.SymbolicMaxNotTaken;
if (isa<SCEVCouldNotCompute>(ExitMax))
>From 9c8dac0196c3199842606a208e19235bc6505bc2 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 18:05:58 +0800
Subject: [PATCH 3/4] Use function_ref, return const reference
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 40 ++++++++++---------
llvm/lib/Analysis/ScalarEvolution.cpp | 37 ++++++++---------
llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 11 ++---
3 files changed, 46 insertions(+), 42 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 692dd02f4b4324..c67cbefd7fb927 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1206,10 +1206,11 @@ class ScalarEvolution {
///
/// If \p AllowPredicates is set, this call will try to use a minimal set of
/// SCEV predicates in order to return an exact answer.
- ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards,
- bool ExitIfTrue, bool ControlsOnlyExit,
- bool AllowPredicates = false);
+ ExitLimit
+ computeExitLimitFromCond(const Loop *L, Value *ExitCond,
+ function_ref<const LoopGuards &()> GetLoopGuards,
+ bool ExitIfTrue, bool ControlsOnlyExit,
+ bool AllowPredicates = false);
/// A predicate is said to be monotonically increasing if may go from being
/// false to being true as the loop iterates, but never the other way
@@ -1841,7 +1842,7 @@ class ScalarEvolution {
/// this call will try to use a minimal set of SCEV predicates in order to
/// return an exact answer.
ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
- std::function<LoopGuards()> GetLoopGuards,
+ function_ref<const LoopGuards &()> GetLoopGuards,
bool IsOnlyExit, bool AllowPredicates = false);
// Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1876,15 +1877,15 @@ class ScalarEvolution {
ExitLimit computeExitLimitFromCondCached(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates);
ExitLimit computeExitLimitFromCondImpl(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates);
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates);
/// Compute the number of times the backedge of the specified loop will
@@ -1892,27 +1893,28 @@ class ScalarEvolution {
/// ExitCond and ExitIfTrue. If AllowPredicates is set, this call will try
/// to use a minimal set of SCEV predicates in order to return an exact
/// answer.
- ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
- std::function<LoopGuards()> GetLoopGuards,
- bool ExitIfTrue, bool IsSubExpr,
- bool AllowPredicates = false);
+ ExitLimit
+ computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
+ function_ref<const LoopGuards &()> GetLoopGuards,
+ bool ExitIfTrue, bool IsSubExpr,
+ bool AllowPredicates = false);
/// Variant of previous which takes the components representing an ICmp
/// as opposed to the ICmpInst itself. Note that the prior version can
/// return more precise results in some cases and is preferred when caller
/// has a materialized ICmp.
- ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
- const SCEV *LHS, const SCEV *RHS,
- std::function<LoopGuards()> GetLoopGuards,
- bool IsSubExpr,
- bool AllowPredicates = false);
+ ExitLimit
+ computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ function_ref<const LoopGuards &()> GetLoopGuards,
+ bool IsSubExpr, bool AllowPredicates = false);
/// Compute the number of times the backedge of the specified loop will
/// execute if its exit condition were a switch with a single exiting case
/// to ExitingBB.
ExitLimit computeExitLimitFromSingleExitSwitch(
const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
- std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
+ function_ref<const LoopGuards &()> GetLoopGuards, bool IsSubExpr);
/// Compute the exit limit of a loop that is controlled by a
/// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip
@@ -1937,7 +1939,7 @@ class ScalarEvolution {
/// If AllowPredicates is set, this call will try to use a minimal set of
/// SCEV predicates in order to return an exact answer.
ExitLimit howFarToZero(const SCEV *V, const Loop *L,
- std::function<LoopGuards()> GetLoopGuards,
+ function_ref<const LoopGuards &()> GetLoopGuards,
bool IsSubExpr, bool AllowPredicates = false);
/// Return the number of times an exit condition checking the specified
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0ff2c486a06611..d4c4a127868729 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8806,11 +8806,11 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
const SCEV *MayExitMaxBECount = nullptr;
bool MustExitMaxOrZero = false;
bool IsOnlyExit = ExitingBlocks.size() == 1;
- std::optional<LoopGuards> LoopGuards;
- auto GetLoopGuards = [&LoopGuards, &L, this]() {
- if (!LoopGuards)
- LoopGuards.emplace(LoopGuards::collect(L, *this));
- return *LoopGuards;
+ std::optional<LoopGuards> CachedLoopGuards;
+ auto GetLoopGuards = [&CachedLoopGuards, &L, this]() -> const LoopGuards & {
+ if (!CachedLoopGuards)
+ CachedLoopGuards.emplace(LoopGuards::collect(L, *this));
+ return *CachedLoopGuards;
};
// Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
@@ -8901,10 +8901,10 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
MaxBECount, MaxOrZero);
}
-ScalarEvolution::ExitLimit
-ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
- std::function<LoopGuards()> GetLoopGuards,
- bool IsOnlyExit, bool AllowPredicates) {
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(
+ const Loop *L, BasicBlock *ExitingBlock,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool IsOnlyExit,
+ bool AllowPredicates) {
assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
// If our exiting block does not dominate the latch, then its connection with
// loop's exit limit may be far from trivial.
@@ -8942,8 +8942,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
- const Loop *L, Value *ExitCond, std::function<LoopGuards()> GetLoopGuards,
- bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) {
+ const Loop *L, Value *ExitCond,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
+ bool ControlsOnlyExit, bool AllowPredicates) {
ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
return computeExitLimitFromCondCached(Cache, L, ExitCond, GetLoopGuards,
ExitIfTrue, ControlsOnlyExit,
@@ -8984,7 +8985,7 @@ void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
@@ -9000,7 +9001,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
// Handle BinOp conditions (And, Or).
if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
@@ -9066,7 +9067,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
std::optional<ScalarEvolution::ExitLimit>
ScalarEvolution::computeExitLimitFromCondFromBinOp(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
// Check if the controlling expression for this loop is an And or Or.
Value *Op0, *Op1;
@@ -9147,7 +9148,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
const Loop *L, ICmpInst *ExitCond,
- std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Pred;
@@ -9176,7 +9177,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
- std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit,
bool AllowPredicates) {
// Try to evaluate any dependencies out of the loop.
@@ -9350,7 +9351,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromSingleExitSwitch(
const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock,
- std::function<LoopGuards()> GetLoopGuards, bool ControlsOnlyExit) {
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit) {
assert(!L->contains(ExitingBlock) && "Not an exiting block!");
// Give up if the exit is the default dest of a switch.
@@ -10504,7 +10505,7 @@ SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
ScalarEvolution::ExitLimit
ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
- std::function<LoopGuards()> GetLoopGuards,
+ function_ref<const LoopGuards &()> GetLoopGuards,
bool ControlsOnlyExit, bool AllowPredicates) {
// This is only used for loops with a "x != y" exit test. The exit condition
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 62e6d541af5c65..1cf5aca2266c61 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -1335,11 +1335,12 @@ static bool optimizeLoopExitWithUnknownExitCount(
Visited.insert(OldCond);
Worklist.push_back(OldCond);
- std::optional<ScalarEvolution::LoopGuards> LoopGuards;
- auto GetLoopGuards = [&LoopGuards, &L, &SE]() {
- if (!LoopGuards)
- LoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
- return *LoopGuards;
+ std::optional<ScalarEvolution::LoopGuards> CachedLoopGuards;
+ auto GetLoopGuards = [&CachedLoopGuards, &L,
+ &SE]() -> const ScalarEvolution::LoopGuards & {
+ if (!CachedLoopGuards)
+ CachedLoopGuards.emplace(ScalarEvolution::LoopGuards::collect(L, *SE));
+ return *CachedLoopGuards;
};
auto GoThrough = [&](Value *V) {
>From a10aad8c6acfcb879f56ab9267f61d0b2c899939 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Thu, 21 Nov 2024 18:40:21 +0800
Subject: [PATCH 4/4] Use reference again and use in howManyLessThans
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 4 +++-
llvm/lib/Analysis/ScalarEvolution.cpp | 21 ++++++++++----------
2 files changed, 14 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c67cbefd7fb927..756fd0eeeef74a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1960,7 +1960,9 @@ class ScalarEvolution {
/// If \p AllowPredicates is set, this call will try to use a minimal set of
/// SCEV predicates in order to return an exact answer.
ExitLimit howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L,
- bool isSigned, bool ControlsOnlyExit,
+ bool isSigned,
+ function_ref<const LoopGuards &()> GetLoopGuards,
+ bool ControlsOnlyExit,
bool AllowPredicates = false);
ExitLimit howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L,
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d4c4a127868729..51ac06121c9c12 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9317,8 +9317,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { // while (X < Y)
bool IsSigned = ICmpInst::isSigned(Pred);
- ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
- AllowPredicates);
+ ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, GetLoopGuards,
+ ControlsOnlyExit, AllowPredicates);
if (EL.hasAnyInfo())
return EL;
break;
@@ -10569,7 +10569,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L,
if (!isLoopInvariant(Step, L))
return getCouldNotCompute();
- LoopGuards Guards = GetLoopGuards();
+ const LoopGuards &Guards = GetLoopGuards();
// Specialize step for this loop so we get context sensitive facts below.
const SCEV *StepWLG = applyLoopGuards(Step, Guards);
@@ -12928,10 +12928,10 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
getConstant(StrideForMaxBECount) /* Step */);
}
-ScalarEvolution::ExitLimit
-ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
- const Loop *L, bool IsSigned,
- bool ControlsOnlyExit, bool AllowPredicates) {
+ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(
+ const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
+ function_ref<const LoopGuards &()> GetLoopGuards, bool ControlsOnlyExit,
+ bool AllowPredicates) {
SmallVector<const SCEVPredicate *> Predicates;
const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
@@ -12965,7 +12965,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
Limit = Limit.zext(OuterBitWidth);
- return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
+ return getUnsignedRangeMax(applyLoopGuards(RHS, GetLoopGuards()))
+ .ule(Limit);
};
auto Flags = AR->getNoWrapFlags();
if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
@@ -13216,8 +13217,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
- const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
- const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
+ const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, GetLoopGuards());
+ const SCEV *GuardedStart = applyLoopGuards(OrigStart, GetLoopGuards());
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
More information about the llvm-commits
mailing list