[llvm] df3d70b - [Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 2 06:05:29 PDT 2024
Author: David Sherwood
Date: 2024-09-02T14:05:26+01:00
New Revision: df3d70b5a72fee43af3793c8b7a138bd44cac8cf
URL: https://github.com/llvm/llvm-project/commit/df3d70b5a72fee43af3793c8b7a138bd44cac8cf
DIFF: https://github.com/llvm/llvm-project/commit/df3d70b5a72fee43af3793c8b7a138bd44cac8cf.diff
LOG: [Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)
Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.
New tests added here:
Analysis/ScalarEvolution/predicated-exit-count.ll
Added:
llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
Modified:
llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index fe46a504bce5d1..89f9395959779d 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
ExitCountKind Kind = Exact);
+ /// Same as above except this uses the predicated backedge taken info and
+ /// may require predicates.
+ const SCEV *
+ getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates,
+ ExitCountKind Kind = Exact);
+
/// If the specified loop has a predictable backedge-taken count, return it,
/// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
/// the number of times the loop header will be branched to from within the
@@ -1517,6 +1524,10 @@ class ScalarEvolution {
bool isComplete() const { return IsComplete; }
const SCEV *getConstantMax() const { return ConstantMax; }
+ const ExitNotTakenInfo *getExitNotTaken(
+ const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+
public:
BackedgeTakenInfo() = default;
BackedgeTakenInfo(BackedgeTakenInfo &&) = default;
@@ -1563,16 +1574,29 @@ class ScalarEvolution {
/// Return the number of times this loop exit may fall through to the back
/// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
/// this block before this number of iterations, but may exit via another
- /// block.
- const SCEV *getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ /// block. If \p Predicates is null the function returns CouldNotCompute if
+ /// predicates are required, otherwise it fills in the required predicates.
+ const SCEV *getExact(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+ if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+ return ENT->ExactNotTaken;
+ else
+ return SE->getCouldNotCompute();
+ }
/// Get the constant max backedge taken count for the loop.
const SCEV *getConstantMax(ScalarEvolution *SE) const;
/// Get the constant max backedge taken count for the particular loop exit.
- const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getConstantMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+ if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+ return ENT->ConstantMaxNotTaken;
+ else
+ return SE->getCouldNotCompute();
+ }
/// Get the symbolic max backedge taken count for the loop.
const SCEV *getSymbolicMax(
@@ -1580,8 +1604,14 @@ class ScalarEvolution {
SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);
/// Get the symbolic max backedge taken count for the particular loop exit.
- const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const;
+ const SCEV *getSymbolicMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+ if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+ return ENT->SymbolicMaxNotTaken;
+ else
+ return SE->getCouldNotCompute();
+ }
/// Return true if the number of times this backedge is taken is either the
/// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 54dde8401cdff0..6b4a81c217b3c2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
llvm_unreachable("Invalid ExitCountKind!");
}
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+ const Loop *L, const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
+ switch (Kind) {
+ case Exact:
+ return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+ Predicates);
+ case SymbolicMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+ Predicates);
+ case ConstantMaximum:
+ return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+ Predicates);
+ };
+ llvm_unreachable("Invalid ExitCountKind!");
+}
+
const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
@@ -8574,33 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
}
-/// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const {
- for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ExactNotTaken;
-
- return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+const ScalarEvolution::ExitNotTakenInfo *
+ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
+ const BasicBlock *ExitingBlock,
+ SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.ConstantMaxNotTaken;
-
- return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
- const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
- for (const auto &ENT : ExitNotTaken)
- if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
- return ENT.SymbolicMaxNotTaken;
+ if (ENT.ExitingBlock == ExitingBlock) {
+ if (ENT.hasAlwaysTruePredicate())
+ return &ENT;
+ else if (Predicates) {
+ for (const auto *P : ENT.Predicates)
+ Predicates->push_back(P);
+ return &ENT;
+ }
+ }
- return SE->getCouldNotCompute();
+ return nullptr;
}
/// getConstantMax - Get the constant max backedge taken count for the loop.
@@ -13642,7 +13648,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
if (ExitingBlocks.size() > 1)
for (BasicBlock *ExitingBlock : ExitingBlocks) {
OS << " exit count for " << ExitingBlock->getName() << ": ";
- PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
+ const SCEV *EC = SE->getExitCount(L, ExitingBlock);
+ PrintSCEVWithTypeHint(OS, EC);
+ if (isa<SCEVCouldNotCompute>(EC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
+ if (!isa<SCEVCouldNotCompute>(EC)) {
+ OS << "\n predicated exit count for " << ExitingBlock->getName()
+ << ": ";
+ PrintSCEVWithTypeHint(OS, EC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}
@@ -13682,6 +13702,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
ScalarEvolution::SymbolicMaximum);
PrintSCEVWithTypeHint(OS, ExitBTC);
+ if (isa<SCEVCouldNotCompute>(ExitBTC)) {
+ // Retry with predicates.
+ SmallVector<const SCEVPredicate *, 4> Predicates;
+ ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
+ ScalarEvolution::SymbolicMaximum);
+ if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
+ OS << "\n predicated symbolic max exit count for "
+ << ExitingBlock->getName() << ": ";
+ PrintSCEVWithTypeHint(OS, ExitBTC);
+ OS << "\n Predicates:\n";
+ for (const auto *P : Predicates)
+ P->print(OS, 4);
+ }
+ }
OS << "\n";
}
diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
index e9faf98eee4492..6d64f76494638f 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
@@ -93,14 +93,25 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
; CHECK-NEXT: Determining loop execution counts for: @ule_from_zero_no_nuw
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: exit count for latch: %N
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is %N
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: symbolic max exit count for latch: %N
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {0,+,1}<%loop> Added Flags: <nusw>
;
entry:
br label %loop
@@ -211,14 +222,25 @@ define void @sle_from_int_min_no_nsw(i32 %M, i32 %N) {
; CHECK-NEXT: Determining loop execution counts for: @sle_from_int_min_no_nsw
; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
; CHECK-NEXT: exit count for latch: (-2147483648 + %N)
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i32 -1
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is (-2147483648 + %N)
; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
; CHECK-NEXT: symbolic max exit count for latch: (-2147483648 + %N)
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {-2147483648,+,1}<%loop> Added Flags: <nssw>
;
entry:
br label %loop
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
new file mode 100644
index 00000000000000..de214183710ab3
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 < %s 2>&1 | FileCheck %s
+
+
+define i32 @multiple_exits_with_predicates(ptr %src1, ptr readonly %src2, i32 %end) {
+; CHECK-LABEL: 'multiple_exits_with_predicates'
+; CHECK-NEXT: Determining loop execution counts for: @multiple_exits_with_predicates
+; CHECK-NEXT: Loop %for.body: <multiple exits> Unpredictable backedge-taken count.
+; CHECK-NEXT: exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for for.body: i32 1023
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT: exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT: Loop %for.body: Unpredictable constant max backedge-taken count.
+; CHECK-NEXT: Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT: symbolic max exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for for.body: i32 1023
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT: symbolic max exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: symbolic max exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT: Loop %for.body: Predicated symbolic max backedge-taken count is (1023 umin (-1 + (1 umax %end)))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-NEXT: {1,+,1}<%for.body> Added Flags: <nusw>
+;
+entry:
+ br label %for.body
+
+for.body:
+ %index = phi i8 [ %index.next, %for.inc ], [ 0, %entry ]
+ %index.next = add i8 %index, 1
+ %conv = zext i8 %index.next to i32
+ %cmp.body = icmp ne i32 %conv, 1024
+ br i1 %cmp.body, label %for.work, label %exit
+
+for.work:
+ %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %index
+ %0 = load i32, ptr %arrayidx, align 4
+ %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %index
+ %1 = load i32, ptr %arrayidx3, align 4
+ %cmp.work = icmp eq i32 %0, %1
+ br i1 %cmp.work, label %found, label %for.inc
+
+for.inc:
+ %cmp.inc = icmp ult i32 %conv, %end
+ br i1 %cmp.inc, label %for.body, label %exit
+
+found:
+ ret i32 1
+
+exit:
+ ret i32 0
+}
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index 8dc79a54eb97a5..2ec6158e9b0920 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -8,10 +8,18 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
@@ -51,10 +59,18 @@ define void @test2(i64 %x, ptr %a) {
; CHECK-NEXT: Loop %header: <multiple exits> Unpredictable backedge-taken count.
; CHECK-NEXT: exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %header: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: symbolic max exit count for header: ***COULDNOTCOMPUTE***
; CHECK-NEXT: symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT: predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT: Predicates:
+; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
; CHECK-NEXT: Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
; CHECK-NEXT: Predicates:
; CHECK-NEXT: {1,+,1}<%header> Added Flags: <nusw>
More information about the llvm-commits
mailing list