[llvm] df3d70b - [Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 2 06:05:29 PDT 2024


Author: David Sherwood
Date: 2024-09-02T14:05:26+01:00
New Revision: df3d70b5a72fee43af3793c8b7a138bd44cac8cf

URL: https://github.com/llvm/llvm-project/commit/df3d70b5a72fee43af3793c8b7a138bd44cac8cf
DIFF: https://github.com/llvm/llvm-project/commit/df3d70b5a72fee43af3793c8b7a138bd44cac8cf.diff

LOG: [Analysis] Add getPredicatedExitCount to ScalarEvolution (#105649)

Due to a reviewer request on PR #88385 I have created this patch
to add a getPredicatedExitCount function, which is similar to
getExitCount except that it uses the predicated backedge taken
information. With PR #88385 we will start to care about more
loops with multiple exits, and want the ability to query exit
counts for a particular exiting block. Such loops may require
predicates in order to be vectorised.

New tests added here:

Analysis/ScalarEvolution/predicated-exit-count.ll

Added: 
    llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
    llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index fe46a504bce5d1..89f9395959779d 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -871,6 +871,13 @@ class ScalarEvolution {
   const SCEV *getExitCount(const Loop *L, const BasicBlock *ExitingBlock,
                            ExitCountKind Kind = Exact);
 
+  /// Same as above except this uses the predicated backedge taken info and
+  /// may require predicates.
+  const SCEV *
+  getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock,
+                         SmallVectorImpl<const SCEVPredicate *> *Predicates,
+                         ExitCountKind Kind = Exact);
+
   /// If the specified loop has a predictable backedge-taken count, return it,
   /// otherwise return a SCEVCouldNotCompute object. The backedge-taken count is
   /// the number of times the loop header will be branched to from within the
@@ -1517,6 +1524,10 @@ class ScalarEvolution {
     bool isComplete() const { return IsComplete; }
     const SCEV *getConstantMax() const { return ConstantMax; }
 
+    const ExitNotTakenInfo *getExitNotTaken(
+        const BasicBlock *ExitingBlock,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const;
+
   public:
     BackedgeTakenInfo() = default;
     BackedgeTakenInfo(BackedgeTakenInfo &&) = default;
@@ -1563,16 +1574,29 @@ class ScalarEvolution {
     /// Return the number of times this loop exit may fall through to the back
     /// edge, or SCEVCouldNotCompute. The loop is guaranteed not to exit via
     /// this block before this number of iterations, but may exit via another
-    /// block.
-    const SCEV *getExact(const BasicBlock *ExitingBlock,
-                         ScalarEvolution *SE) const;
+    /// block. If \p Predicates is null the function returns CouldNotCompute if
+    /// predicates are required, otherwise it fills in the required predicates.
+    const SCEV *getExact(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->ExactNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Get the constant max backedge taken count for the loop.
     const SCEV *getConstantMax(ScalarEvolution *SE) const;
 
     /// Get the constant max backedge taken count for the particular loop exit.
-    const SCEV *getConstantMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getConstantMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->ConstantMaxNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Get the symbolic max backedge taken count for the loop.
     const SCEV *getSymbolicMax(
@@ -1580,8 +1604,14 @@ class ScalarEvolution {
         SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr);
 
     /// Get the symbolic max backedge taken count for the particular loop exit.
-    const SCEV *getSymbolicMax(const BasicBlock *ExitingBlock,
-                               ScalarEvolution *SE) const;
+    const SCEV *getSymbolicMax(
+        const BasicBlock *ExitingBlock, ScalarEvolution *SE,
+        SmallVectorImpl<const SCEVPredicate *> *Predicates = nullptr) const {
+      if (auto *ENT = getExitNotTaken(ExitingBlock, Predicates))
+        return ENT->SymbolicMaxNotTaken;
+      else
+        return SE->getCouldNotCompute();
+    }
 
     /// Return true if the number of times this backedge is taken is either the
     /// value returned by getConstantMax or zero.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 54dde8401cdff0..6b4a81c217b3c2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8247,6 +8247,23 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
   llvm_unreachable("Invalid ExitCountKind!");
 }
 
+const SCEV *ScalarEvolution::getPredicatedExitCount(
+    const Loop *L, const BasicBlock *ExitingBlock,
+    SmallVectorImpl<const SCEVPredicate *> *Predicates, ExitCountKind Kind) {
+  switch (Kind) {
+  case Exact:
+    return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
+                                                      Predicates);
+  case SymbolicMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
+                                                            Predicates);
+  case ConstantMaximum:
+    return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
+                                                            Predicates);
+  };
+  llvm_unreachable("Invalid ExitCountKind!");
+}
+
 const SCEV *ScalarEvolution::getPredicatedBackedgeTakenCount(
     const Loop *L, SmallVectorImpl<const SCEVPredicate *> &Preds) {
   return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
@@ -8574,33 +8591,22 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
   return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
 }
 
-/// Get the exact not taken count for this loop exit.
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
-                                             ScalarEvolution *SE) const {
-  for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ExactNotTaken;
-
-  return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
+const ScalarEvolution::ExitNotTakenInfo *
+ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
+    const BasicBlock *ExitingBlock,
+    SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
   for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.ConstantMaxNotTaken;
-
-  return SE->getCouldNotCompute();
-}
-
-const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
-    const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
-  for (const auto &ENT : ExitNotTaken)
-    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
-      return ENT.SymbolicMaxNotTaken;
+    if (ENT.ExitingBlock == ExitingBlock) {
+      if (ENT.hasAlwaysTruePredicate())
+        return &ENT;
+      else if (Predicates) {
+        for (const auto *P : ENT.Predicates)
+          Predicates->push_back(P);
+        return &ENT;
+      }
+    }
 
-  return SE->getCouldNotCompute();
+  return nullptr;
 }
 
 /// getConstantMax - Get the constant max backedge taken count for the loop.
@@ -13642,7 +13648,21 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
   if (ExitingBlocks.size() > 1)
     for (BasicBlock *ExitingBlock : ExitingBlocks) {
       OS << "  exit count for " << ExitingBlock->getName() << ": ";
-      PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
+      const SCEV *EC = SE->getExitCount(L, ExitingBlock);
+      PrintSCEVWithTypeHint(OS, EC);
+      if (isa<SCEVCouldNotCompute>(EC)) {
+        // Retry with predicates.
+        SmallVector<const SCEVPredicate *, 4> Predicates;
+        EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
+        if (!isa<SCEVCouldNotCompute>(EC)) {
+          OS << "\n  predicated exit count for " << ExitingBlock->getName()
+             << ": ";
+          PrintSCEVWithTypeHint(OS, EC);
+          OS << "\n   Predicates:\n";
+          for (const auto *P : Predicates)
+            P->print(OS, 4);
+        }
+      }
       OS << "\n";
     }
 
@@ -13682,6 +13702,20 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
       auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
                                        ScalarEvolution::SymbolicMaximum);
       PrintSCEVWithTypeHint(OS, ExitBTC);
+      if (isa<SCEVCouldNotCompute>(ExitBTC)) {
+        // Retry with predicates.
+        SmallVector<const SCEVPredicate *, 4> Predicates;
+        ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
+                                             ScalarEvolution::SymbolicMaximum);
+        if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
+          OS << "\n  predicated symbolic max exit count for "
+             << ExitingBlock->getName() << ": ";
+          PrintSCEVWithTypeHint(OS, ExitBTC);
+          OS << "\n   Predicates:\n";
+          for (const auto *P : Predicates)
+            P->print(OS, 4);
+        }
+      }
       OS << "\n";
     }
 

diff  --git a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
index e9faf98eee4492..6d64f76494638f 100644
--- a/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
+++ b/llvm/test/Analysis/ScalarEvolution/exit-count-non-strict.ll
@@ -93,14 +93,25 @@ define void @ule_from_zero_no_nuw(i32 %M, i32 %N) {
 ; CHECK-NEXT:  Determining loop execution counts for: @ule_from_zero_no_nuw
 ; CHECK-NEXT:  Loop %loop: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    exit count for latch: %N
 ; CHECK-NEXT:  Loop %loop: constant max backedge-taken count is i32 -1
 ; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is %N
 ; CHECK-NEXT:    symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for loop: (1 + (zext i32 %M to i64))<nuw><nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    symbolic max exit count for latch: %N
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
+; CHECK-NEXT:  Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 %N to i64) umin (1 + (zext i32 %M to i64))<nuw><nsw>)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {0,+,1}<%loop> Added Flags: <nusw>
 ;
 entry:
   br label %loop
@@ -211,14 +222,25 @@ define void @sle_from_int_min_no_nsw(i32 %M, i32 %N) {
 ; CHECK-NEXT:  Determining loop execution counts for: @sle_from_int_min_no_nsw
 ; CHECK-NEXT:  Loop %loop: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    exit count for latch: (-2147483648 + %N)
 ; CHECK-NEXT:  Loop %loop: constant max backedge-taken count is i32 -1
 ; CHECK-NEXT:  Loop %loop: symbolic max backedge-taken count is (-2147483648 + %N)
 ; CHECK-NEXT:    symbolic max exit count for loop: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for loop: (2147483649 + (sext i32 %M to i64))<nsw>
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:    symbolic max exit count for latch: (-2147483648 + %N)
 ; CHECK-NEXT:  Loop %loop: Predicated backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
+; CHECK-NEXT:  Loop %loop: Predicated symbolic max backedge-taken count is ((zext i32 (-2147483648 + %N) to i64) umin (2147483649 + (sext i32 %M to i64))<nsw>)
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {-2147483648,+,1}<%loop> Added Flags: <nssw>
 ;
 entry:
   br label %loop

diff  --git a/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
new file mode 100644
index 00000000000000..de214183710ab3
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-exit-count.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 < %s 2>&1 | FileCheck %s
+
+
+define i32 @multiple_exits_with_predicates(ptr %src1, ptr readonly %src2, i32 %end) {
+; CHECK-LABEL: 'multiple_exits_with_predicates'
+; CHECK-NEXT:  Determining loop execution counts for: @multiple_exits_with_predicates
+; CHECK-NEXT:  Loop %for.body: <multiple exits> Unpredictable backedge-taken count.
+; CHECK-NEXT:    exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for for.body: i32 1023
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:    exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:  Loop %for.body: Unpredictable constant max backedge-taken count.
+; CHECK-NEXT:  Loop %for.body: Unpredictable symbolic max backedge-taken count.
+; CHECK-NEXT:    symbolic max exit count for for.body: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for for.body: i32 1023
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:    symbolic max exit count for for.work: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    symbolic max exit count for for.inc: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for for.inc: (-1 + (1 umax %end))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-EMPTY:
+; CHECK-NEXT:  Loop %for.body: Predicated symbolic max backedge-taken count is (1023 umin (-1 + (1 umax %end)))
+; CHECK-NEXT:   Predicates:
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+; CHECK-NEXT:      {1,+,1}<%for.body> Added Flags: <nusw>
+;
+entry:
+  br label %for.body
+
+for.body:
+  %index = phi i8 [ %index.next, %for.inc ], [ 0, %entry ]
+  %index.next = add i8 %index, 1
+  %conv = zext i8 %index.next to i32
+  %cmp.body = icmp ne i32 %conv, 1024
+  br i1 %cmp.body, label %for.work, label %exit
+
+for.work:
+  %arrayidx = getelementptr inbounds i32, ptr %src1, i8 %index
+  %0 = load i32, ptr %arrayidx, align 4
+  %arrayidx3 = getelementptr inbounds i32, ptr %src2, i8 %index
+  %1 = load i32, ptr %arrayidx3, align 4
+  %cmp.work = icmp eq i32 %0, %1
+  br i1 %cmp.work, label %found, label %for.inc
+
+for.inc:
+  %cmp.inc = icmp ult i32 %conv, %end
+  br i1 %cmp.inc, label %for.body, label %exit
+
+found:
+  ret i32 1
+
+exit:
+  ret i32 0
+}

diff  --git a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
index 8dc79a54eb97a5..2ec6158e9b0920 100644
--- a/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/predicated-symbolic-max-backedge-taken-count.ll
@@ -8,10 +8,18 @@ define void @test1(i64 %x, ptr %a, ptr %b) {
 ; CHECK-NEXT:  Loop %header: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Unpredictable constant max backedge-taken count.
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
@@ -51,10 +59,18 @@ define void @test2(i64 %x, ptr %a) {
 ; CHECK-NEXT:  Loop %header: <multiple exits> Unpredictable backedge-taken count.
 ; CHECK-NEXT:    exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Unpredictable constant max backedge-taken count.
 ; CHECK-NEXT:  Loop %header: Unpredictable symbolic max backedge-taken count.
 ; CHECK-NEXT:    symbolic max exit count for header: ***COULDNOTCOMPUTE***
 ; CHECK-NEXT:    symbolic max exit count for latch: ***COULDNOTCOMPUTE***
+; CHECK-NEXT:    predicated symbolic max exit count for latch: (-1 + (1 umax %x))
+; CHECK-NEXT:     Predicates:
+; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>
+; CHECK-EMPTY:
 ; CHECK-NEXT:  Loop %header: Predicated symbolic max backedge-taken count is (-1 + (1 umax %x))
 ; CHECK-NEXT:   Predicates:
 ; CHECK-NEXT:      {1,+,1}<%header> Added Flags: <nusw>


        


More information about the llvm-commits mailing list