[llvm] [SCEV] Add predicated version of getSymbolicMaxBackedgeTakenCount. (PR #93498)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Tue May 28 14:39:04 PDT 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/93498
>From 68c2420c4bfe6f28094a0b8b576800c7219b6f10 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 27 May 2024 21:17:21 -0700
Subject: [PATCH] [SCEV] Add predicated version of
getSymbolicMaxBackedgeTakenCount.
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 15 +++++-
llvm/lib/Analysis/ScalarEvolution.cpp | 48 +++++++++++++++++--
...cated-symbolic-max-backedge-taken-count.ll | 6 +++
3 files changed, 63 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 1d016b28347d2..d9e10dd3150a3 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -912,6 +912,9 @@ class ScalarEvolution {
return getBackedgeTakenCount(L, SymbolicMaximum);
}
+ const SCEV *getPredicatedSymbolicMaxBackedgeTakenCount(
+ const Loop *L, SmallVector<const SCEVPredicate *, 4> &Predicates);
+
/// Return true if the backedge taken count is either the value returned by
/// getConstantMaxBackedgeTakenCount or zero.
bool isBackedgeTakenCountMaxOrZero(const Loop *L);
@@ -1549,7 +1552,9 @@ class ScalarEvolution {
ScalarEvolution *SE) const;
/// Get the symbolic max backedge taken count for the loop.
- const SCEV *getSymbolicMax(const Loop *L, ScalarEvolution *SE);
+ const SCEV *
+ getSymbolicMax(const Loop *L, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
/// Get the symbolic max backedge taken count for the particular loop exit.
const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
@@ -1746,7 +1751,7 @@ class ScalarEvolution {
/// Similar to getBackedgeTakenInfo, but will add predicates as required
/// with the purpose of returning complete information.
- const BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);
+ BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);
/// Compute the number of times the specified loop will iterate.
/// If AllowPredicates is set, we will create new SCEV predicates as
@@ -2311,6 +2316,9 @@ class PredicatedScalarEvolution {
/// Get the (predicated) backedge count for the analyzed loop.
const SCEV *getBackedgeTakenCount();
+ /// Get the (predicated) symbolic max backedge count for the analyzed loop.
+ const SCEV *getSymbolicMaxBackedgeTakenCount();
+
/// Adds a new predicate.
void addPredicate(const SCEVPredicate &Pred);
@@ -2379,6 +2387,9 @@ class PredicatedScalarEvolution {
/// The backedge taken count.
const SCEV *BackedgeCount = nullptr;
+
+ /// The symbolic backedge taken count.
+ const SCEV *SymbolicMaxBackedgeCount = nullptr;
};
template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index bb56b41fe15d5..e46d7183a2a35 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8295,6 +8295,11 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}
+const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount(
+ const Loop *L, SmallVector<const SCEVPredicate *, 4> &Preds) {
+ return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
+}
+
bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
}
@@ -8311,7 +8316,7 @@ static void PushLoopPHIs(const Loop *L,
Worklist.push_back(&PN);
}
-const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
auto &BTI = getBackedgeTakenInfo(L);
if (BTI.hasFullInfo())
@@ -8644,9 +8649,9 @@ ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
return getConstantMax();
}
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
- ScalarEvolution *SE) {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
+ const Loop *L, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) {
if (!SymbolicMax) {
// Form an expression for the maximum exit count possible for this loop. We
// merge the max and exact information to approximate a version of
@@ -8661,6 +8666,12 @@ ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
"We should only have known counts for exiting blocks that "
"dominate latch!");
ExitCounts.push_back(ExitCount);
+ if (Predicates)
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+
+ assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
+ "Predicate should be always true!");
}
}
if (ExitCounts.empty())
@@ -13609,6 +13620,24 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
P->print(OS, 4);
}
+ Preds.clear();
+ auto *PredSymbolicMax =
+ SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds);
+ if (SymbolicBTC != PredSymbolicMax) {
+ OS << "Loop ";
+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+ OS << ": ";
+ if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
+ OS << "Predicated symbolic max backedge-taken count is ";
+ PrintSCEVWithTypeHint(OS, PredSymbolicMax);
+ } else
+ OS << "Unpredictable predicated symbolic max backedge-taken count.";
+ OS << "\n";
+ OS << " Predicates:\n";
+ for (const auto *P : Preds)
+ P->print(OS, 4);
+ }
+
if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -14822,6 +14851,17 @@ const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
return BackedgeCount;
}
+const SCEV *PredicatedScalarEvolution::getSymbolicMaxBackedgeTakenCount() {
+ if (!SymbolicMaxBackedgeCount) {
+ SmallVector<const SCEVPredicate *, 4> Preds;
+ SymbolicMaxBackedgeCount =
+ SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
+ for (const auto *P : Preds)
+ addPredicate(*P);
+ }
+ return SymbolicMaxBackedgeCount;
+}
+
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
if (Preds->implies(&Pred))
return;
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index d40416359b65c..8dc79a54eb97a 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -12,6 +12,9 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
;
entry:
br label %header
@@ -52,6 +55,9 @@ define void @test2(i64 %x, ptr %a) {
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
;
entry:
br label %header
More information about the llvm-commits
mailing list