[llvm] change contents of ScalarEvolution from private to protected (PR #83052)

Joshua Ferguson via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 14 12:25:35 PDT 2024


https://github.com/skewballfox updated https://github.com/llvm/llvm-project/pull/83052

>From eea887cf6be39856fa441ed48f72c1c9177a76a6 Mon Sep 17 00:00:00 2001
From: Joshua Ferguson <joshua.ferguson.273 at gmail.com>
Date: Sun, 25 Feb 2024 14:06:02 -0600
Subject: [PATCH 01/13] mainly pushing to switch machines

---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 0880f9c65aa45d..1b03437de30c28 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1345,7 +1345,7 @@ class ScalarEvolution {
     }
   };
 
-private:
+protected:
   /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a
   /// Value is deleted.
   class SCEVCallbackVH final : public CallbackVH {

>From e47436b767d635c14c10fc8c0bfc4fe30b8967d6 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 08:35:45 -0600
Subject: [PATCH 02/13] added AssumeLoopExits bool to SE, lifting MustExit code
 into SE

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  9 ++-
 .../llvm/Analysis/Utils/EnzymeFunctionUtils.h | 71 +++++++++++++++++++
 llvm/lib/Analysis/ScalarEvolution.cpp         |  8 ++-
 3 files changed, 84 insertions(+), 4 deletions(-)
 create mode 100644 llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 1b03437de30c28..3075358e95791f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -460,6 +460,9 @@ class ScalarEvolution {
     LoopComputable ///< The SCEV varies predictably with the loop.
   };
 
+  bool AssumeLoopExists = false;
+  void setAssumeLoopExists();
+
   /// An enum describing the relationship between a SCEV and a basic block.
   enum BlockDisposition {
     DoesNotDominateBlock,  ///< The SCEV does not dominate the block.
@@ -1345,7 +1348,7 @@ class ScalarEvolution {
     }
   };
 
-protected:
+  private:
   /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a
   /// Value is deleted.
   class SCEVCallbackVH final : public CallbackVH {
@@ -1364,7 +1367,7 @@ class ScalarEvolution {
 
   /// The function we are analyzing.
   Function &F;
-
+  
   /// Does the module have any calls to the llvm.experimental.guard intrinsic
   /// at all?  If this is false, we avoid doing work that will only help if
   /// thare are guards present in the IR.
@@ -1765,7 +1768,7 @@ class ScalarEvolution {
   /// an arbitrary expression as opposed to only constants.
   const SCEV *computeSymbolicMaxBackedgeTakenCount(const Loop *L);
 
-  // Helper functions for computeExitLimitFromCond to avoid exponential time
+// Helper functions for computeExitLimitFromCond to avoid exponential time
   // complexity.
 
   class ExitLimitCache {
diff --git a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h
new file mode 100644
index 00000000000000..a211bdca6a47d6
--- /dev/null
+++ b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h
@@ -0,0 +1,71 @@
+
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+
+#include "llvm/IR/Function.h"
+
+#include "llvm/IR/Instructions.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
+#include <deque>
+
+
+// TODO note this doesn't go through [loop, unreachable], and we could get more
+// performance by doing this can consider doing some domtree magic potentially
+static inline llvm::SmallPtrSet<llvm::BasicBlock *, 4>
+getGuaranteedUnreachable(llvm::Function *F) {
+  llvm::SmallPtrSet<llvm::BasicBlock *, 4> knownUnreachables;
+  if (F->empty())
+    return knownUnreachables;
+  std::deque<llvm::BasicBlock *> todo;
+  for (auto &BB : *F) {
+    todo.push_back(&BB);
+  }
+
+  while (!todo.empty()) {
+    llvm::BasicBlock *next = todo.front();
+    todo.pop_front();
+
+    if (knownUnreachables.find(next) != knownUnreachables.end())
+      continue;
+
+    if (llvm::isa<llvm::ReturnInst>(next->getTerminator()))
+      continue;
+
+    if (llvm::isa<llvm::UnreachableInst>(next->getTerminator())) {
+      knownUnreachables.insert(next);
+      for (llvm::BasicBlock *Pred : predecessors(next)) {
+        todo.push_back(Pred);
+      }
+      continue;
+    }
+
+    // Assume resumes don't happen
+    // TODO consider EH
+    if (llvm::isa<llvm::ResumeInst>(next->getTerminator())) {
+      knownUnreachables.insert(next);
+      for (llvm::BasicBlock *Pred : predecessors(next)) {
+        todo.push_back(Pred);
+      }
+      continue;
+    }
+
+    bool unreachable = true;
+    for (llvm::BasicBlock *Succ : llvm::successors(next)) {
+      if (knownUnreachables.find(Succ) == knownUnreachables.end()) {
+        unreachable = false;
+        break;
+      }
+    }
+
+    if (!unreachable)
+      continue;
+    knownUnreachables.insert(next);
+    for (llvm::BasicBlock *Pred : llvm::predecessors(next)) {
+      todo.push_back(Pred);
+    }
+    continue;
+  }
+
+  return knownUnreachables;
+}
\ No newline at end of file
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4b2db80bc1ec30..6dc59108f5e188 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -82,6 +82,7 @@
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
+#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
@@ -509,6 +510,10 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
   return S;
 }
 
+void ScalarEvolution::setAssumeLoopExists() {
+  this->AssumeLoopExists=true;
+}
+
 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                            const SCEV *op, Type *ty)
     : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
@@ -7413,7 +7418,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
   // A mustprogress loop without side effects must be finite.
   // TODO: The check used here is very conservative.  It's only *specific*
   // side effects which are well defined in infinite loops.
-  return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
+  return this->AssumeLoopExists || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
 }
 
 const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
@@ -13354,6 +13359,7 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
   return getSizeOfExpr(ETy, Ty);
 }
 
+
 //===----------------------------------------------------------------------===//
 //                   SCEVCallbackVH Class Implementation
 //===----------------------------------------------------------------------===//

>From f55e361a3ba1d4a5ca30f4b9719d23d57d273cc5 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 09:51:55 -0600
Subject: [PATCH 03/13] added MustExitcode for computeExitLimit

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  7 ++--
 .../llvm/Analysis/Utils/EnzymeFunctionUtils.h |  1 -
 llvm/lib/Analysis/ScalarEvolution.cpp         | 32 +++++++++++++++----
 3 files changed, 30 insertions(+), 10 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 3075358e95791f..4cc1954c1233f6 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -462,6 +462,7 @@ class ScalarEvolution {
 
   bool AssumeLoopExists = false;
   void setAssumeLoopExists();
+  llvm::SmallPtrSet<llvm::BasicBlock *, 4> GuaranteedUnreachable;
 
   /// An enum describing the relationship between a SCEV and a basic block.
   enum BlockDisposition {
@@ -1348,7 +1349,7 @@ class ScalarEvolution {
     }
   };
 
-  private:
+private:
   /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a
   /// Value is deleted.
   class SCEVCallbackVH final : public CallbackVH {
@@ -1367,7 +1368,7 @@ class ScalarEvolution {
 
   /// The function we are analyzing.
   Function &F;
-  
+
   /// Does the module have any calls to the llvm.experimental.guard intrinsic
   /// at all?  If this is false, we avoid doing work that will only help if
   /// thare are guards present in the IR.
@@ -1768,7 +1769,7 @@ class ScalarEvolution {
   /// an arbitrary expression as opposed to only constants.
   const SCEV *computeSymbolicMaxBackedgeTakenCount(const Loop *L);
 
-// Helper functions for computeExitLimitFromCond to avoid exponential time
+  // Helper functions for computeExitLimitFromCond to avoid exponential time
   // complexity.
 
   class ExitLimitCache {
diff --git a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h
index a211bdca6a47d6..59032cbe6dddd4 100644
--- a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h
@@ -9,7 +9,6 @@
 #include "llvm/Transforms/Utils/ValueMapper.h"
 #include <deque>
 
-
 // TODO note this doesn't go through [loop, unreachable], and we could get more
 // performance by doing this can consider doing some domtree magic potentially
 static inline llvm::SmallPtrSet<llvm::BasicBlock *, 4>
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6dc59108f5e188..c1071f07b7f280 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -80,9 +80,9 @@
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
-#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
@@ -510,9 +510,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
   return S;
 }
 
-void ScalarEvolution::setAssumeLoopExists() {
-  this->AssumeLoopExists=true;
-}
+void ScalarEvolution::setAssumeLoopExists() { this->AssumeLoopExists = true; }
 
 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                            const SCEV *op, Type *ty)
@@ -7418,7 +7416,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
   // A mustprogress loop without side effects must be finite.
   // TODO: The check used here is very conservative.  It's only *specific*
   // side effects which are well defined in infinite loops.
-  return this->AssumeLoopExists || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
+  return this->AssumeLoopExists || isFinite(L) ||
+         (isMustProgress(L) && loopHasNoSideEffects(L));
 }
 
 const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
@@ -8833,6 +8832,26 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
 ScalarEvolution::ExitLimit
 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
                                       bool AllowPredicates) {
+  if (AssumeLoopExists) {
+    SmallVector<BasicBlock *, 8> ExitingBlocks;
+    L->getExitingBlocks(ExitingBlocks);
+    for (auto &ExitingBlock : ExitingBlocks) {
+      BasicBlock *Exit = nullptr;
+      for (auto *SBB : successors(ExitingBlock)) {
+        if (!L->contains(SBB)) {
+          if (GuaranteedUnreachable.count(SBB))
+            continue;
+          Exit = SBB;
+          break;
+        }
+      }
+      if (!Exit)
+        ExitingBlock = nullptr;
+    }
+    ExitingBlocks.erase(
+        std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr),
+        ExitingBlocks.end());
+  }
   assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
   // If our exiting block does not dominate the latch, then its connection with
   // loop's exit limit may be far from trivial.
@@ -8858,6 +8877,8 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
     BasicBlock *Exit = nullptr;
     for (auto *SBB : successors(ExitingBlock))
       if (!L->contains(SBB)) {
+        if (AssumeLoopExists and GuaranteedUnreachable.count(SBB))
+          continue;
         if (Exit) // Multiple exit successors.
           return getCouldNotCompute();
         Exit = SBB;
@@ -13359,7 +13380,6 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) {
   return getSizeOfExpr(ETy, Ty);
 }
 
-
 //===----------------------------------------------------------------------===//
 //                   SCEVCallbackVH Class Implementation
 //===----------------------------------------------------------------------===//

>From 8e85c0653be244e036e68eb31a4022ff05b23257 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 10:33:22 -0600
Subject: [PATCH 04/13] added enzyme mustExit code to 
 computeExitLimitFromSingleExitSwitch

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c1071f07b7f280..d28436e02466be 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9264,8 +9264,14 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
   if (Switch->getDefaultDest() == ExitingBlock)
     return getCouldNotCompute();
 
-  assert(L->contains(Switch->getDefaultDest()) &&
-         "Default case must not exit the loop!");
+  // if not using enzyme executes by default
+  // if using enzyme and the code is guaranteed unreachable,
+  // the default destination doesn't matter
+  if (!AssumeLoopExists ||
+      !GuaranteedUnreachable.count(Switch->getDefaultDest())) {
+    assert(L->contains(Switch->getDefaultDest()) &&
+           "Default case must not exit the loop!");
+  }
   const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
   const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
 

>From 3f378b5c9370355e3b5fc66709df06ec4f3970f3 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 11:05:50 -0600
Subject: [PATCH 05/13] add enzyme must exit code to
 computeExitLimitFromCondImpl

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 109 ++++++++++++++++++++++++--
 1 file changed, 104 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d28436e02466be..62f8ddfa720812 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8949,10 +8949,104 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
-  // Handle BinOp conditions (And, Or).
-  if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
-          Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
-    return *LimitFromBinOp;
+  if (!AssumeLoopExists) {
+    // Handle BinOp conditions (And, Or).
+    if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
+            Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
+      return *LimitFromBinOp;
+  } else {
+    // Check if the controlling expression for this loop is an And or Or.
+    if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
+      if (BO->getOpcode() == Instruction::And) {
+        // Recurse on the operands of the and.
+        bool EitherMayExit = !ExitIfTrue;
+        ExitLimit EL0 = computeExitLimitFromCondCached(
+            Cache, L, BO->getOperand(0), ExitIfTrue,
+            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
+        ExitLimit EL1 = computeExitLimitFromCondCached(
+            Cache, L, BO->getOperand(1), ExitIfTrue,
+            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
+        const SCEV *BECount = getCouldNotCompute();
+        const SCEV *MaxBECount = getCouldNotCompute();
+        if (EitherMayExit) {
+          // Both conditions must be true for the loop to continue executing.
+          // Choose the less conservative count.
+          if (EL0.ExactNotTaken == getCouldNotCompute() ||
+              EL1.ExactNotTaken == getCouldNotCompute())
+            BECount = getCouldNotCompute();
+          else
+            BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken,
+                                                 EL1.ExactNotTaken);
+
+          if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
+            MaxBECount = EL1.ConstantMaxNotTaken;
+          else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
+            MaxBECount = EL0.ConstantMaxNotTaken;
+          else
+            MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
+                                                    EL1.ConstantMaxNotTaken);
+        } else {
+          // Both conditions must be true at the same time for the loop to exit.
+          // For now, be conservative.
+          if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
+            MaxBECount = EL0.ConstantMaxNotTaken;
+          if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+            BECount = EL0.ExactNotTaken;
+        }
+
+        // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
+        // to be more aggressive when computing BECount than when computing
+        // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
+        // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
+        // EL1.ConstantMaxNotTaken to not.
+        if (isa<SCEVCouldNotCompute>(MaxBECount) &&
+            !isa<SCEVCouldNotCompute>(BECount))
+          MaxBECount = getConstant(getUnsignedRangeMax(BECount));
+
+        return ExitLimit(BECount, MaxBECount, MaxBECount, false,
+                         {&EL0.Predicates, &EL1.Predicates});
+      }
+      if (BO->getOpcode() == Instruction::Or) {
+        // Recurse on the operands of the or.
+        bool EitherMayExit = ExitIfTrue;
+        ExitLimit EL0 = computeExitLimitFromCondCached(
+            Cache, L, BO->getOperand(0), ExitIfTrue,
+            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
+        ExitLimit EL1 = computeExitLimitFromCondCached(
+            Cache, L, BO->getOperand(1), ExitIfTrue,
+            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
+        const SCEV *BECount = getCouldNotCompute();
+        const SCEV *MaxBECount = getCouldNotCompute();
+        if (EitherMayExit) {
+          // Both conditions must be false for the loop to continue executing.
+          // Choose the less conservative count.
+          if (EL0.ExactNotTaken == getCouldNotCompute() ||
+              EL1.ExactNotTaken == getCouldNotCompute())
+            BECount = getCouldNotCompute();
+          else
+            BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken,
+                                                 EL1.ExactNotTaken);
+
+          if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
+            MaxBECount = EL1.ConstantMaxNotTaken;
+          else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
+            MaxBECount = EL0.ConstantMaxNotTaken;
+          else
+            MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
+                                                    EL1.ConstantMaxNotTaken);
+        } else {
+          // Both conditions must be false at the same time for the loop to
+          // exit. For now, be conservative.
+          if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
+            MaxBECount = EL0.ConstantMaxNotTaken;
+          if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+            BECount = EL0.ExactNotTaken;
+        }
+        return ExitLimit(BECount, MaxBECount, MaxBECount, false,
+                         {&EL0.Predicates, &EL1.Predicates});
+      }
+    }
+  }
 
   // With an icmp, it may be feasible to compute an exact backedge-taken count.
   // Proceed to the next level to examine the icmp.
@@ -8973,12 +9067,17 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
   // preserve the CFG and is temporarily leaving constant conditions
   // in place.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
-    if (ExitIfTrue == !CI->getZExtValue())
+    if (ExitIfTrue == !CI->getZExtValue()) {
       // The backedge is always taken.
       return getCouldNotCompute();
+    }
     // The backedge is never taken.
     return getZero(CI->getType());
   }
+  // The rest of this code was missing from the MustExitScalarEvolution
+  // overrides
+  // so this should never be reached if using enzyme
+  assert(!AssumeLoopExists);
 
   // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
   // with a constant step, we can form an equivalent icmp predicate and figure

>From 14a0c6c187d61db2e017202283be20d17cc93ed7 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 11:31:39 -0600
Subject: [PATCH 06/13] implemented enzyme must exit code in
 computeExitLimitFromICmp

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 183 +++++++++++++++++++++-----
 1 file changed, 151 insertions(+), 32 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 62f8ddfa720812..b6e88b563e2724 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9074,35 +9074,32 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     // The backedge is never taken.
     return getZero(CI->getType());
   }
-  // The rest of this code was missing from the MustExitScalarEvolution
-  // overrides
-  // so this should never be reached if using enzyme
-  assert(!AssumeLoopExists);
-
-  // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
-  // with a constant step, we can form an equivalent icmp predicate and figure
-  // out how many iterations will be taken before we exit.
-  const WithOverflowInst *WO;
-  const APInt *C;
-  if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
-      match(WO->getRHS(), m_APInt(C))) {
-    ConstantRange NWR =
-      ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
-                                           WO->getNoWrapKind());
-    CmpInst::Predicate Pred;
-    APInt NewRHSC, Offset;
-    NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
-    if (!ExitIfTrue)
-      Pred = ICmpInst::getInversePredicate(Pred);
-    auto *LHS = getSCEV(WO->getLHS());
-    if (Offset != 0)
-      LHS = getAddExpr(LHS, getConstant(Offset));
-    auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
-                                       ControlsOnlyExit, AllowPredicates);
-    if (EL.hasAnyInfo())
-      return EL;
-  }
 
+  // block was never executed in MustExitScalarEvolution code
+  if (!AssumeLoopExists) {
+    // If we're exiting based on the overflow flag of an x.with.overflow
+    // intrinsic with a constant step, we can form an equivalent icmp predicate
+    // and figure out how many iterations will be taken before we exit.
+    const WithOverflowInst *WO;
+    const APInt *C;
+    if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
+        match(WO->getRHS(), m_APInt(C))) {
+      ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
+          WO->getBinaryOp(), *C, WO->getNoWrapKind());
+      CmpInst::Predicate Pred;
+      APInt NewRHSC, Offset;
+      NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
+      if (!ExitIfTrue)
+        Pred = ICmpInst::getInversePredicate(Pred);
+      auto *LHS = getSCEV(WO->getLHS());
+      if (Offset != 0)
+        LHS = getAddExpr(LHS, getConstant(Offset));
+      auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
+                                         ControlsOnlyExit, AllowPredicates);
+      if (EL.hasAnyInfo())
+        return EL;
+    }
+  }
   // If it's not an integer or pointer comparison then compute it the hard way.
   return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
 }
@@ -9201,12 +9198,134 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
+  if (!AssumeLoopExists) {
+    ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
+                                            AllowPredicates);
+    if (EL.hasAnyInfo())
+      return EL;
+  } else {
+#define PROP_PHI(LHS)                                                          \
+  if (auto un = dyn_cast<SCEVUnknown>(LHS)) {                                  \
+    if (auto pn = dyn_cast_or_null<PHINode>(un->getValue())) {                 \
+      const SCEV *sc = nullptr;                                                \
+      bool failed = false;                                                     \
+      for (auto &a : pn->incoming_values()) {                                  \
+        auto subsc = getSCEV(a);                                               \
+        if (sc == nullptr) {                                                   \
+          sc = subsc;                                                          \
+          continue;                                                            \
+        }                                                                      \
+        if (subsc != sc) {                                                     \
+          failed = true;                                                       \
+          break;                                                               \
+        }                                                                      \
+      }                                                                        \
+      if (!failed) {                                                           \
+        LHS = sc;                                                              \
+      }                                                                        \
+    }                                                                          \
+  }
+    PROP_PHI(LHS)
+    PROP_PHI(RHS)
+
+    // Try to evaluate any dependencies out of the loop.
+    LHS = getSCEVAtScope(LHS, L);
+    RHS = getSCEVAtScope(RHS, L);
+
+    // At this point, we would like to compute how many iterations of the
+    // loop the predicate will return true for these inputs.
+    if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
+      // If there is a loop-invariant, force it into the RHS.
+      std::swap(LHS, RHS);
+      Pred = ICmpInst::getSwappedPredicate(Pred);
+    }
 
-  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
-                                          AllowPredicates);
-  if (EL.hasAnyInfo())
-    return EL;
+    // Simplify the operands before analyzing them.
+    (void)SimplifyICmpOperands(Pred, LHS, RHS);
 
+    // If we have a comparison of a chrec against a constant, try to use value
+    // ranges to answer this query.
+    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
+      if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
+        if (AddRec->getLoop() == L) {
+          // Form the constant range.
+          ConstantRange CompRange =
+              ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
+
+          const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
+          if (!isa<SCEVCouldNotCompute>(Ret))
+            return Ret;
+        }
+
+    switch (Pred) {
+    case ICmpInst::ICMP_NE: { // while (X != Y)
+      // Convert to: while (X-Y != 0)
+      ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
+                                  AllowPredicates);
+      if (EL.hasAnyInfo())
+        return EL;
+      break;
+    }
+    case ICmpInst::ICMP_EQ: { // while (X == Y)
+      // Convert to: while (X-Y == 0)
+      ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
+      if (EL.hasAnyInfo())
+        return EL;
+      break;
+    }
+    case ICmpInst::ICMP_SLT:
+    case ICmpInst::ICMP_ULT:
+    case ICmpInst::ICMP_SLE:
+    case ICmpInst::ICMP_ULE: { // while (X < Y)
+      bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
+
+      if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) {
+        if (!isa<IntegerType>(RHS->getType()))
+          break;
+        SmallVector<const SCEV *, 2> sv = {
+            RHS, getConstant(
+                     ConstantInt::get(cast<IntegerType>(RHS->getType()), 1))};
+        // Since this is not an infinite loop by induction, RHS cannot be
+        // int_max/uint_max Therefore adding 1 does not wrap.
+        if (IsSigned)
+          RHS = getAddExpr(sv, SCEV::FlagNSW);
+        else
+          RHS = getAddExpr(sv, SCEV::FlagNUW);
+      }
+      ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
+                                      AllowPredicates);
+      if (EL.hasAnyInfo())
+        return EL;
+      break;
+    }
+    case ICmpInst::ICMP_SGT:
+    case ICmpInst::ICMP_UGT:
+    case ICmpInst::ICMP_SGE:
+    case ICmpInst::ICMP_UGE: { // while (X > Y)
+      bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE;
+      if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) {
+        if (!isa<IntegerType>(RHS->getType()))
+          break;
+        SmallVector<const SCEV *, 2> sv = {
+            RHS, getConstant(
+                     ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))};
+        // Since this is not an infinite loop by induction, RHS cannot be
+        // int_min/uint_min Therefore subtracting 1 does not wrap.
+        if (IsSigned)
+          RHS = getAddExpr(sv, SCEV::FlagNSW);
+        else
+          RHS = getAddExpr(sv, SCEV::FlagNUW);
+      }
+      ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned,
+                                         ControlsOnlyExit, AllowPredicates);
+      if (EL.hasAnyInfo())
+        return EL;
+      break;
+    }
+    default:
+      break;
+    }
+  }
   auto *ExhaustiveCount =
       computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
 

>From abb0ab463de42b5b66261fed48de69d8980b30c0 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 29 Feb 2024 14:30:36 -0600
Subject: [PATCH 07/13]  add Enzyme changes to SE howManyLessThans

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 100 ++++++++++++++++----------
 1 file changed, 63 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b6e88b563e2724..854cfec1e6805d 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12983,38 +12983,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
       const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
       if (AR && AR->getLoop() == L && AR->isAffine()) {
-        auto canProveNUW = [&]() {
-          // We can use the comparison to infer no-wrap flags only if it fully
-          // controls the loop exit.
-          if (!ControlsOnlyExit)
-            return false;
-
-          if (!isLoopInvariant(RHS, L))
-            return false;
-
-          if (!isKnownNonZero(AR->getStepRecurrence(*this)))
-            // We need the sequence defined by AR to strictly increase in the
-            // unsigned integer domain for the logic below to hold.
-            return false;
-
-          const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
-          const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
-          // If RHS <=u Limit, then there must exist a value V in the sequence
-          // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
-          // V <=u UINT_MAX.  Thus, we must exit the loop before unsigned
-          // overflow occurs.  This limit also implies that a signed comparison
-          // (in the wide bitwidth) is equivalent to an unsigned comparison as
-          // the high bits on both sides must be zero.
-          APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
-          APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
-          Limit = Limit.zext(OuterBitWidth);
-          return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
-        };
-        auto Flags = AR->getNoWrapFlags();
-        if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
-          Flags = setFlags(Flags, SCEV::FlagNUW);
+        if (!AssumeLoopExists) {
+          auto canProveNUW = [&]() {
+            // We can use the comparison to infer no-wrap flags only if it fully
+            // controls the loop exit.
+            if (!ControlsOnlyExit)
+              return false;
 
-        setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
+            if (!isLoopInvariant(RHS, L))
+              return false;
+
+            if (!isKnownNonZero(AR->getStepRecurrence(*this)))
+              // We need the sequence defined by AR to strictly increase in the
+              // unsigned integer domain for the logic below to hold.
+              return false;
+
+            const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
+            const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
+            // If RHS <=u Limit, then there must exist a value V in the sequence
+            // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
+            // V <=u UINT_MAX.  Thus, we must exit the loop before unsigned
+            // overflow occurs.  This limit also implies that a signed
+            // comparison (in the wide bitwidth) is equivalent to an unsigned
+            // comparison as the high bits on both sides must be zero.
+            APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
+            APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
+            Limit = Limit.zext(OuterBitWidth);
+            return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
+          };
+          auto Flags = AR->getNoWrapFlags();
+          if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
+            Flags = setFlags(Flags, SCEV::FlagNUW);
+
+          setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
+        } else {
+          auto Flags = AR->getNoWrapFlags();
+          if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) {
+            Flags = setFlags(Flags, SCEV::FlagNW);
+
+            SmallVector<const SCEV *, 4> Operands{AR->operands()};
+            Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+
+            setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
+          }
+        }
         if (AR->hasNoUnsignedWrap()) {
           // Emulate what getZeroExtendExpr would have done during construction
           // if we'd been able to infer the fact just above at that time.
@@ -13098,6 +13110,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
         !loopHasNoAbnormalExits(L))
       return getCouldNotCompute();
 
+    // This bailout is protecting the logic in computeMaxBECountForLT which
+    // has not yet been sufficiently auditted or tested with negative strides.
+    // We used to filter out all known-non-positive cases here, we're in the
+    // process of being less restrictive bit by bit.
+    if (AssumeLoopExists && IsSigned && isKnownNonPositive(Stride))
+      return getCouldNotCompute();
+
     if (!isKnownNonZero(Stride)) {
       // If we have a step of zero, and RHS isn't invariant in L, we don't know
       // if it might eventually be greater than start and if so, on which
@@ -13227,13 +13246,17 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (!BECount) {
     auto canProveRHSGreaterThanEqualStart = [&]() {
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
-      const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
-      const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
 
-      if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
-          isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
-        return true;
+      if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) {
+        if (AssumeLoopExists) {
+          return true;
+        }
+        const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
+        const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
 
+        if (isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
+          return true;
+      }
       // (RHS > Start - 1) implies RHS >= Start.
       // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
       //   "Start - 1" doesn't overflow.
@@ -13370,7 +13393,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
       !isa<SCEVCouldNotCompute>(BECount))
     ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
-
+  if (AssumeLoopExists) {
+    return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero,
+                     Predicates);
+  }
   const SCEV *SymbolicMaxBECount =
       isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
   return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,

>From c1d83de8b8bb83f3a93e4c271ab9dfd10f7e7950 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Fri, 1 Mar 2024 11:21:01 -0600
Subject: [PATCH 08/13] fixed issue in howManyLessThans where conditions were
 incorrectly dependent

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 11 +++++++----
 1 file changed, 7 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 854cfec1e6805d..492b33e0a7c233 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7416,7 +7416,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
   // A mustprogress loop without side effects must be finite.
   // TODO: The check used here is very conservative.  It's only *specific*
   // side effects which are well defined in infinite loops.
-  return this->AssumeLoopExists || isFinite(L) ||
+  return AssumeLoopExists || isFinite(L) ||
          (isMustProgress(L) && loopHasNoSideEffects(L));
 }
 
@@ -13248,9 +13248,12 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
 
       if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) {
-        if (AssumeLoopExists) {
-          return true;
-        }
+        return true;
+      }
+      // In the Enzyme MustExitScalarEvolutionCode, this check was missing
+      // I do not have enough context to know if these two checks should be
+      // mutually Exclusive. If they aren't then this bool check is unnecessary
+      if (!AssumeLoopExists) {
         const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
         const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
 

>From 66ab0c3e093988ac51258837ccc063f0955d7417 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 7 Mar 2024 09:47:02 -0600
Subject: [PATCH 09/13] incorporating changes from code review

---
 llvm/include/llvm/Analysis/ScalarEvolution.h |   4 +-
 llvm/lib/Analysis/ScalarEvolution.cpp        | 139 +++++++++----------
 2 files changed, 66 insertions(+), 77 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 4cc1954c1233f6..50dbe2aeec884a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -460,8 +460,8 @@ class ScalarEvolution {
     LoopComputable ///< The SCEV varies predictably with the loop.
   };
 
-  bool AssumeLoopExists = false;
-  void setAssumeLoopExists();
+  bool AssumeLoopExits = false;
+  void setAssumeLoopExits();
   llvm::SmallPtrSet<llvm::BasicBlock *, 4> GuaranteedUnreachable;
 
   /// An enum describing the relationship between a SCEV and a basic block.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 492b33e0a7c233..3b1fcab6e333b8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -510,7 +510,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
   return S;
 }
 
-void ScalarEvolution::setAssumeLoopExists() { this->AssumeLoopExists = true; }
+void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopExits = true; }
 
 SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                            const SCEV *op, Type *ty)
@@ -7416,7 +7416,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
   // A mustprogress loop without side effects must be finite.
   // TODO: The check used here is very conservative.  It's only *specific*
   // side effects which are well defined in infinite loops.
-  return AssumeLoopExists || isFinite(L) ||
+  return AssumeLoopExits || isFinite(L) ||
          (isMustProgress(L) && loopHasNoSideEffects(L));
 }
 
@@ -8832,7 +8832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
 ScalarEvolution::ExitLimit
 ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
                                       bool AllowPredicates) {
-  if (AssumeLoopExists) {
+  if (AssumeLoopExits) {
     SmallVector<BasicBlock *, 8> ExitingBlocks;
     L->getExitingBlocks(ExitingBlocks);
     for (auto &ExitingBlock : ExitingBlocks) {
@@ -8877,7 +8877,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
     BasicBlock *Exit = nullptr;
     for (auto *SBB : successors(ExitingBlock))
       if (!L->contains(SBB)) {
-        if (AssumeLoopExists and GuaranteedUnreachable.count(SBB))
+        if (AssumeLoopExits and GuaranteedUnreachable.count(SBB))
           continue;
         if (Exit) // Multiple exit successors.
           return getCouldNotCompute();
@@ -8949,7 +8949,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
-  if (!AssumeLoopExists) {
+  if (!AssumeLoopExits) {
     // Handle BinOp conditions (And, Or).
     if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
             Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
@@ -9076,30 +9076,30 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
   }
 
   // block was never executed in MustExitScalarEvolution code
-  if (!AssumeLoopExists) {
-    // If we're exiting based on the overflow flag of an x.with.overflow
-    // intrinsic with a constant step, we can form an equivalent icmp predicate
-    // and figure out how many iterations will be taken before we exit.
-    const WithOverflowInst *WO;
-    const APInt *C;
-    if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
-        match(WO->getRHS(), m_APInt(C))) {
-      ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
-          WO->getBinaryOp(), *C, WO->getNoWrapKind());
-      CmpInst::Predicate Pred;
-      APInt NewRHSC, Offset;
-      NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
-      if (!ExitIfTrue)
-        Pred = ICmpInst::getInversePredicate(Pred);
-      auto *LHS = getSCEV(WO->getLHS());
-      if (Offset != 0)
-        LHS = getAddExpr(LHS, getConstant(Offset));
-      auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
-                                         ControlsOnlyExit, AllowPredicates);
-      if (EL.hasAnyInfo())
-        return EL;
-    }
+
+  // If we're exiting based on the overflow flag of an x.with.overflow
+  // intrinsic with a constant step, we can form an equivalent icmp predicate
+  // and figure out how many iterations will be taken before we exit.
+  const WithOverflowInst *WO;
+  const APInt *C;
+  if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
+      match(WO->getRHS(), m_APInt(C))) {
+    ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
+        WO->getBinaryOp(), *C, WO->getNoWrapKind());
+    CmpInst::Predicate Pred;
+    APInt NewRHSC, Offset;
+    NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
+    if (!ExitIfTrue)
+      Pred = ICmpInst::getInversePredicate(Pred);
+    auto *LHS = getSCEV(WO->getLHS());
+    if (Offset != 0)
+      LHS = getAddExpr(LHS, getConstant(Offset));
+    auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
+                                       ControlsOnlyExit, AllowPredicates);
+    if (EL.hasAnyInfo())
+      return EL;
   }
+
   // If it's not an integer or pointer comparison then compute it the hard way.
   return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
 }
@@ -9198,7 +9198,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
-  if (!AssumeLoopExists) {
+  if (!AssumeLoopExits) {
     ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
                                             AllowPredicates);
     if (EL.hasAnyInfo())
@@ -9485,7 +9485,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
   // if not using enzyme executes by default
   // if using enzyme and the code is guaranteed unreachable,
   // the default destination doesn't matter
-  if (!AssumeLoopExists ||
+  if (!AssumeLoopExits ||
       !GuaranteedUnreachable.count(Switch->getDefaultDest())) {
     assert(L->contains(Switch->getDefaultDest()) &&
            "Default case must not exit the loop!");
@@ -12983,50 +12983,39 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
       const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
       if (AR && AR->getLoop() == L && AR->isAffine()) {
-        if (!AssumeLoopExists) {
-          auto canProveNUW = [&]() {
-            // We can use the comparison to infer no-wrap flags only if it fully
-            // controls the loop exit.
-            if (!ControlsOnlyExit)
-              return false;
-
-            if (!isLoopInvariant(RHS, L))
-              return false;
-
-            if (!isKnownNonZero(AR->getStepRecurrence(*this)))
-              // We need the sequence defined by AR to strictly increase in the
-              // unsigned integer domain for the logic below to hold.
-              return false;
-
-            const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
-            const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
-            // If RHS <=u Limit, then there must exist a value V in the sequence
-            // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
-            // V <=u UINT_MAX.  Thus, we must exit the loop before unsigned
-            // overflow occurs.  This limit also implies that a signed
-            // comparison (in the wide bitwidth) is equivalent to an unsigned
-            // comparison as the high bits on both sides must be zero.
-            APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
-            APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
-            Limit = Limit.zext(OuterBitWidth);
-            return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
-          };
-          auto Flags = AR->getNoWrapFlags();
-          if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
-            Flags = setFlags(Flags, SCEV::FlagNUW);
-
-          setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
-        } else {
-          auto Flags = AR->getNoWrapFlags();
-          if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) {
-            Flags = setFlags(Flags, SCEV::FlagNW);
+        auto canProveNUW = [&]() {
+          // We can use the comparison to infer no-wrap flags only if it fully
+          // controls the loop exit.
+          if (!ControlsOnlyExit)
+            return false;
+
+          if (!isLoopInvariant(RHS, L))
+            return false;
+
+          if (!isKnownNonZero(AR->getStepRecurrence(*this)))
+            // We need the sequence defined by AR to strictly increase in the
+            // unsigned integer domain for the logic below to hold.
+            return false;
+
+          const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
+          const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
+          // If RHS <=u Limit, then there must exist a value V in the sequence
+          // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
+          // V <=u UINT_MAX.  Thus, we must exit the loop before unsigned
+          // overflow occurs.  This limit also implies that a signed
+          // comparison (in the wide bitwidth) is equivalent to an unsigned
+          // comparison as the high bits on both sides must be zero.
+          APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
+          APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
+          Limit = Limit.zext(OuterBitWidth);
+          return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
+        };
+        auto Flags = AR->getNoWrapFlags();
+        if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
+          Flags = setFlags(Flags, SCEV::FlagNUW);
 
-            SmallVector<const SCEV *, 4> Operands{AR->operands()};
-            Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
+        setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
 
-            setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
-          }
-        }
         if (AR->hasNoUnsignedWrap()) {
           // Emulate what getZeroExtendExpr would have done during construction
           // if we'd been able to infer the fact just above at that time.
@@ -13114,7 +13103,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     // has not yet been sufficiently auditted or tested with negative strides.
     // We used to filter out all known-non-positive cases here, we're in the
     // process of being less restrictive bit by bit.
-    if (AssumeLoopExists && IsSigned && isKnownNonPositive(Stride))
+    if (AssumeLoopExits && IsSigned && isKnownNonPositive(Stride))
       return getCouldNotCompute();
 
     if (!isKnownNonZero(Stride)) {
@@ -13253,7 +13242,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // In the Enzyme MustExitScalarEvolutionCode, this check was missing
       // I do not have enough context to know if these two checks should be
       // mutually Exclusive. If they aren't then this bool check is unnecessary
-      if (!AssumeLoopExists) {
+      if (!AssumeLoopExits) {
         const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
         const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
 
@@ -13396,7 +13385,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
       !isa<SCEVCouldNotCompute>(BECount))
     ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
-  if (AssumeLoopExists) {
+  if (AssumeLoopExits) {
     return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero,
                      Predicates);
   }

>From 9b57191bf32c57dc62927cbc7d1c17ad04f4d91d Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Mon, 11 Mar 2024 08:44:39 -0500
Subject: [PATCH 10/13] removed unrelated change

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3b1fcab6e333b8..53aa2faacf1cd4 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9067,16 +9067,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
   // preserve the CFG and is temporarily leaving constant conditions
   // in place.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
-    if (ExitIfTrue == !CI->getZExtValue()) {
+    if (ExitIfTrue == !CI->getZExtValue())
       // The backedge is always taken.
       return getCouldNotCompute();
-    }
+
     // The backedge is never taken.
     return getZero(CI->getType());
   }
 
-  // block was never executed in MustExitScalarEvolution code
-
   // If we're exiting based on the overflow flag of an x.with.overflow
   // intrinsic with a constant step, we can form an equivalent icmp predicate
   // and figure out how many iterations will be taken before we exit.

>From 57767932c2ce69a71f672da7bc115e2796e529f9 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 14 Mar 2024 12:13:45 -0500
Subject: [PATCH 11/13] moved mustexit code to other computeExitLimitFromICmp
 definition

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 191 +++++++++-----------------
 1 file changed, 66 insertions(+), 125 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 53aa2faacf1cd4..9c54dcb0e3f905 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9196,12 +9196,25 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
 
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
-  if (!AssumeLoopExits) {
+
     ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
                                             AllowPredicates);
     if (EL.hasAnyInfo())
       return EL;
-  } else {
+
+    auto *ExhaustiveCount =
+        computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
+
+    if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
+      return ExhaustiveCount;
+
+    return computeShiftCompareExitLimit(
+        ExitCond->getOperand(0), ExitCond->getOperand(1), L, OriginalPred);
+}
+ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
+    const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+    bool ControlsOnlyExit, bool AllowPredicates) {
+  if (AssumeLoopExits) {
 #define PROP_PHI(LHS)                                                          \
   if (auto un = dyn_cast<SCEVUnknown>(LHS)) {                                  \
     if (auto pn = dyn_cast_or_null<PHINode>(un->getValue())) {                 \
@@ -9225,118 +9238,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   }
     PROP_PHI(LHS)
     PROP_PHI(RHS)
-
-    // Try to evaluate any dependencies out of the loop.
-    LHS = getSCEVAtScope(LHS, L);
-    RHS = getSCEVAtScope(RHS, L);
-
-    // At this point, we would like to compute how many iterations of the
-    // loop the predicate will return true for these inputs.
-    if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
-      // If there is a loop-invariant, force it into the RHS.
-      std::swap(LHS, RHS);
-      Pred = ICmpInst::getSwappedPredicate(Pred);
-    }
-
-    // Simplify the operands before analyzing them.
-    (void)SimplifyICmpOperands(Pred, LHS, RHS);
-
-    // If we have a comparison of a chrec against a constant, try to use value
-    // ranges to answer this query.
-    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
-      if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
-        if (AddRec->getLoop() == L) {
-          // Form the constant range.
-          ConstantRange CompRange =
-              ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
-
-          const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
-          if (!isa<SCEVCouldNotCompute>(Ret))
-            return Ret;
-        }
-
-    switch (Pred) {
-    case ICmpInst::ICMP_NE: { // while (X != Y)
-      // Convert to: while (X-Y != 0)
-      ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
-                                  AllowPredicates);
-      if (EL.hasAnyInfo())
-        return EL;
-      break;
-    }
-    case ICmpInst::ICMP_EQ: { // while (X == Y)
-      // Convert to: while (X-Y == 0)
-      ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
-      if (EL.hasAnyInfo())
-        return EL;
-      break;
-    }
-    case ICmpInst::ICMP_SLT:
-    case ICmpInst::ICMP_ULT:
-    case ICmpInst::ICMP_SLE:
-    case ICmpInst::ICMP_ULE: { // while (X < Y)
-      bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
-
-      if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) {
-        if (!isa<IntegerType>(RHS->getType()))
-          break;
-        SmallVector<const SCEV *, 2> sv = {
-            RHS, getConstant(
-                     ConstantInt::get(cast<IntegerType>(RHS->getType()), 1))};
-        // Since this is not an infinite loop by induction, RHS cannot be
-        // int_max/uint_max Therefore adding 1 does not wrap.
-        if (IsSigned)
-          RHS = getAddExpr(sv, SCEV::FlagNSW);
-        else
-          RHS = getAddExpr(sv, SCEV::FlagNUW);
-      }
-      ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
-                                      AllowPredicates);
-      if (EL.hasAnyInfo())
-        return EL;
-      break;
-    }
-    case ICmpInst::ICMP_SGT:
-    case ICmpInst::ICMP_UGT:
-    case ICmpInst::ICMP_SGE:
-    case ICmpInst::ICMP_UGE: { // while (X > Y)
-      bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE;
-      if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) {
-        if (!isa<IntegerType>(RHS->getType()))
-          break;
-        SmallVector<const SCEV *, 2> sv = {
-            RHS, getConstant(
-                     ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))};
-        // Since this is not an infinite loop by induction, RHS cannot be
-        // int_min/uint_min Therefore subtracting 1 does not wrap.
-        if (IsSigned)
-          RHS = getAddExpr(sv, SCEV::FlagNSW);
-        else
-          RHS = getAddExpr(sv, SCEV::FlagNUW);
-      }
-      ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned,
-                                         ControlsOnlyExit, AllowPredicates);
-      if (EL.hasAnyInfo())
-        return EL;
-      break;
-    }
-    default:
-      break;
-    }
   }
-  auto *ExhaustiveCount =
-      computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
-
-  if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
-    return ExhaustiveCount;
-
-  return computeShiftCompareExitLimit(ExitCond->getOperand(0),
-                                      ExitCond->getOperand(1), L, OriginalPred);
-}
-ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
-    const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
-    bool ControlsOnlyExit, bool AllowPredicates) {
-
   // Try to evaluate any dependencies out of the loop.
   LHS = getSCEVAtScope(LHS, L);
   RHS = getSCEVAtScope(RHS, L);
@@ -9349,6 +9251,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     Pred = ICmpInst::getSwappedPredicate(Pred);
   }
 
+  // was not present in Enzyme code, the last condition is true if
+  // AssumeLoopExits is true
+  // will the first two checks cause enzyme to fail?
   bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
                                loopIsFiniteByAssumption(L);
   // Simplify the operands before analyzing them.
@@ -9426,18 +9331,37 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     if (EL.hasAnyInfo()) return EL;
     break;
   }
+
   case ICmpInst::ICMP_SLE:
   case ICmpInst::ICMP_ULE:
-    // Since the loop is finite, an invariant RHS cannot include the boundary
-    // value, otherwise it would loop forever.
-    if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
-        !isLoopInvariant(RHS, L))
-      break;
-    RHS = getAddExpr(getOne(RHS->getType()), RHS);
+    if (!AssumeLoopExits) {
+      // Since the loop is finite, an invariant RHS cannot include the boundary
+      // value, otherwise it would loop forever.
+      if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
+          !isLoopInvariant(RHS, L))
+        break;
+      RHS = getAddExpr(getOne(RHS->getType()), RHS);
+    }
     [[fallthrough]];
+
   case ICmpInst::ICMP_SLT:
   case ICmpInst::ICMP_ULT: { // while (X < Y)
     bool IsSigned = ICmpInst::isSigned(Pred);
+    if (AssumeLoopExits) {
+      if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) {
+        if (!isa<IntegerType>(RHS->getType()))
+          break;
+        SmallVector<const SCEV *, 2> sv = {
+            RHS, getConstant(
+                     ConstantInt::get(cast<IntegerType>(RHS->getType()), 1))};
+        // Since this is not an infinite loop by induction, RHS cannot be
+        // int_max/uint_max Therefore adding 1 does not wrap.
+        if (IsSigned)
+          RHS = getAddExpr(sv, SCEV::FlagNSW);
+        else
+          RHS = getAddExpr(sv, SCEV::FlagNUW);
+      }
+    }
     ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
                                     AllowPredicates);
     if (EL.hasAnyInfo())
@@ -9446,16 +9370,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   }
   case ICmpInst::ICMP_SGE:
   case ICmpInst::ICMP_UGE:
-    // Since the loop is finite, an invariant RHS cannot include the boundary
-    // value, otherwise it would loop forever.
-    if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
-        !isLoopInvariant(RHS, L))
-      break;
-    RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
+    if (!AssumeLoopExits) {
+      // Since the loop is finite, an invariant RHS cannot include the boundary
+      // value, otherwise it would loop forever.
+      if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
+          !isLoopInvariant(RHS, L))
+        break;
+      RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
+    }
     [[fallthrough]];
   case ICmpInst::ICMP_SGT:
   case ICmpInst::ICMP_UGT: { // while (X > Y)
     bool IsSigned = ICmpInst::isSigned(Pred);
+    if (AssumeLoopExits) {
+      if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) {
+        if (!isa<IntegerType>(RHS->getType()))
+          break;
+        SmallVector<const SCEV *, 2> sv = {
+            RHS, getConstant(
+                     ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))};
+        // Since this is not an infinite loop by induction, RHS cannot be
+        // int_min/uint_min Therefore subtracting 1 does not wrap.
+        if (IsSigned)
+          RHS = getAddExpr(sv, SCEV::FlagNSW);
+        else
+          RHS = getAddExpr(sv, SCEV::FlagNUW);
+      }
+    }
     ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
                                        AllowPredicates);
     if (EL.hasAnyInfo())

>From bdabce85e51a2da335dd9c669981624c1eb1e6b6 Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 14 Mar 2024 12:45:11 -0500
Subject: [PATCH 12/13] reran git clang-format HEAD~1

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 19 +++++++++----------
 1 file changed, 9 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 9c54dcb0e3f905..e11d02f2c12e14 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -9197,19 +9197,18 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
   const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
   const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
 
-    ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
-                                            AllowPredicates);
-    if (EL.hasAnyInfo())
-      return EL;
+  ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
+                                          AllowPredicates);
+  if (EL.hasAnyInfo())
+    return EL;
 
-    auto *ExhaustiveCount =
-        computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
+  auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
 
-    if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
-      return ExhaustiveCount;
+  if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
+    return ExhaustiveCount;
 
-    return computeShiftCompareExitLimit(
-        ExitCond->getOperand(0), ExitCond->getOperand(1), L, OriginalPred);
+  return computeShiftCompareExitLimit(ExitCond->getOperand(0),
+                                      ExitCond->getOperand(1), L, OriginalPred);
 }
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
     const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,

>From a9c9251be1038e8af49402a30f96249448d68ecc Mon Sep 17 00:00:00 2001
From: skewballfox <joshua.ferguson.273 at gmail.com>
Date: Thu, 14 Mar 2024 14:25:18 -0500
Subject: [PATCH 13/13] removed redundant binOp code from CondImpl

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 123 +++++---------------------
 1 file changed, 20 insertions(+), 103 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e11d02f2c12e14..4375c254c83610 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8949,104 +8949,11 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
 ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
     ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
     bool ControlsOnlyExit, bool AllowPredicates) {
-  if (!AssumeLoopExits) {
-    // Handle BinOp conditions (And, Or).
-    if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
-            Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
-      return *LimitFromBinOp;
-  } else {
-    // Check if the controlling expression for this loop is an And or Or.
-    if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
-      if (BO->getOpcode() == Instruction::And) {
-        // Recurse on the operands of the and.
-        bool EitherMayExit = !ExitIfTrue;
-        ExitLimit EL0 = computeExitLimitFromCondCached(
-            Cache, L, BO->getOperand(0), ExitIfTrue,
-            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
-        ExitLimit EL1 = computeExitLimitFromCondCached(
-            Cache, L, BO->getOperand(1), ExitIfTrue,
-            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
-        const SCEV *BECount = getCouldNotCompute();
-        const SCEV *MaxBECount = getCouldNotCompute();
-        if (EitherMayExit) {
-          // Both conditions must be true for the loop to continue executing.
-          // Choose the less conservative count.
-          if (EL0.ExactNotTaken == getCouldNotCompute() ||
-              EL1.ExactNotTaken == getCouldNotCompute())
-            BECount = getCouldNotCompute();
-          else
-            BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken,
-                                                 EL1.ExactNotTaken);
-
-          if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
-            MaxBECount = EL1.ConstantMaxNotTaken;
-          else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
-            MaxBECount = EL0.ConstantMaxNotTaken;
-          else
-            MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
-                                                    EL1.ConstantMaxNotTaken);
-        } else {
-          // Both conditions must be true at the same time for the loop to exit.
-          // For now, be conservative.
-          if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
-            MaxBECount = EL0.ConstantMaxNotTaken;
-          if (EL0.ExactNotTaken == EL1.ExactNotTaken)
-            BECount = EL0.ExactNotTaken;
-        }
 
-        // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
-        // to be more aggressive when computing BECount than when computing
-        // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and
-        // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
-        // EL1.ConstantMaxNotTaken to not.
-        if (isa<SCEVCouldNotCompute>(MaxBECount) &&
-            !isa<SCEVCouldNotCompute>(BECount))
-          MaxBECount = getConstant(getUnsignedRangeMax(BECount));
-
-        return ExitLimit(BECount, MaxBECount, MaxBECount, false,
-                         {&EL0.Predicates, &EL1.Predicates});
-      }
-      if (BO->getOpcode() == Instruction::Or) {
-        // Recurse on the operands of the or.
-        bool EitherMayExit = ExitIfTrue;
-        ExitLimit EL0 = computeExitLimitFromCondCached(
-            Cache, L, BO->getOperand(0), ExitIfTrue,
-            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
-        ExitLimit EL1 = computeExitLimitFromCondCached(
-            Cache, L, BO->getOperand(1), ExitIfTrue,
-            ControlsOnlyExit && !EitherMayExit, AllowPredicates);
-        const SCEV *BECount = getCouldNotCompute();
-        const SCEV *MaxBECount = getCouldNotCompute();
-        if (EitherMayExit) {
-          // Both conditions must be false for the loop to continue executing.
-          // Choose the less conservative count.
-          if (EL0.ExactNotTaken == getCouldNotCompute() ||
-              EL1.ExactNotTaken == getCouldNotCompute())
-            BECount = getCouldNotCompute();
-          else
-            BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken,
-                                                 EL1.ExactNotTaken);
-
-          if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
-            MaxBECount = EL1.ConstantMaxNotTaken;
-          else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
-            MaxBECount = EL0.ConstantMaxNotTaken;
-          else
-            MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
-                                                    EL1.ConstantMaxNotTaken);
-        } else {
-          // Both conditions must be false at the same time for the loop to
-          // exit. For now, be conservative.
-          if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken)
-            MaxBECount = EL0.ConstantMaxNotTaken;
-          if (EL0.ExactNotTaken == EL1.ExactNotTaken)
-            BECount = EL0.ExactNotTaken;
-        }
-        return ExitLimit(BECount, MaxBECount, MaxBECount, false,
-                         {&EL0.Predicates, &EL1.Predicates});
-      }
-    }
-  }
+  // Handle BinOp conditions (And, Or).
+  if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
+          Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
+    return *LimitFromBinOp;
 
   // With an icmp, it may be feasible to compute an exact backedge-taken count.
   // Proceed to the next level to examine the icmp.
@@ -9139,6 +9046,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
   const SCEV *SymbolicMaxBECount = getCouldNotCompute();
   if (EitherMayExit) {
     bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
+
     // Both conditions must be same for the loop to continue executing.
     // Choose the less conservative count.
     if (EL0.ExactNotTaken != getCouldNotCompute() &&
@@ -9146,6 +9054,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
       BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
                                            UseSequentialUMin);
     }
+
     if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
       ConstantMaxBECount = EL1.ConstantMaxNotTaken;
     else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
@@ -9165,6 +9074,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
     // For now, be conservative.
     if (EL0.ExactNotTaken == EL1.ExactNotTaken)
       BECount = EL0.ExactNotTaken;
+    // This was executed in Enzyme's must exit code under the
+    // logic for when the binary op was OR
+    if (AssumeLoopExits && !IsAnd) {
+      if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+        ConstantMaxBECount = EL0.ExactNotTaken;
+    }
   }
 
   // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
@@ -9173,12 +9088,14 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
   // and
   // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
   // EL1.ConstantMaxNotTaken to not.
-  if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
-      !isa<SCEVCouldNotCompute>(BECount))
-    ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
-  if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
-    SymbolicMaxBECount =
-        isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
+  if (!AssumeLoopExits || !IsAnd) { // should skip if assume exits and OR
+    if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
+        !isa<SCEVCouldNotCompute>(BECount))
+      ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
+    if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
+      SymbolicMaxBECount =
+          isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
+  }
   return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
                    { &EL0.Predicates, &EL1.Predicates });
 }



More information about the llvm-commits mailing list