[llvm] [Analysis] Add getPredicatedExitCount to ScalarEvolution (PR #105649)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 27 06:21:31 PDT 2024
https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/105649
>From d744f9fc22bec096470eeeac8a0c88e582be6e5e Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 22 Aug 2024 12:29:40 +0000
Subject: [PATCH 1/3] [Analysis] Add getPredicatedExitCount to ScalarEvolution
Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.
The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 25 +++++---
llvm/lib/Analysis/ScalarEvolution.cpp | 62 +++++++++++++++----
.../Analysis/ScalarEvolutionTest.cpp | 52 ++++++++++++++++
3 files changed, 121 insertions(+), 18 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5154e2f6659c12..03fb11993448e5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
ExitCountKind Kind = Exact);
+ /// Same as above except this uses the predicated backedge taken info and
+ /// may require predicates.
+ const SCEV *
+ getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVector<const SCEVPredicate *, 4> *Predicates,
+ ExitCountKind Kind = Exact);
+
/// If the specified loop has a predictable backedge-taken count, return it,
/// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
/// the number of times the loop header will be branched to from within the
@@ -1562,16 +1569,19 @@ class ScalarEvolution {
/// Return the number of times this loop exit may fall through to the back
/// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
/// this block before this number of iterations, but may exit via another
- /// block.
- const SCEV *getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ /// block. If \p Predicates is null the function returns CouldNotCompute if
+ /// predicates are required, otherwise it fills in the required predicates.
+ const SCEV *
+ getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Get the constant max backedge taken count for the loop.
const SCEV *getConstantMax(ScalarEvolution *SE) const;
/// Get the constant max backedge taken count for the particular loop exit.
- const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getConstantMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Get the symbolic max backedge taken count for the loop.
const SCEV *
@@ -1579,8 +1589,9 @@ class ScalarEvolution {
SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
/// Get the symbolic max backedge taken count for the particular loop exit.
- const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getSymbolicMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
/// Return true if the number of times this backedge is taken is either the
/// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a19358dee8ef49..3726ff323630ab 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8249,6 +8249,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+ const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+ switch (Kind) {
+ case Exact:
+ return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+ Predicates);
+ case SymbolicMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+ Predicates);
+ case ConstantMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+ Predicates);
+ };
+ llvm_unreachable("Invalid ExitCountKind!");
+}
+
const SCEV *
ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
SmallVector<const SCEVPredicate *, 4> &Preds) {
@@ -8578,30 +8595,53 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
}
/// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ExactNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.ExactNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.ExactNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ConstantMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.ConstantMaxNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.ConstantMaxNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVector<const SCEVPredicate *, 4> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.SymbolicMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return ENT.SymbolicMaxNotTaken;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return ENT.SymbolicMaxNotTaken;
+ }
+ }
return SE->getCouldNotCompute();
}
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index d4d90d80f4cea1..a9bd4789707012 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,4 +1707,56 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
});
}
+TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
+ LLVMContext C;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(R"(
+define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+entry:
+ %cmp7 = icmp sgt i64 %end, 0
+ br i1 %cmp7, label %for.body, label %exit
+
+for.body:
+ %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
+ %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
+ %0 = load i32, ptr %arrayidx, align 4
+ %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
+ %1 = load i32, ptr %arrayidx3, align 4
+ %add = add i32 %1, %0
+ store i32 %add, ptr %arrayidx3, align 4
+ %inc = add i16 %i.08, 1
+ %conv = zext i16 %inc to i64
+ %cmp = icmp ult i64 %conv, %end
+ br i1 %cmp, label %for.body, label %exit
+
+exit:
+ ret void
+})",
+ Err, C);
+
+ ASSERT_TRUE(M && "Could not parse module?");
+ ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+ runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ BasicBlock &EntryBB = F.getEntryBlock();
+ BasicBlock *ForBodyBB = nullptr;
+ Loop *Loop = nullptr;
+ for (BasicBlock *Succ : successors(&EntryBB)) {
+ Loop = LI.getLoopFor(Succ);
+ if (Loop) {
+ ForBodyBB = Loop->getHeader();
+ break;
+ }
+ }
+ ASSERT_TRUE(Loop && "Couldn't find the loop!");
+ ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ const SCEV *ExitCount = SE.getPredicatedExitCount(
+ Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
+ ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+ ASSERT_FALSE(Predicates.empty());
+ });
+}
+
} // end namespace llvm
>From 7067096b362ed616e09517f05323a2a9eddb77df Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 22 Aug 2024 14:15:11 +0000
Subject: [PATCH 2/3] Use SmallVectorImpl instead of SmallVector
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 12 ++++++------
llvm/lib/Analysis/ScalarEvolution.cpp | 8 ++++----
2 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 03fb11993448e5..f2d63f6b89a9de 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -875,7 +875,7 @@ class ScalarEvolution {
/// may require predicates.
const SCEV *
getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
- SmallVector<const SCEVPredicate *, 4> *Predicates,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates,
ExitCountKind Kind = Exact);
/// If the specified loop has a predictable backedge-taken count, return it,
@@ -1571,9 +1571,9 @@ class ScalarEvolution {
/// this block before this number of iterations, but may exit via another
/// block. If \p Predicates is null the function returns CouldNotCompute if
/// predicates are required, otherwise it fills in the required predicates.
- const SCEV *
- getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+ const SCEV *getExact(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
/// Get the constant max backedge taken count for the loop.
const SCEV *getConstantMax(ScalarEvolution *SE) const;
@@ -1581,7 +1581,7 @@ class ScalarEvolution {
/// Get the constant max backedge taken count for the particular loop exit.
const SCEV *getConstantMax(
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
/// Get the symbolic max backedge taken count for the loop.
const SCEV *
@@ -1591,7 +1591,7 @@ class ScalarEvolution {
/// Get the symbolic max backedge taken count for the particular loop exit.
const SCEV *getSymbolicMax(
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
/// Return true if the number of times this backedge is taken is either the
/// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3726ff323630ab..8416683cbe7407 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8251,7 +8251,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
const SCEV *ScalarEvolution::getPredicatedExitCount(
const Loop *L, const BasicBlock *ExitingBlock,
- SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+ SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
switch (Kind) {
case Exact:
return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
@@ -8597,7 +8597,7 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
/// Get the exact not taken count for this loop exit.
const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+ SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock) {
if (ENT.hasAlwaysTruePredicate())
@@ -8614,7 +8614,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+ SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock) {
if (ENT.hasAlwaysTruePredicate())
@@ -8631,7 +8631,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
const BasicBlock *ExitingBlock, ScalarEvolution *SE,
- SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+ SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock) {
if (ENT.hasAlwaysTruePredicate())
>From 7a83aeb1c3c300a2c5135986531d2ff393d3b148 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Tue, 27 Aug 2024 13:19:27 +0000
Subject: [PATCH 3/3] Address review comments
* Print out the predicated exact and symbolic exit counts for
blocks if the unpredicated exit count cannot be computed.
* Add unit tests for the symbolic and constant maximums.
---
llvm/lib/Analysis/ScalarEvolution.cpp | 30 +++++++++-
.../ScalarEvolution/exit-count-non-strict.ll | 22 +++++++
...cated-symbolic-max-backedge-taken-count.ll | 16 +++++
.../Analysis/ScalarEvolutionTest.cpp | 58 ++++++++++++++++++-
4 files changed, 123 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 8416683cbe7407..e8e53046e1713d 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13685,7 +13685,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (ExitingBlocks.size() > 1)
for (BasicBlock *ExitingBlock : ExitingBlocks) {
OS << " exit count for " << ExitingBlock->getName() << ": ";
- PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
+ const SCEV *EC = SE->getExitCount(L, ExitingBlock);
+ PrintSCEVWithTypeHint(OS, EC);
+ if (isa<SCEVCouldNotCompute>(EC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
+ if (!isa<SCEVCouldNotCompute>(EC)) {
+ OS << "\n predicated exit count for " << ExitingBlock->getName()
+ << ": ";
+ PrintSCEVWithTypeHint(OS, EC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}
@@ -13725,6 +13739,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
ScalarEvolution::SymbolicMaximum);
PrintSCEVWithTypeHint(OS, ExitBTC);
+ if (isa<SCEVCouldNotCompute>(ExitBTC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
+ ScalarEvolution::SymbolicMaximum);
+ if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
+ OS << "\n predicated symbolic max exit count for "
+ << ExitingBlock->getName() << ": ";
+ PrintSCEVWithTypeHint(OS, ExitBTC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}
diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
index e9faf98eee4492..6d64f76494638f 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
@@ -93,14 +93,25 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
; CHECK-NEXT: Determining loop execution counts for: @ule_from_zero_no_nuw
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: exit count for latch: %N
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %N
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: symbolic max exit count for latch: %N
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
;
entry:
br label %loop
@@ -211,14 +222,25 @@ define void @sle_from_int_min_no_nsw(i32 %M, i32 %N) {
; CHECK-NEXT: Determining loop execution counts for: @sle_from_int_min_no_nsw
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
; CHECK-NEXT: exit count for latch: (-2147483648 + %N)
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-2147483648 + %N)
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
; CHECK-NEXT: symbolic max exit count for latch: (-2147483648 + %N)
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
;
entry:
br label %loop
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index 8dc79a54eb97a5..2ec6158e9b0920 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -8,10 +8,18 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
@@ -51,10 +59,18 @@ define void @test2(i64 %x, ptr %a) {
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index a9bd4789707012..5ec4a676f69d89 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1711,7 +1711,7 @@ TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
LLVMContext C;
SMDiagnostic Err;
std::unique_ptr<Module> M = parseAssemblyString(R"(
-define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+define void @foo1(ptr %dest, ptr %src, i64 noundef %end) {
entry:
%cmp7 = icmp sgt i64 %end, 0
br i1 %cmp7, label %for.body, label %exit
@@ -1732,13 +1732,42 @@ for.body:
exit:
ret void
+}
+
+define i32 @foo2(ptr nocapture noundef %src1, ptr nocapture noundef readonly %src2, i8 noundef %start, i32 %end) {
+entry:
+ %st = zext i8 %start to i16
+ %ext = zext i8 %start to i32
+ %ecmp = icmp ult i16 %st, 42
+ br i1 %ecmp, label %for.body, label %exit
+
+for.body:
+ %i.08 = phi i8 [ %inc, %for.inc ], [ 0, %entry ]
+ %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %i.08
+ %0 = load i32, ptr %arrayidx, align 4
+ %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %i.08
+ %1 = load i32, ptr %arrayidx3, align 4
+ %cmp.early = icmp eq i32 %0, %1
+ br i1 %cmp.early, label %found, label %for.inc
+
+for.inc:
+ %inc = add i8 %i.08, 1
+ %conv = zext i8 %inc to i32
+ %cmp = icmp ult i32 %conv, %end
+ br i1 %cmp, label %for.body, label %exit
+
+found:
+ ret i32 1
+
+exit:
+ ret i32 0
})",
Err, C);
ASSERT_TRUE(M && "Could not parse module?");
ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
- runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ runWithSE(*M, "foo1", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
BasicBlock &EntryBB = F.getEntryBlock();
BasicBlock *ForBodyBB = nullptr;
Loop *Loop = nullptr;
@@ -1757,6 +1786,31 @@ for.body:
ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
ASSERT_FALSE(Predicates.empty());
});
+
+ runWithSE(*M, "foo2", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ BasicBlock &EntryBB = F.getEntryBlock();
+ BasicBlock *ForIncBB = nullptr;
+ Loop *Loop = nullptr;
+ for (BasicBlock *Succ : successors(&EntryBB)) {
+ Loop = LI.getLoopFor(Succ);
+ if (Loop) {
+ ForIncBB = Loop->getLoopLatch();
+ break;
+ }
+ }
+ ASSERT_TRUE(Loop && "Couldn't find the loop!");
+ ASSERT_TRUE(ForIncBB && "Couldn't find the loop header!");
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ const SCEV *ExitCount = SE.getPredicatedExitCount(
+ Loop, ForIncBB, &Predicates, ScalarEvolution::SymbolicMaximum);
+ ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+ ASSERT_FALSE(Predicates.empty());
+ ExitCount = SE.getPredicatedExitCount(Loop, ForIncBB, &Predicates,
+ ScalarEvolution::ConstantMaximum);
+ ASSERT_TRUE(isa<SCEVConstant>(ExitCount));
+ ASSERT_TRUE(cast<SCEVConstant>(ExitCount)->getAPInt().getSExtValue() ==
+ -2ll);
+ });
}
} // end namespace llvm
More information about the llvm-commits
mailing list