[llvm] [Analysis] Add getPredicatedExitCount to ScalarEvolution (PR #105649)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 22 05:39:35 PDT 2024


https://github.com/david-arm created https://github.com/llvm/llvm-project/pull/105649

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.


>From d744f9fc22bec096470eeeac8a0c88e582be6e5e Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 22 Aug 2024 12:29:40 +0000
Subject: [PATCH] [Analysis] Add getPredicatedExitCount to ScalarEvolution

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 25 +++++---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 62 +++++++++++++++----
 .../Analysis/ScalarEvolutionTest.cpp          | 52 ++++++++++++++++
 3 files changed, 121 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5154e2f6659c12..03fb11993448e5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
   const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
                            ExitCountKind Kind = Exact);
 
+  /// Same as above except this uses the predicated backedge taken info and
+  /// may require predicates.
+  const SCEV *
+  getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+                         SmallVector<const SCEVPredicate *, 4> *Predicates,
+                         ExitCountKind Kind = Exact);
+
   /// If the specified loop has a predictable backedge-taken count, return it,
   /// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
   /// the number of times the loop header will be branched to from within the
@@ -1562,16 +1569,19 @@ class ScalarEvolution {
     /// Return the number of times this loop exit may fall through to the back
     /// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
     /// this block before this number of iterations, but may exit via another
-    /// block.
-    const SCEV *getExact(const BasicBlock *ExitingBlock,
-                         ScalarEvolution *SE) const;
+    /// block. If \p Predicates is null the function returns CouldNotCompute if
+    /// predicates are required, otherwise it fills in the required predicates.
+    const SCEV *
+    getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+             SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
 
     /// Get the constant max backedge taken count for the particular loop exit.
-    const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getConstantMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *
@@ -1579,8 +1589,9 @@ class ScalarEvolution {
                    SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
 
     /// Get the symbolic max backedge taken count for the particular loop exit.
-    const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getSymbolicMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index a19358dee8ef49..3726ff323630ab 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8249,6 +8249,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
   llvm_unreachable("Invalid ExitCountKind!");
 }
 
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+    const Loop *L, const BasicBlock *ExitingBlock,
+    SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+  switch (Kind) {
+  case Exact:
+    return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+                                                      Predicates);
+  case SymbolicMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+                                                            Predicates);
+  case ConstantMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+                                                            Predicates);
+  };
+  llvm_unreachable("Invalid ExitCountKind!");
+}
+
 const SCEV *
 ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
                                                  SmallVector<const SCEVPredicate *, 4> &Preds) {
@@ -8578,30 +8595,53 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
 }
 
 /// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
-                                             ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ExactNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ExactNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ExactNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ConstantMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ConstantMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ConstantMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.SymbolicMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.SymbolicMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.SymbolicMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index d4d90d80f4cea1..a9bd4789707012 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,4 +1707,56 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(R"(
+define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+entry:
+  %cmp7 = icmp sgt i64 %end, 0
+  br i1 %cmp7, label %for.body, label %exit
+
+for.body:
+  %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
+  %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
+  %1 = load i32, ptr %arrayidx3, align 4
+  %add = add i32 %1, %0
+  store i32 %add, ptr %arrayidx3, align 4
+  %inc = add i16 %i.08, 1
+  %conv = zext i16 %inc to i64
+  %cmp = icmp ult i64 %conv, %end
+  br i1 %cmp, label %for.body, label %exit
+
+exit:
+  ret void
+})",
+                                                  Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    BasicBlock &EntryBB = F.getEntryBlock();
+    BasicBlock *ForBodyBB = nullptr;
+    Loop *Loop = nullptr;
+    for (BasicBlock *Succ : successors(&EntryBB)) {
+      Loop = LI.getLoopFor(Succ);
+      if (Loop) {
+        ForBodyBB = Loop->getHeader();
+        break;
+      }
+    }
+    ASSERT_TRUE(Loop && "Couldn't find the loop!");
+    ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
+    SmallVector<const SCEVPredicate *, 4> Predicates;
+    const SCEV *ExitCount = SE.getPredicatedExitCount(
+        Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
+    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+    ASSERT_FALSE(Predicates.empty());
+  });
+}
+
 }  // end namespace llvm



More information about the llvm-commits mailing list