[llvm] [Analysis] Add getPredicatedExitCount to ScalarEvolution (PR #105649)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 30 08:26:29 PDT 2024


https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/105649

>From 0581251be50caa5dc4a88e2e3863faa44c1f2236 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 22 Aug 2024 12:29:40 +0000
Subject: [PATCH 1/5] [Analysis] Add getPredicatedExitCount to ScalarEvolution

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

The only way to test this patch is via unit tests that I have
added to unittests/Analysis/ScalarEvolutionTest.cpp.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 25 +++++---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 62 +++++++++++++++----
 .../Analysis/ScalarEvolutionTest.cpp          | 52 ++++++++++++++++
 3 files changed, 121 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index fe46a504bce5d1..5095cb661f36ce 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
   const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
                            ExitCountKind Kind = Exact);
 
+  /// Same as above except this uses the predicated backedge taken info and
+  /// may require predicates.
+  const SCEV *
+  getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+                         SmallVector<const SCEVPredicate *, 4> *Predicates,
+                         ExitCountKind Kind = Exact);
+
   /// If the specified loop has a predictable backedge-taken count, return it,
   /// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
   /// the number of times the loop header will be branched to from within the
@@ -1563,16 +1570,19 @@ class ScalarEvolution {
     /// Return the number of times this loop exit may fall through to the back
     /// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
     /// this block before this number of iterations, but may exit via another
-    /// block.
-    const SCEV *getExact(const BasicBlock *ExitingBlock,
-                         ScalarEvolution *SE) const;
+    /// block. If \p Predicates is null the function returns CouldNotCompute if
+    /// predicates are required, otherwise it fills in the required predicates.
+    const SCEV *
+    getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+             SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
 
     /// Get the constant max backedge taken count for the particular loop exit.
-    const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getConstantMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *getSymbolicMax(
@@ -1580,8 +1590,9 @@ class ScalarEvolution {
         SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);
 
     /// Get the symbolic max backedge taken count for the particular loop exit.
-    const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getSymbolicMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 54dde8401cdff0..3c0d154d31abb8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
   llvm_unreachable("Invalid ExitCountKind!");
 }
 
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+    const Loop *L, const BasicBlock *ExitingBlock,
+    SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+  switch (Kind) {
+  case Exact:
+    return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+                                                      Predicates);
+  case SymbolicMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+                                                            Predicates);
+  case ConstantMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+                                                            Predicates);
+  };
+  llvm_unreachable("Invalid ExitCountKind!");
+}
+
 const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
     const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
   return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
@@ -8575,30 +8592,53 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
 }
 
 /// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
-                                             ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ExactNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ExactNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ExactNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ConstantMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.ConstantMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.ConstantMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.SymbolicMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return ENT.SymbolicMaxNotTaken;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return ENT.SymbolicMaxNotTaken;
+      }
+    }
 
   return SE->getCouldNotCompute();
 }
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index d4d90d80f4cea1..a9bd4789707012 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,4 +1707,56 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
   });
 }
 
+TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(R"(
+define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+entry:
+  %cmp7 = icmp sgt i64 %end, 0
+  br i1 %cmp7, label %for.body, label %exit
+
+for.body:
+  %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
+  %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
+  %1 = load i32, ptr %arrayidx3, align 4
+  %add = add i32 %1, %0
+  store i32 %add, ptr %arrayidx3, align 4
+  %inc = add i16 %i.08, 1
+  %conv = zext i16 %inc to i64
+  %cmp = icmp ult i64 %conv, %end
+  br i1 %cmp, label %for.body, label %exit
+
+exit:
+  ret void
+})",
+                                                  Err, C);
+
+  ASSERT_TRUE(M && "Could not parse module?");
+  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    BasicBlock &EntryBB = F.getEntryBlock();
+    BasicBlock *ForBodyBB = nullptr;
+    Loop *Loop = nullptr;
+    for (BasicBlock *Succ : successors(&EntryBB)) {
+      Loop = LI.getLoopFor(Succ);
+      if (Loop) {
+        ForBodyBB = Loop->getHeader();
+        break;
+      }
+    }
+    ASSERT_TRUE(Loop && "Couldn't find the loop!");
+    ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
+    SmallVector<const SCEVPredicate *, 4> Predicates;
+    const SCEV *ExitCount = SE.getPredicatedExitCount(
+        Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
+    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+    ASSERT_FALSE(Predicates.empty());
+  });
+}
+
 }  // end namespace llvm

>From 0df3907c10d0d81a8fae606368d304f2930c36e4 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Thu, 22 Aug 2024 14:15:11 +0000
Subject: [PATCH 2/5] Use SmallVectorImpl instead of SmallVector

---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 12 ++++++------
 llvm/lib/Analysis/ScalarEvolution.cpp        |  8 ++++----
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 5095cb661f36ce..7e323cb6e1d45f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -875,7 +875,7 @@ class ScalarEvolution {
   /// may require predicates.
   const SCEV *
   getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
-                         SmallVector<const SCEVPredicate *, 4> *Predicates,
+                         SmallVectorImpl<const SCEVPredicate *> *Predicates,
                          ExitCountKind Kind = Exact);
 
   /// If the specified loop has a predictable backedge-taken count, return it,
@@ -1572,9 +1572,9 @@ class ScalarEvolution {
     /// this block before this number of iterations, but may exit via another
     /// block. If \p Predicates is null the function returns CouldNotCompute if
     /// predicates are required, otherwise it fills in the required predicates.
-    const SCEV *
-    getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-             SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+    const SCEV *getExact(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
@@ -1582,7 +1582,7 @@ class ScalarEvolution {
     /// Get the constant max backedge taken count for the particular loop exit.
     const SCEV *getConstantMax(
         const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *getSymbolicMax(
@@ -1592,7 +1592,7 @@ class ScalarEvolution {
     /// Get the symbolic max backedge taken count for the particular loop exit.
     const SCEV *getSymbolicMax(
         const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-        SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr) const;
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3c0d154d31abb8..81048c035f4d7c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8249,7 +8249,7 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
 
 const SCEV *ScalarEvolution::getPredicatedExitCount(
     const Loop *L, const BasicBlock *ExitingBlock,
-    SmallVector<const SCEVPredicate *, 4> *Predicates, ExitCountKind Kind) {
+    SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
   switch (Kind) {
   case Exact:
     return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
@@ -8594,7 +8594,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
 /// Get the exact not taken count for this loop exit.
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
     const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
     if (ENT.ExitingBlock == ExitingBlock) {
       if (ENT.hasAlwaysTruePredicate())
@@ -8611,7 +8611,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
     const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
     if (ENT.ExitingBlock == ExitingBlock) {
       if (ENT.hasAlwaysTruePredicate())
@@ -8628,7 +8628,7 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
 
 const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
     const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-    SmallVector<const SCEVPredicate *, 4> *Predicates) const {
+    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
     if (ENT.ExitingBlock == ExitingBlock) {
       if (ENT.hasAlwaysTruePredicate())

>From 106a02e3257425ae14c9e34612a6e2fe8d8bb79c Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Tue, 27 Aug 2024 13:19:27 +0000
Subject: [PATCH 3/5] Address review comments

* Print out the predicated exact and symbolic exit counts for
blocks if the unpredicated exit count cannot be computed.
* Add unit tests for the symbolic and constant maximums.
---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 30 +++++++++-
 .../ScalarEvolution/exit-count-non-strict.ll  | 22 +++++++
 ...cated-symbolic-max-backedge-taken-count.ll | 16 +++++
 .../Analysis/ScalarEvolutionTest.cpp          | 58 ++++++++++++++++++-
 4 files changed, 123 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 81048c035f4d7c..07012a2b43bd08 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13682,7 +13682,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
   if (ExitingBlocks.size() > 1)
     for (BasicBlock *ExitingBlock : ExitingBlocks) {
       OS << "  exit count for " << ExitingBlock->getName() << ": ";
-      PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
+      const SCEV *EC = SE->getExitCount(L, ExitingBlock);
+      PrintSCEVWithTypeHint(OS, EC);
+      if (isa<SCEVCouldNotCompute>(EC)) {
+        // Retry with predicates.
+        SmallVector<const SCEVPredicate *, 4> Predicates;
+        EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
+        if (!isa<SCEVCouldNotCompute>(EC)) {
+          OS << "\n  predicated exit count for " << ExitingBlock->getName()
+             << ": ";
+          PrintSCEVWithTypeHint(OS, EC);
+          OS << "\n   Predicates:\n";
+          for (const auto *P : Predicates)
+            P->print(OS, 4);
+        }
+      }
       OS << "\n";
     }
 
@@ -13722,6 +13736,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
       auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
                                        ScalarEvolution::SymbolicMaximum);
       PrintSCEVWithTypeHint(OS, ExitBTC);
+      if (isa<SCEVCouldNotCompute>(ExitBTC)) {
+        // Retry with predicates.
+        SmallVector<const SCEVPredicate *, 4> Predicates;
+        ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
+                                             ScalarEvolution::SymbolicMaximum);
+        if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
+          OS << "\n  predicated symbolic max exit count for "
+             << ExitingBlock->getName() << ": ";
+          PrintSCEVWithTypeHint(OS, ExitBTC);
+          OS << "\n   Predicates:\n";
+          for (const auto *P : Predicates)
+            P->print(OS, 4);
+        }
+      }
       OS << "\n";
     }
 
diff --git a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
index e9faf98eee4492..6d64f76494638f 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
@@ -93,14 +93,25 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
 ; CHECK-NEXT:  Determining loop execution counts for: @ule_from_zero_no_nuw
 ; CHECK-NEXT:  Loop %loop: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    exit count for latch: %N
 ; CHECK-NEXT:  Loop %loop: constant max backedge-taken count is i32 -1
 ; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is %N
 ; CHECK-NEXT:    symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    symbolic max exit count for latch: %N
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-NEXT:  Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
 ;
 entry:
   br label %loop
@@ -211,14 +222,25 @@ define void @sle_from_int_min_no_nsw(i32 %M, i32 %N) {
 ; CHECK-NEXT:  Determining loop execution counts for: @sle_from_int_min_no_nsw
 ; CHECK-NEXT:  Loop %loop: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    exit count for latch: (-2147483648 + %N)
 ; CHECK-NEXT:  Loop %loop: constant max backedge-taken count is i32 -1
 ; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is (-2147483648 + %N)
 ; CHECK-NEXT:    symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    symbolic max exit count for latch: (-2147483648 + %N)
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-NEXT:  Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
 ;
 entry:
   br label %loop
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index 8dc79a54eb97a5..2ec6158e9b0920 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -8,10 +8,18 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
 ; CHECK-NEXT:  Loop %header: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Unpredictable constant max backedge-taken count.
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
@@ -51,10 +59,18 @@ define void @test2(i64 %x, ptr %a) {
 ; CHECK-NEXT:  Loop %header: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Unpredictable constant max backedge-taken count.
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index a9bd4789707012..5ec4a676f69d89 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1711,7 +1711,7 @@ TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
   LLVMContext C;
   SMDiagnostic Err;
   std::unique_ptr<Module> M = parseAssemblyString(R"(
-define void @foo(ptr %dest, ptr %src, i64 noundef %end) {
+define void @foo1(ptr %dest, ptr %src, i64 noundef %end) {
 entry:
   %cmp7 = icmp sgt i64 %end, 0
   br i1 %cmp7, label %for.body, label %exit
@@ -1732,13 +1732,42 @@ for.body:
 
 exit:
   ret void
+}
+
+define i32 @foo2(ptr nocapture noundef %src1, ptr nocapture noundef readonly %src2, i8 noundef %start, i32 %end) {
+entry:
+  %st = zext i8 %start to i16
+  %ext = zext i8 %start to i32
+  %ecmp = icmp ult i16 %st, 42
+  br i1 %ecmp, label %for.body, label %exit
+
+for.body:
+  %i.08 = phi i8 [ %inc, %for.inc ], [ 0, %entry ]
+  %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %i.08
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %i.08
+  %1 = load i32, ptr %arrayidx3, align 4
+  %cmp.early = icmp eq i32 %0, %1
+  br i1 %cmp.early, label %found, label %for.inc
+
+for.inc:
+  %inc = add i8 %i.08, 1
+  %conv = zext i8 %inc to i32
+  %cmp = icmp ult i32 %conv, %end
+  br i1 %cmp, label %for.body, label %exit
+
+found:
+  ret i32 1
+
+exit:
+  ret i32 0
 })",
                                                   Err, C);
 
   ASSERT_TRUE(M && "Could not parse module?");
   ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
 
-  runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+  runWithSE(*M, "foo1", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
     BasicBlock &EntryBB = F.getEntryBlock();
     BasicBlock *ForBodyBB = nullptr;
     Loop *Loop = nullptr;
@@ -1757,6 +1786,31 @@ for.body:
     ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
     ASSERT_FALSE(Predicates.empty());
   });
+
+  runWithSE(*M, "foo2", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    BasicBlock &EntryBB = F.getEntryBlock();
+    BasicBlock *ForIncBB = nullptr;
+    Loop *Loop = nullptr;
+    for (BasicBlock *Succ : successors(&EntryBB)) {
+      Loop = LI.getLoopFor(Succ);
+      if (Loop) {
+        ForIncBB = Loop->getLoopLatch();
+        break;
+      }
+    }
+    ASSERT_TRUE(Loop && "Couldn't find the loop!");
+    ASSERT_TRUE(ForIncBB && "Couldn't find the loop header!");
+    SmallVector<const SCEVPredicate *, 4> Predicates;
+    const SCEV *ExitCount = SE.getPredicatedExitCount(
+        Loop, ForIncBB, &Predicates, ScalarEvolution::SymbolicMaximum);
+    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
+    ASSERT_FALSE(Predicates.empty());
+    ExitCount = SE.getPredicatedExitCount(Loop, ForIncBB, &Predicates,
+                                          ScalarEvolution::ConstantMaximum);
+    ASSERT_TRUE(isa<SCEVConstant>(ExitCount));
+    ASSERT_TRUE(cast<SCEVConstant>(ExitCount)->getAPInt().getSExtValue() ==
+                -2ll);
+  });
 }
 
 }  // end namespace llvm

>From d556ac39ca81f2545735e8f9707079df2f1aee8d Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 30 Aug 2024 09:54:35 +0000
Subject: [PATCH 4/5] Remove unit tests in favour of ll test

---
 .../ScalarEvolution/predicated-exit-count.ll  |  65 +++++++++++
 .../Analysis/ScalarEvolutionTest.cpp          | 106 ------------------
 2 files changed, 65 insertions(+), 106 deletions(-)
 create mode 100644 llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll

diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
new file mode 100644
index 00000000000000..de214183710ab3
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 < %s 2>&1 | FileCheck %s
+
+
+define i32 @multiple_exits_with_predicates(ptr %src1, ptr readonly %src2, i32 %end) {
+; CHECK-LABEL: 'multiple_exits_with_predicates'
+; CHECK-NEXT:  Determining loop execution counts for: @multiple_exits_with_predicates
+; CHECK-NEXT:  Loop %for.body: <multiple exits> Unpredictable backedge-taken count.
+; CHECK-NEXT:    exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for for.body: i32 1023
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:    exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:    symbolic max exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for for.body: i32 1023
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:    symbolic max exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    symbolic max exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:  Loop %for.body: Predicated symbolic max backedge-taken count is (1023 umin (-1 + (1 umax %end)))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+;
+entry:
+  br label %for.body
+
+for.body:
+  %index = phi i8 [ %index.next, %for.inc ], [ 0, %entry ]
+  %index.next = add i8 %index, 1
+  %conv = zext i8 %index.next to i32
+  %cmp.body = icmp ne i32 %conv, 1024
+  br i1 %cmp.body, label %for.work, label %exit
+
+for.work:
+  %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %index
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %index
+  %1 = load i32, ptr %arrayidx3, align 4
+  %cmp.work = icmp eq i32 %0, %1
+  br i1 %cmp.work, label %found, label %for.inc
+
+for.inc:
+  %cmp.inc = icmp ult i32 %conv, %end
+  br i1 %cmp.inc, label %for.body, label %exit
+
+found:
+  ret i32 1
+
+exit:
+  ret i32 0
+}
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 5ec4a676f69d89..d4d90d80f4cea1 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1707,110 +1707,4 @@ TEST_F(ScalarEvolutionsTest, ComplexityComparatorIsStrictWeakOrdering) {
   });
 }
 
-TEST_F(ScalarEvolutionsTest, ExitCountWithPredicates) {
-  LLVMContext C;
-  SMDiagnostic Err;
-  std::unique_ptr<Module> M = parseAssemblyString(R"(
-define void @foo1(ptr %dest, ptr %src, i64 noundef %end) {
-entry:
-  %cmp7 = icmp sgt i64 %end, 0
-  br i1 %cmp7, label %for.body, label %exit
-
-for.body:
-  %conv9 = phi i64 [ %conv, %for.body ], [ 0, %entry ]
-  %i.08 = phi i16 [ %inc, %for.body ], [ 0, %entry ]
-  %arrayidx = getelementptr inbounds i32, ptr %src, i64 %conv9
-  %0 = load i32, ptr %arrayidx, align 4
-  %arrayidx3 = getelementptr inbounds i32, ptr %dest, i64 %conv9
-  %1 = load i32, ptr %arrayidx3, align 4
-  %add = add i32 %1, %0
-  store i32 %add, ptr %arrayidx3, align 4
-  %inc = add i16 %i.08, 1
-  %conv = zext i16 %inc to i64
-  %cmp = icmp ult i64 %conv, %end
-  br i1 %cmp, label %for.body, label %exit
-
-exit:
-  ret void
-}
-
-define i32 @foo2(ptr nocapture noundef %src1, ptr nocapture noundef readonly %src2, i8 noundef %start, i32 %end) {
-entry:
-  %st = zext i8 %start to i16
-  %ext = zext i8 %start to i32
-  %ecmp = icmp ult i16 %st, 42
-  br i1 %ecmp, label %for.body, label %exit
-
-for.body:
-  %i.08 = phi i8 [ %inc, %for.inc ], [ 0, %entry ]
-  %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %i.08
-  %0 = load i32, ptr %arrayidx, align 4
-  %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %i.08
-  %1 = load i32, ptr %arrayidx3, align 4
-  %cmp.early = icmp eq i32 %0, %1
-  br i1 %cmp.early, label %found, label %for.inc
-
-for.inc:
-  %inc = add i8 %i.08, 1
-  %conv = zext i8 %inc to i32
-  %cmp = icmp ult i32 %conv, %end
-  br i1 %cmp, label %for.body, label %exit
-
-found:
-  ret i32 1
-
-exit:
-  ret i32 0
-})",
-                                                  Err, C);
-
-  ASSERT_TRUE(M && "Could not parse module?");
-  ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
-
-  runWithSE(*M, "foo1", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
-    BasicBlock &EntryBB = F.getEntryBlock();
-    BasicBlock *ForBodyBB = nullptr;
-    Loop *Loop = nullptr;
-    for (BasicBlock *Succ : successors(&EntryBB)) {
-      Loop = LI.getLoopFor(Succ);
-      if (Loop) {
-        ForBodyBB = Loop->getHeader();
-        break;
-      }
-    }
-    ASSERT_TRUE(Loop && "Couldn't find the loop!");
-    ASSERT_TRUE(ForBodyBB && "Couldn't find the loop header!");
-    SmallVector<const SCEVPredicate *, 4> Predicates;
-    const SCEV *ExitCount = SE.getPredicatedExitCount(
-        Loop, ForBodyBB, &Predicates, ScalarEvolution::Exact);
-    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
-    ASSERT_FALSE(Predicates.empty());
-  });
-
-  runWithSE(*M, "foo2", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
-    BasicBlock &EntryBB = F.getEntryBlock();
-    BasicBlock *ForIncBB = nullptr;
-    Loop *Loop = nullptr;
-    for (BasicBlock *Succ : successors(&EntryBB)) {
-      Loop = LI.getLoopFor(Succ);
-      if (Loop) {
-        ForIncBB = Loop->getLoopLatch();
-        break;
-      }
-    }
-    ASSERT_TRUE(Loop && "Couldn't find the loop!");
-    ASSERT_TRUE(ForIncBB && "Couldn't find the loop header!");
-    SmallVector<const SCEVPredicate *, 4> Predicates;
-    const SCEV *ExitCount = SE.getPredicatedExitCount(
-        Loop, ForIncBB, &Predicates, ScalarEvolution::SymbolicMaximum);
-    ASSERT_FALSE(isa<SCEVCouldNotCompute>(ExitCount));
-    ASSERT_FALSE(Predicates.empty());
-    ExitCount = SE.getPredicatedExitCount(Loop, ForIncBB, &Predicates,
-                                          ScalarEvolution::ConstantMaximum);
-    ASSERT_TRUE(isa<SCEVConstant>(ExitCount));
-    ASSERT_TRUE(cast<SCEVConstant>(ExitCount)->getAPInt().getSExtValue() ==
-                -2ll);
-  });
-}
-
 }  // end namespace llvm

>From 30282cc5b7ee4523d10ede0786de825366cdc7fb Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 30 Aug 2024 15:25:35 +0000
Subject: [PATCH 5/5] Refactor getExact, etc. to call a common getExitNotTaken
 function

---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 25 +++++++++--
 llvm/lib/Analysis/ScalarEvolution.cpp        | 46 +++-----------------
 2 files changed, 28 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 7e323cb6e1d45f..89f9395959779d 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1524,6 +1524,10 @@ class ScalarEvolution {
     bool isComplete() const { return IsComplete; }
     const SCEV *getConstantMax() const { return ConstantMax; }
 
+    const ExitNotTakenInfo *getExitNotTaken(
+        const BasicBlock *ExitingBlock,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+
   public:
     BackedgeTakenInfo() = default;
     BackedgeTakenInfo(BackedgeTakenInfo &&) = default;
@@ -1574,7 +1578,12 @@ class ScalarEvolution {
     /// predicates are required, otherwise it fills in the required predicates.
     const SCEV *getExact(
         const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->ExactNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
@@ -1582,7 +1591,12 @@ class ScalarEvolution {
     /// Get the constant max backedge taken count for the particular loop exit.
     const SCEV *getConstantMax(
         const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->ConstantMaxNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *getSymbolicMax(
@@ -1592,7 +1606,12 @@ class ScalarEvolution {
     /// Get the symbolic max backedge taken count for the particular loop exit.
     const SCEV *getSymbolicMax(
         const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->SymbolicMaxNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 07012a2b43bd08..6b4a81c217b3c2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8591,56 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
   return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
 }
 
-/// Get the exact not taken count for this loop exit.
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
-  for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock) {
-      if (ENT.hasAlwaysTruePredicate())
-        return ENT.ExactNotTaken;
-      else if (Predicates) {
-        for (const auto *P : ENT.Predicates)
-          Predicates->push_back(P);
-        return ENT.ExactNotTaken;
-      }
-    }
-
-  return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+const ScalarEvolution::ExitNotTakenInfo *
+ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
+    const BasicBlock *ExitingBlock,
     SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
     if (ENT.ExitingBlock == ExitingBlock) {
       if (ENT.hasAlwaysTruePredicate())
-        return ENT.ConstantMaxNotTaken;
+        return &ENT;
       else if (Predicates) {
         for (const auto *P : ENT.Predicates)
           Predicates->push_back(P);
-        return ENT.ConstantMaxNotTaken;
+        return &ENT;
       }
     }
 
-  return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE,
-    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
-  for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock) {
-      if (ENT.hasAlwaysTruePredicate())
-        return ENT.SymbolicMaxNotTaken;
-      else if (Predicates) {
-        for (const auto *P : ENT.Predicates)
-          Predicates->push_back(P);
-        return ENT.SymbolicMaxNotTaken;
-      }
-    }
-
-  return SE->getCouldNotCompute();
+  return nullptr;
 }
 
 /// getConstantMax - Get the constant max backedge taken count for the loop.



More information about the llvm-commits mailing list