[llvm] [SCEV] Add predicated version of getSymbolicMaxBackedgeTakenCount. (PR #93498)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue May 28 14:39:04 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/93498

>From 68c2420c4bfe6f28094a0b8b576800c7219b6f10 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Mon, 27 May 2024 21:17:21 -0700
Subject: [PATCH] [SCEV] Add predicated version of
 getSymbolicMaxBackedgeTakenCount.

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 15 +++++-
 llvm/lib/Analysis/ScalarEvolution.cpp         | 48 +++++++++++++++++--
 ...cated-symbolic-max-backedge-taken-count.ll |  6 +++
 3 files changed, 63 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 1d016b28347d2..d9e10dd3150a3 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -912,6 +912,9 @@ class ScalarEvolution {
     return getBackedgeTakenCount(L, SymbolicMaximum);
   }
 
+  const SCEV *getPredicatedSymbolicMaxBackedgeTakenCount(
+      const Loop *L, SmallVector<const SCEVPredicate *, 4> &Predicates);
+
   /// Return true if the backedge taken count is either the value returned by
   /// getConstantMaxBackedgeTakenCount or zero.
   bool isBackedgeTakenCountMaxOrZero(const Loop *L);
@@ -1549,7 +1552,9 @@ class ScalarEvolution {
                                ScalarEvolution *SE) const;
 
     /// Get the symbolic max backedge taken count for the loop.
-    const SCEV *getSymbolicMax(const Loop *L, ScalarEvolution *SE);
+    const SCEV *
+    getSymbolicMax(const Loop *L, ScalarEvolution *SE,
+                   SmallVector<const SCEVPredicate *, 4> *Predicates = nullptr);
 
     /// Get the symbolic max backedge taken count for the particular loop exit.
     const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
@@ -1746,7 +1751,7 @@ class ScalarEvolution {
 
   /// Similar to getBackedgeTakenInfo, but will add predicates as required
   /// with the purpose of returning complete information.
-  const BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);
+  BackedgeTakenInfo &getPredicatedBackedgeTakenInfo(const Loop *L);
 
   /// Compute the number of times the specified loop will iterate.
   /// If AllowPredicates is set, we will create new SCEV predicates as
@@ -2311,6 +2316,9 @@ class PredicatedScalarEvolution {
   /// Get the (predicated) backedge count for the analyzed loop.
   const SCEV *getBackedgeTakenCount();
 
+  /// Get the (predicated) symbolic max backedge count for the analyzed loop.
+  const SCEV *getSymbolicMaxBackedgeTakenCount();
+
   /// Adds a new predicate.
   void addPredicate(const SCEVPredicate &Pred);
 
@@ -2379,6 +2387,9 @@ class PredicatedScalarEvolution {
 
   /// The backedge taken count.
   const SCEV *BackedgeCount = nullptr;
+
+  /// The symbolic backedge taken count.
+  const SCEV *SymbolicMaxBackedgeCount = nullptr;
 };
 
 template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index bb56b41fe15d5..e46d7183a2a35 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8295,6 +8295,11 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
   llvm_unreachable("Invalid ExitCountKind!");
 }
 
+const SCEV *ScalarEvolution::getPredicatedSymbolicMaxBackedgeTakenCount(
+    const Loop *L, SmallVector<const SCEVPredicate *, 4> &Preds) {
+  return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
+}
+
 bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
   return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
 }
@@ -8311,7 +8316,7 @@ static void PushLoopPHIs(const Loop *L,
       Worklist.push_back(&PN);
 }
 
-const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::BackedgeTakenInfo &
 ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
   auto &BTI = getBackedgeTakenInfo(L);
   if (BTI.hasFullInfo())
@@ -8644,9 +8649,9 @@ ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
   return getConstantMax();
 }
 
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
-                                                   ScalarEvolution *SE) {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
+    const Loop *L, ScalarEvolution *SE,
+    SmallVector<const SCEVPredicate *, 4> *Predicates) {
   if (!SymbolicMax) {
     // Form an expression for the maximum exit count possible for this loop. We
     // merge the max and exact information to approximate a version of
@@ -8661,6 +8666,12 @@ ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
                "We should only have known counts for exiting blocks that "
                "dominate latch!");
         ExitCounts.push_back(ExitCount);
+        if (Predicates)
+          for (const auto *P : ENT.Predicates)
+            Predicates->push_back(P);
+
+        assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
+               "Predicate should be always true!");
       }
     }
     if (ExitCounts.empty())
@@ -13609,6 +13620,24 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
       P->print(OS, 4);
   }
 
+  Preds.clear();
+  auto *PredSymbolicMax =
+      SE->getPredicatedSymbolicMaxBackedgeTakenCount(L, Preds);
+  if (SymbolicBTC != PredSymbolicMax) {
+    OS << "Loop ";
+    L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
+    OS << ": ";
+    if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
+      OS << "Predicated symbolic max backedge-taken count is ";
+      PrintSCEVWithTypeHint(OS, PredSymbolicMax);
+    } else
+      OS << "Unpredictable predicated symbolic max backedge-taken count.";
+    OS << "\n";
+    OS << " Predicates:\n";
+    for (const auto *P : Preds)
+      P->print(OS, 4);
+  }
+
   if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
     OS << "Loop ";
     L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
@@ -14822,6 +14851,17 @@ const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() {
   return BackedgeCount;
 }
 
+const SCEV *PredicatedScalarEvolution::getSymbolicMaxBackedgeTakenCount() {
+  if (!SymbolicMaxBackedgeCount) {
+    SmallVector<const SCEVPredicate *, 4> Preds;
+    SymbolicMaxBackedgeCount =
+        SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
+    for (const auto *P : Preds)
+      addPredicate(*P);
+  }
+  return SymbolicMaxBackedgeCount;
+}
+
 void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
   if (Preds->implies(&Pred))
     return;
diff --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index d40416359b65c..8dc79a54eb97a 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -12,6 +12,9 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
 ;
 entry:
   br label %header
@@ -52,6 +55,9 @@ define void @test2(i64 %x, ptr %a) {
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
 ;
 entry:
   br label %header



More information about the llvm-commits mailing list