[llvm] [SCEV] Support addrec in right hand side in howManyLessThans (PR #92560)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 08:55:54 PDT 2024


https://github.com/hiraditya updated https://github.com/llvm/llvm-project/pull/92560

>From 2861ab8b3309579b6875a5e84cc139e8fb708fec Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Fri, 17 May 2024 19:29:36 +0530
Subject: [PATCH 1/9] Update howManyLessThans

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 3982f5b1b8148..5978cdafcfbcb 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13000,12 +13000,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
-  // When the RHS is not invariant, we do not know the end bound of the loop and
-  // cannot calculate the ExactBECount needed by ExitLimit. However, we can
-  // calculate the MaxBECount, given the start, stride and max value for the end
-  // bound of the loop (RHS), and the fact that IV does not overflow (which is
-  // checked above).
   if (!isLoopInvariant(RHS, L)) {
+    // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
+    if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+       return howManyLessThans(getMinusSCEV(IV, RHSAddRec), 
+        getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
+    }
+    // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
+    // given the start, stride and max value for the end bound of the 
+    // loop (RHS), and the fact that IV does not overflow (which is
+    // checked above).
     const SCEV *MaxBECount = computeMaxBECountForLT(
         Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
     return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,

>From 60c174491a474ec9a4515b026c47066dd030459f Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Sun, 19 May 2024 21:29:03 +0530
Subject: [PATCH 2/9] Update howManyLessThans

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 80 +++++++++++++++++++++------
 1 file changed, 64 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 5978cdafcfbcb..650b5fde44e74 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13000,27 +13000,77 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
+  const SCEV *End = nullptr, *BECount = nullptr, 
+    *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
-    if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
-       return howManyLessThans(getMinusSCEV(IV, RHSAddRec), 
-        getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
-    }
-    // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
-    // given the start, stride and max value for the end bound of the 
-    // loop (RHS), and the fact that IV does not overflow (which is
-    // checked above).
-    const SCEV *MaxBECount = computeMaxBECountForLT(
-        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
-    return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
-                     MaxBECount, false /*MaxOrZero*/, Predicates);
+    if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+      /*
+        The structure of loop we are trying to calculate backedge-count of:
+        left = left_start
+        right = right_start
+        while(left < right){
+          // ... do something here ...
+          left += s1; // stride of left is s1>0
+          right -= s2; // stride of right is -s2 (s2 > 0)
+        }
+        // left and right are converging at the middle
+        // (maybe not exactly at center)
+
+      */
+      const SCEV *RHSStart = RHSAddRec->getStart();
+      const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
+      // if Stride-RHSStride>0 and does not overflow, we can write
+      //  backedge count as:
+      //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0 
+
+      // check if Stride-RHSStride will not overflow
+      if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+        const SCEV *Denominator = getMinusSCEV(Stride, RHSStride); 
+        if (isKnownPositive(Denominator)) {
+          End = IsSigned ? getSMaxExpr(RHSStart, Start) : 
+            getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
+
+          const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
+          
+          BECount = getUDivCeilSCEV(Delta, Denominator);
+          BECountIfBackedgeTaken = getUDivCeilSCEV(
+            getMinusSCEV(RHSStart, Start), Denominator);
+
+          const SCEV *ConstantMaxBECount;
+          bool MaxOrZero = false;
+          if (isa<SCEVConstant>(BECount)) {
+            ConstantMaxBECount = BECount;
+          } else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+            ConstantMaxBECount = BECountIfBackedgeTaken;
+            MaxOrZero = true;
+          } else {
+            ConstantMaxBECount = computeMaxBECountForLT(
+                Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), 
+                IsSigned);
+          }
+
+          const SCEV *SymbolicMaxBECount = BECount;
+          return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, 
+                          MaxOrZero, Predicates);
+        }
+      }
+    } 
+    if (BECount == nullptr) {
+      // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
+      // given the start, stride and max value for the end bound of the 
+      // loop (RHS), and the fact that IV does not overflow (which is
+      // checked above).
+      const SCEV *MaxBECount = computeMaxBECountForLT(
+          Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+      return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
+                      MaxBECount, false /*MaxOrZero*/, Predicates);
+    }
   }
 
   // We use the expression (max(End,Start)-Start)/Stride to describe the
   // backedge count, as if the backedge is taken at least once max(End,Start)
   // is End and so the result is as above, and if not max(End,Start) is Start
   // so we get a backedge count of zero.
-  const SCEV *BECount = nullptr;
   auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
   assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
   assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
@@ -13052,7 +13102,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     BECount = getUDivExpr(Numerator, Stride);
   }
 
-  const SCEV *BECountIfBackedgeTaken = nullptr;
   if (!BECount) {
     auto canProveRHSGreaterThanEqualStart = [&]() {
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -13080,7 +13129,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
     // If we know that RHS >= Start in the context of loop, then we know that
     // max(RHS, Start) = RHS at this point.
-    const SCEV *End;
     if (canProveRHSGreaterThanEqualStart()) {
       End = RHS;
     } else {

>From 2d78f1fced40c1e66c58a0e2adac067b94e8d612 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Sun, 19 May 2024 22:18:08 +0530
Subject: [PATCH 3/9] Add negative rhs stride condition

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 7 +++++--
 1 file changed, 5 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 650b5fde44e74..043bd69b32fe1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13023,8 +13023,11 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       //  backedge count as:
       //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0 
 
-      // check if Stride-RHSStride will not overflow
-      if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+      // check if RHSStride<0 and Stride-RHSStride will not overflow 
+      // FIXME: Can RHSStride be positive?
+      if (isKnownNegative(RHSStride) && 
+        willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+          
         const SCEV *Denominator = getMinusSCEV(Stride, RHSStride); 
         if (isKnownPositive(Denominator)) {
           End = IsSigned ? getSMaxExpr(RHSStart, Start) : 

>From 1b4945b0309cb126d9f1a1700a47cbdbdf46e93e Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Tue, 21 May 2024 02:24:56 +0530
Subject: [PATCH 4/9] Fix formatting

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 363 +++++++++++++-------------
 1 file changed, 176 insertions(+), 187 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 043bd69b32fe1..80225969225fe 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13000,10 +13000,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
-  const SCEV *End = nullptr, *BECount = nullptr, 
-    *BECountIfBackedgeTaken = nullptr;
+  const SCEV *End = nullptr, *BECount = nullptr,
+             *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+    if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)) {
       /*
         The structure of loop we are trying to calculate backedge-count of:
         left = left_start
@@ -13021,213 +13021,202 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
       // if Stride-RHSStride>0 and does not overflow, we can write
       //  backedge count as:
-      //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0 
+      //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0
 
-      // check if RHSStride<0 and Stride-RHSStride will not overflow 
+      // check if RHSStride<0 and Stride-RHSStride will not overflow
       // FIXME: Can RHSStride be positive?
-      if (isKnownNegative(RHSStride) && 
-        willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
-          
-        const SCEV *Denominator = getMinusSCEV(Stride, RHSStride); 
+      if (isKnownNegative(RHSStride) &&
+          willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+
+        const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
         if (isKnownPositive(Denominator)) {
-          End = IsSigned ? getSMaxExpr(RHSStart, Start) : 
-            getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
+          End = IsSigned ? getSMaxExpr(RHSStart, Start)
+                         : getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
 
           const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
-          
-          BECount = getUDivCeilSCEV(Delta, Denominator);
-          BECountIfBackedgeTaken = getUDivCeilSCEV(
-            getMinusSCEV(RHSStart, Start), Denominator);
-
-          const SCEV *ConstantMaxBECount;
-          bool MaxOrZero = false;
-          if (isa<SCEVConstant>(BECount)) {
-            ConstantMaxBECount = BECount;
-          } else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
-            ConstantMaxBECount = BECountIfBackedgeTaken;
-            MaxOrZero = true;
-          } else {
-            ConstantMaxBECount = computeMaxBECountForLT(
-                Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), 
-                IsSigned);
-          }
 
-          const SCEV *SymbolicMaxBECount = BECount;
-          return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, 
-                          MaxOrZero, Predicates);
+          BECount = getUDivCeilSCEV(Delta, Denominator);
+          BECountIfBackedgeTaken =
+              getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
         }
       }
-    } 
+    }
     if (BECount == nullptr) {
-      // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
-      // given the start, stride and max value for the end bound of the 
+      // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
+      // given the start, stride and max value for the end bound of the
       // loop (RHS), and the fact that IV does not overflow (which is
       // checked above).
       const SCEV *MaxBECount = computeMaxBECountForLT(
           Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
       return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
-                      MaxBECount, false /*MaxOrZero*/, Predicates);
-    }
-  }
-
-  // We use the expression (max(End,Start)-Start)/Stride to describe the
-  // backedge count, as if the backedge is taken at least once max(End,Start)
-  // is End and so the result is as above, and if not max(End,Start) is Start
-  // so we get a backedge count of zero.
-  auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
-  assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
-  assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
-  assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
-  // Can we prove (max(RHS,Start) > Start - Stride?
-  if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
-      isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
-    // In this case, we can use a refined formula for computing backedge taken
-    // count.  The general formula remains:
-    //   "End-Start /uceiling Stride" where "End = max(RHS,Start)"
-    // We want to use the alternate formula:
-    //   "((End - 1) - (Start - Stride)) /u Stride"
-    // Let's do a quick case analysis to show these are equivalent under
-    // our precondition that max(RHS,Start) > Start - Stride.
-    // * For RHS <= Start, the backedge-taken count must be zero.
-    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
-    //   "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
-    //   "Stride - 1 /u Stride" which is indeed zero for all non-zero values
-    //     of Stride.  For 0 stride, we've use umin(1,Stride) above, reducing
-    //     this to the stride of 1 case.
-    // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
-    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
-    //   "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
-    //   "((RHS - (Start - Stride) - 1) /u Stride".
-    //   Our preconditions trivially imply no overflow in that form.
-    const SCEV *MinusOne = getMinusOne(Stride->getType());
-    const SCEV *Numerator =
-        getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
-    BECount = getUDivExpr(Numerator, Stride);
-  }
-
-  if (!BECount) {
-    auto canProveRHSGreaterThanEqualStart = [&]() {
-      auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
-      const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
-      const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
-
-      if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
-          isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
-        return true;
+                       MaxBECount, false /*MaxOrZero*/, Predicates);
+    }
+  } else {
+    // We use the expression (max(End,Start)-Start)/Stride to describe the
+    // backedge count, as if the backedge is taken at least once
+    // max(End,Start) is End and so the result is as above, and if not
+    // max(End,Start) is Start so we get a backedge count of zero.
+    auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
+    assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
+    assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
+    assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
+    // Can we prove (max(RHS,Start) > Start - Stride?
+    if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
+        isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
+      // In this case, we can use a refined formula for computing backedge
+      // taken count.  The general formula remains:
+      //   "End-Start /uceiling Stride" where "End = max(RHS,Start)"
+      // We want to use the alternate formula:
+      //   "((End - 1) - (Start - Stride)) /u Stride"
+      // Let's do a quick case analysis to show these are equivalent under
+      // our precondition that max(RHS,Start) > Start - Stride.
+      // * For RHS <= Start, the backedge-taken count must be zero.
+      //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
+      //   "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
+      //   "Stride - 1 /u Stride" which is indeed zero for all non-zero values
+      //     of Stride.  For 0 stride, we've use umin(1,Stride) above,
+      //     reducing this to the stride of 1 case.
+      // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
+      // Stride".
+      //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
+      //   "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
+      //   "((RHS - (Start - Stride) - 1) /u Stride".
+      //   Our preconditions trivially imply no overflow in that form.
+      const SCEV *MinusOne = getMinusOne(Stride->getType());
+      const SCEV *Numerator =
+          getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
+      BECount = getUDivExpr(Numerator, Stride);
+    }
+
+    if (!BECount) {
+      auto canProveRHSGreaterThanEqualStart = [&]() {
+        auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
+        const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
+        const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
+
+        if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
+            isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
+          return true;
 
-      // (RHS > Start - 1) implies RHS >= Start.
-      // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
-      //   "Start - 1" doesn't overflow.
-      // * For signed comparison, if Start - 1 does overflow, it's equal
-      //   to INT_MAX, and "RHS >s INT_MAX" is trivially false.
-      // * For unsigned comparison, if Start - 1 does overflow, it's equal
-      //   to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
-      //
-      // FIXME: Should isLoopEntryGuardedByCond do this for us?
-      auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
-      auto *StartMinusOne = getAddExpr(OrigStart,
-                                       getMinusOne(OrigStart->getType()));
-      return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
-    };
+        // (RHS > Start - 1) implies RHS >= Start.
+        // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
+        //   "Start - 1" doesn't overflow.
+        // * For signed comparison, if Start - 1 does overflow, it's equal
+        //   to INT_MAX, and "RHS >s INT_MAX" is trivially false.
+        // * For unsigned comparison, if Start - 1 does overflow, it's equal
+        //   to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
+        //
+        // FIXME: Should isLoopEntryGuardedByCond do this for us?
+        auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
+        auto *StartMinusOne =
+            getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
+        return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
+      };
 
-    // If we know that RHS >= Start in the context of loop, then we know that
-    // max(RHS, Start) = RHS at this point.
-    if (canProveRHSGreaterThanEqualStart()) {
-      End = RHS;
-    } else {
-      // If RHS < Start, the backedge will be taken zero times.  So in
-      // general, we can write the backedge-taken count as:
+      // If we know that RHS >= Start in the context of loop, then we know
+      // that max(RHS, Start) = RHS at this point.
+      if (canProveRHSGreaterThanEqualStart()) {
+        End = RHS;
+      } else {
+        // If RHS < Start, the backedge will be taken zero times.  So in
+        // general, we can write the backedge-taken count as:
+        //
+        //     RHS >= Start ? ceil(RHS - Start) / Stride : 0
+        //
+        // We convert it to the following to make it more convenient for SCEV:
+        //
+        //     ceil(max(RHS, Start) - Start) / Stride
+        End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
+
+        // See what would happen if we assume the backedge is taken. This is
+        // used to compute MaxBECount.
+        BECountIfBackedgeTaken =
+            getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
+      }
+
+      // At this point, we know:
       //
-      //     RHS >= Start ? ceil(RHS - Start) / Stride : 0
+      // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
+      // 2. The index variable doesn't overflow.
       //
-      // We convert it to the following to make it more convenient for SCEV:
+      // Therefore, we know N exists such that
+      // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
+      // doesn't overflow.
       //
-      //     ceil(max(RHS, Start) - Start) / Stride
-      End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
-
-      // See what would happen if we assume the backedge is taken. This is
-      // used to compute MaxBECount.
-      BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
-    }
-
-    // At this point, we know:
-    //
-    // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
-    // 2. The index variable doesn't overflow.
-    //
-    // Therefore, we know N exists such that
-    // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
-    // doesn't overflow.
-    //
-    // Using this information, try to prove whether the addition in
-    // "(Start - End) + (Stride - 1)" has unsigned overflow.
-    const SCEV *One = getOne(Stride->getType());
-    bool MayAddOverflow = [&] {
-      if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
-        if (StrideC->getAPInt().isPowerOf2()) {
-          // Suppose Stride is a power of two, and Start/End are unsigned
-          // integers.  Let UMAX be the largest representable unsigned
-          // integer.
-          //
-          // By the preconditions of this function, we know
-          // "(Start + Stride * N) >= End", and this doesn't overflow.
-          // As a formula:
-          //
-          //   End <= (Start + Stride * N) <= UMAX
-          //
-          // Subtracting Start from all the terms:
-          //
-          //   End - Start <= Stride * N <= UMAX - Start
-          //
-          // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
-          //
-          //   End - Start <= Stride * N <= UMAX
-          //
-          // Stride * N is a multiple of Stride. Therefore,
-          //
-          //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
-          //
-          // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
-          // Therefore, UMAX mod Stride == Stride - 1.  So we can write:
-          //
-          //   End - Start <= Stride * N <= UMAX - Stride - 1
-          //
-          // Dropping the middle term:
-          //
-          //   End - Start <= UMAX - Stride - 1
-          //
-          // Adding Stride - 1 to both sides:
-          //
-          //   (End - Start) + (Stride - 1) <= UMAX
-          //
-          // In other words, the addition doesn't have unsigned overflow.
+      // Using this information, try to prove whether the addition in
+      // "(Start - End) + (Stride - 1)" has unsigned overflow.
+      const SCEV *One = getOne(Stride->getType());
+      bool MayAddOverflow = [&] {
+        if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
+          if (StrideC->getAPInt().isPowerOf2()) {
+            // Suppose Stride is a power of two, and Start/End are unsigned
+            // integers.  Let UMAX be the largest representable unsigned
+            // integer.
+            //
+            // By the preconditions of this function, we know
+            // "(Start + Stride * N) >= End", and this doesn't overflow.
+            // As a formula:
+            //
+            //   End <= (Start + Stride * N) <= UMAX
+            //
+            // Subtracting Start from all the terms:
+            //
+            //   End - Start <= Stride * N <= UMAX - Start
+            //
+            // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
+            //
+            //   End - Start <= Stride * N <= UMAX
+            //
+            // Stride * N is a multiple of Stride. Therefore,
+            //
+            //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+            //
+            // Since Stride is a power of two, UMAX + 1 is divisible by
+            // Stride. Therefore, UMAX mod Stride == Stride - 1.  So we can
+            // write:
+            //
+            //   End - Start <= Stride * N <= UMAX - Stride - 1
+            //
+            // Dropping the middle term:
+            //
+            //   End - Start <= UMAX - Stride - 1
+            //
+            // Adding Stride - 1 to both sides:
+            //
+            //   (End - Start) + (Stride - 1) <= UMAX
+            //
+            // In other words, the addition doesn't have unsigned overflow.
+            //
+            // A similar proof works if we treat Start/End as signed values.
+            // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
+            // to use signed max instead of unsigned max. Note that we're
+            // trying to prove a lack of unsigned overflow in either case.
+            return false;
+          }
+        }
+        if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
+          // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
+          // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
+          // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
+          // 1 <s End.
           //
-          // A similar proof works if we treat Start/End as signed values.
-          // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
-          // use signed max instead of unsigned max. Note that we're trying
-          // to prove a lack of unsigned overflow in either case.
+          // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
+          // End.
           return false;
         }
+        return true;
+      }();
+
+      const SCEV *Delta = getMinusSCEV(End, Start);
+      if (!MayAddOverflow) {
+        // floor((D + (S - 1)) / S)
+        // We prefer this formulation if it's legal because it's fewer
+        // operations.
+        BECount =
+            getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
+      } else {
+        BECount = getUDivCeilSCEV(Delta, Stride);
       }
-      if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
-        // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
-        // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
-        // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
-        //
-        // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
-        return false;
-      }
-      return true;
-    }();
-
-    const SCEV *Delta = getMinusSCEV(End, Start);
-    if (!MayAddOverflow) {
-      // floor((D + (S - 1)) / S)
-      // We prefer this formulation if it's legal because it's fewer operations.
-      BECount =
-          getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
-    } else {
-      BECount = getUDivCeilSCEV(Delta, Stride);
     }
   }
 

>From 778a7ddf39133e606df809459d4fb5ede6dfe9cc Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Tue, 21 May 2024 16:22:48 +0530
Subject: [PATCH 5/9] Add tests

---
 llvm/lib/Analysis/ScalarEvolution.cpp         |  4 +-
 llvm/test/Analysis/ScalarEvolution/pr92560.ll | 69 +++++++++++++++++++
 2 files changed, 71 insertions(+), 2 deletions(-)
 create mode 100644 llvm/test/Analysis/ScalarEvolution/pr92560.ll

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 80225969225fe..08e3ebe8c897d 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13003,7 +13003,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   const SCEV *End = nullptr, *BECount = nullptr,
              *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)) {
+    if (const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)) {
       /*
         The structure of loop we are trying to calculate backedge-count of:
         left = left_start
@@ -13026,7 +13026,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // check if RHSStride<0 and Stride-RHSStride will not overflow
       // FIXME: Can RHSStride be positive?
       if (isKnownNegative(RHSStride) &&
-          willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+          willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride, RHSStride)) {
 
         const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
         if (isKnownPositive(Denominator)) {
diff --git a/llvm/test/Analysis/ScalarEvolution/pr92560.ll b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
new file mode 100644
index 0000000000000..bfd92c4fe83ee
--- /dev/null
+++ b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -disable-output "-passes=print<scalar-evolution>" < %s 2>&1 | FileCheck %s
+
+define dso_local void @simple(i32 noundef %n) local_unnamed_addr {
+; CHECK-LABEL: 'simple'
+; CHECK-NEXT:  Classifying expressions for: @simple
+; CHECK-NEXT:    %right.06 = phi i32 [ %dec, %while.body ], [ %n, %entry ]
+; CHECK-NEXT:    --> {%n,+,-4}<nsw><%while.body> U: full-set S: full-set Exits: ((-4 * (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>)))<nsw> + %n) LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %left.05 = phi i32 [ %inc, %while.body ], [ 0, %entry ]
+; CHECK-NEXT:    --> {0,+,4}<nuw><nsw><%while.body> U: [0,2147483641) S: [0,2147483641) Exits: (4 * (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>)))<nuw> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %inc = add nuw nsw i32 %left.05, 4
+; CHECK-NEXT:    --> {4,+,4}<nuw><nsw><%while.body> U: [4,2147483645) S: [4,2147483645) Exits: (4 + (4 * (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>)))<nuw>)<nuw> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %dec = add nsw i32 %right.06, -4
+; CHECK-NEXT:    --> {(-4 + %n),+,-4}<nsw><%while.body> U: full-set S: full-set Exits: (-4 + (-4 * (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>)))<nsw> + %n) LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @simple
+; CHECK-NEXT:  Loop %while.body: backedge-taken count is (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i32 536870910
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is (((-4 + (-1 * (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))<nuw><nsw> + (4 smax (-4 + %n))) /u 8) + (1 umin (-4 + (4 smax (-4 + %n)))<nsw>))
+; CHECK-NEXT:  Loop %while.body: Trip multiple is 1
+;
+entry:
+  %cmp4 = icmp sgt i32 %n, 0
+  br i1 %cmp4, label %while.body, label %while.end
+
+while.body:
+  %right.06 = phi i32 [ %dec, %while.body ], [ %n, %entry ]
+  %left.05 = phi i32 [ %inc, %while.body ], [ 0, %entry ]
+  %inc = add nuw nsw i32 %left.05, 4
+  %dec = add nsw i32 %right.06, -4
+  %cmp = icmp slt i32 %inc, %dec
+  br i1 %cmp, label %while.body, label %while.end
+
+while.end:
+  ret void
+}
+
+define dso_local void @overflow(i32 noundef %n) local_unnamed_addr {
+; CHECK-LABEL: 'overflow'
+; CHECK-NEXT:  Classifying expressions for: @overflow
+; CHECK-NEXT:    %right.06 = phi i32 [ %dec, %while.body ], [ %n, %entry ]
+; CHECK-NEXT:    --> {%n,+,-1}<nsw><%while.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %left.05 = phi i32 [ %inc, %while.body ], [ 2147483647, %entry ]
+; CHECK-NEXT:    --> {2147483647,+,2147483647}<nuw><nsw><%while.body> U: [2147483647,-2147483648) S: [2147483647,-2147483648) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %inc = add nuw nsw i32 %left.05, 2147483647
+; CHECK-NEXT:    --> {-2,+,2147483647}<nuw><nsw><%while.body> U: [-2,-1) S: [-2,0) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %dec = add nsw i32 %right.06, -1
+; CHECK-NEXT:    --> {(-1 + %n),+,-1}<nsw><%while.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @overflow
+; CHECK-NEXT:  Loop %while.body: Unpredictable backedge-taken count.
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i32 1
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is i32 1
+;
+entry:
+  %cmp4 = icmp sgt i32 %n, 0
+  br i1 %cmp4, label %while.body, label %while.end
+
+while.body:
+  %right.06 = phi i32 [ %dec, %while.body ], [ %n, %entry ]
+  %left.05 = phi i32 [ %inc, %while.body ], [ 2147483647, %entry ]
+  %inc = add nuw nsw i32 %left.05, 2147483647
+  %dec = add nsw i32 %right.06, -1
+  %cmp = icmp slt i32 %inc, %dec
+  br i1 %cmp, label %while.body, label %while.end
+
+while.end:
+  ret void
+}
+
+

>From 35afcbb8ce2d0fd48fdfa151aa45fcac0bc7e75b Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Thu, 23 May 2024 10:59:57 +0530
Subject: [PATCH 6/9] Add condition for RHS to be in the same loop, fix
 formatting

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 08e3ebe8c897d..d85f28ca9aa2a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13003,7 +13003,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   const SCEV *End = nullptr, *BECount = nullptr,
              *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    if (const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)) {
+    const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
+    if (RHSAddRec != nullptr && RHSAddRec->getLoop() == L) {
       /*
         The structure of loop we are trying to calculate backedge-count of:
         left = left_start
@@ -13026,7 +13027,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       // check if RHSStride<0 and Stride-RHSStride will not overflow
       // FIXME: Can RHSStride be positive?
       if (isKnownNegative(RHSStride) &&
-          willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride, RHSStride)) {
+          willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
+                          RHSStride)) {
 
         const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
         if (isKnownPositive(Denominator)) {

>From 3394ae029150a772d5fe65cd08a889875a044ad1 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Wed, 29 May 2024 12:45:18 +0530
Subject: [PATCH 7/9] Update comment formatting

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 39 ++++++++++++++-------------
 1 file changed, 21 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d85f28ca9aa2a..6cdd1182e9759 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13005,26 +13005,28 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (!isLoopInvariant(RHS, L)) {
     const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
     if (RHSAddRec != nullptr && RHSAddRec->getLoop() == L) {
-      /*
-        The structure of loop we are trying to calculate backedge-count of:
-        left = left_start
-        right = right_start
-        while(left < right){
-          // ... do something here ...
-          left += s1; // stride of left is s1>0
-          right -= s2; // stride of right is -s2 (s2 > 0)
-        }
-        // left and right are converging at the middle
-        // (maybe not exactly at center)
+      // The structure of loop we are trying to calculate backedge count of:
+      //
+      //  left = left_start
+      //  right = right_start
+      //
+      //  while(left < right){
+      //    ... do something here ...
+      //    left += s1; // stride of left is s1>0
+      //    right -= s2; // stride of right is -s2 (s2 > 0)
+      //  }
+      //
+      // Here, left and right are converging somewhere in the middle.
 
-      */
       const SCEV *RHSStart = RHSAddRec->getStart();
       const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
-      // if Stride-RHSStride>0 and does not overflow, we can write
-      //  backedge count as:
-      //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0
 
-      // check if RHSStride<0 and Stride-RHSStride will not overflow
+      // If Stride - RHSStride is positive and does not overflow, we can write
+      // backedge count as ->
+      //    ceil((End - Start) /u (Stride - RHSStride))
+      //    Where, End = max(RHSStart, Start)
+
+      // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
       // FIXME: Can RHSStride be positive?
       if (isKnownNegative(RHSStride) &&
           willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
@@ -13033,9 +13035,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
         const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
         if (isKnownPositive(Denominator)) {
           End = IsSigned ? getSMaxExpr(RHSStart, Start)
-                         : getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
+                         : getUMaxExpr(RHSStart, Start);
 
-          const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
+          // We can do this because End >= Start, as End = max(RHSStart, Start)
+          const SCEV *Delta = getMinusSCEV(End, Start);
 
           BECount = getUDivCeilSCEV(Delta, Denominator);
           BECountIfBackedgeTaken =

>From 83469ae22475102561438dd621116d0a68a4edf1 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Thu, 30 May 2024 02:54:23 +0530
Subject: [PATCH 8/9] Check for RHSAddRec nowrap flags and add test

---
 llvm/lib/Analysis/ScalarEvolution.cpp         |  3 +-
 llvm/test/Analysis/ScalarEvolution/pr92560.ll | 44 +++++++++++++++++--
 2 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6cdd1182e9759..cdaeaf265cad7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13004,7 +13004,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
              *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
     const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
-    if (RHSAddRec != nullptr && RHSAddRec->getLoop() == L) {
+    bool RHSNoWrap = RHSAddRec->getNoWrapFlags();
+    if (RHSAddRec != nullptr && RHSAddRec->getLoop() == L && RHSNoWrap) {
       // The structure of loop we are trying to calculate backedge count of:
       //
       //  left = left_start
diff --git a/llvm/test/Analysis/ScalarEvolution/pr92560.ll b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
index bfd92c4fe83ee..3ce3a8f3e4eb8 100644
--- a/llvm/test/Analysis/ScalarEvolution/pr92560.ll
+++ b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
@@ -34,9 +34,10 @@ while.end:
   ret void
 }
 
-define dso_local void @overflow(i32 noundef %n) local_unnamed_addr {
-; CHECK-LABEL: 'overflow'
-; CHECK-NEXT:  Classifying expressions for: @overflow
+; Cannot find backedge-count because subtraction of strides is wrapping.
+define dso_local void @stride_overflow(i32 noundef %n) local_unnamed_addr {
+; CHECK-LABEL: 'stride_overflow'
+; CHECK-NEXT:  Classifying expressions for: @stride_overflow
 ; CHECK-NEXT:    %right.06 = phi i32 [ %dec, %while.body ], [ %n, %entry ]
 ; CHECK-NEXT:    --> {%n,+,-1}<nsw><%while.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
 ; CHECK-NEXT:    %left.05 = phi i32 [ %inc, %while.body ], [ 2147483647, %entry ]
@@ -45,7 +46,7 @@ define dso_local void @overflow(i32 noundef %n) local_unnamed_addr {
 ; CHECK-NEXT:    --> {-2,+,2147483647}<nuw><nsw><%while.body> U: [-2,-1) S: [-2,0) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
 ; CHECK-NEXT:    %dec = add nsw i32 %right.06, -1
 ; CHECK-NEXT:    --> {(-1 + %n),+,-1}<nsw><%while.body> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
-; CHECK-NEXT:  Determining loop execution counts for: @overflow
+; CHECK-NEXT:  Determining loop execution counts for: @stride_overflow
 ; CHECK-NEXT:  Loop %while.body: Unpredictable backedge-taken count.
 ; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i32 1
 ; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is i32 1
@@ -66,4 +67,39 @@ while.end:
   ret void
 }
 
+; Cannot find backedge-count because %conv110 is wrapping
+define dso_local void @rhs_wrapping() local_unnamed_addr {
+; CHECK-LABEL: 'rhs_wrapping'
+; CHECK-NEXT:  Classifying expressions for: @rhs_wrapping
+; CHECK-NEXT:    %a = alloca i8, align 1
+; CHECK-NEXT:    --> %a U: full-set S: full-set
+; CHECK-NEXT:    %conv110 = phi i32 [ 0, %entry ], [ %sext8, %while.body ]
+; CHECK-NEXT:    --> {0,+,-1090519040}<%while.body> U: [0,-16777215) S: [-2147483648,2130706433) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %conv9 = phi i32 [ -2147483648, %entry ], [ %sext, %while.body ]
+; CHECK-NEXT:    --> {-2147483648,+,16777216}<nsw><%while.body> U: [0,-16777215) S: [-2147483648,2113929217) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %sext = add nsw i32 %conv9, 16777216
+; CHECK-NEXT:    --> {-2130706432,+,16777216}<nsw><%while.body> U: [0,-16777215) S: [-2130706432,2130706433) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %sext8 = add i32 %conv110, -1090519040
+; CHECK-NEXT:    --> {-1090519040,+,-1090519040}<%while.body> U: [0,-16777215) S: [-2147483648,2130706433) Exits: <<Unknown>> LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @rhs_wrapping
+; CHECK-NEXT:  Loop %while.body: Unpredictable backedge-taken count.
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i32 254
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is i32 254
+;
+entry:
+  %a = alloca i8, align 1
+  br label %while.body
+
+while.body:
+  %conv110 = phi i32 [ 0, %entry ], [ %sext8, %while.body ]
+  %conv9 = phi i32 [ -2147483648, %entry ], [ %sext, %while.body ]
+  %sext = add nsw i32 %conv9, 16777216
+  %sext8 = add i32 %conv110, -1090519040
+  %cmp = icmp slt i32 %sext, %sext8
+  br i1 %cmp, label %while.body, label %while.end
+
+while.end:
+  ret void
+}
+
 

>From 2d49c3b56c03783d629f3b1be3c092d7501359f4 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Thu, 30 May 2024 13:15:53 +0530
Subject: [PATCH 9/9] Add test and fix last commit

---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 10 +++---
 llvm/test/Analysis/ScalarEvolution/pr92560.ll | 31 +++++++++++++++++++
 2 files changed, 35 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index cdaeaf265cad7..a79a6e18b1a6c 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13004,8 +13004,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
              *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
     const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
-    bool RHSNoWrap = RHSAddRec->getNoWrapFlags();
-    if (RHSAddRec != nullptr && RHSAddRec->getLoop() == L && RHSNoWrap) {
+    if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
+        RHSAddRec->getNoWrapFlags()) {
       // The structure of loop we are trying to calculate backedge count of:
       //
       //  left = left_start
@@ -13013,11 +13013,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       //
       //  while(left < right){
       //    ... do something here ...
-      //    left += s1; // stride of left is s1>0
-      //    right -= s2; // stride of right is -s2 (s2 > 0)
+      //    left += s1; // stride of left is s1 (s1 > 0)
+      //    right += s2; // stride of right is s2 (s2 < 0)
       //  }
       //
-      // Here, left and right are converging somewhere in the middle.
 
       const SCEV *RHSStart = RHSAddRec->getStart();
       const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
@@ -13028,7 +13027,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       //    Where, End = max(RHSStart, Start)
 
       // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
-      // FIXME: Can RHSStride be positive?
       if (isKnownNegative(RHSStride) &&
           willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
                           RHSStride)) {
diff --git a/llvm/test/Analysis/ScalarEvolution/pr92560.ll b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
index 3ce3a8f3e4eb8..bf1feec555af9 100644
--- a/llvm/test/Analysis/ScalarEvolution/pr92560.ll
+++ b/llvm/test/Analysis/ScalarEvolution/pr92560.ll
@@ -102,4 +102,35 @@ while.end:
   ret void
 }
 
+; abs(left_stride) != abs(right_stride)
+define dso_local void @simple2() local_unnamed_addr {
+; CHECK-LABEL: 'simple2'
+; CHECK-NEXT:  Classifying expressions for: @simple2
+; CHECK-NEXT:    %right.08 = phi i32 [ 50, %entry ], [ %add2, %while.body ]
+; CHECK-NEXT:    --> {50,+,-5}<nsw><%while.body> U: [25,51) S: [25,51) Exits: 25 LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %left.07 = phi i32 [ 0, %entry ], [ %add, %while.body ]
+; CHECK-NEXT:    --> {0,+,4}<nuw><nsw><%while.body> U: [0,21) S: [0,21) Exits: 20 LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %add = add nuw nsw i32 %left.07, 4
+; CHECK-NEXT:    --> {4,+,4}<nuw><nsw><%while.body> U: [4,25) S: [4,25) Exits: 24 LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:    %add2 = add nsw i32 %right.08, -5
+; CHECK-NEXT:    --> {45,+,-5}<nsw><%while.body> U: [20,46) S: [20,46) Exits: 20 LoopDispositions: { %while.body: Computable }
+; CHECK-NEXT:  Determining loop execution counts for: @simple2
+; CHECK-NEXT:  Loop %while.body: backedge-taken count is i32 5
+; CHECK-NEXT:  Loop %while.body: constant max backedge-taken count is i32 5
+; CHECK-NEXT:  Loop %while.body: symbolic max backedge-taken count is i32 5
+; CHECK-NEXT:  Loop %while.body: Trip multiple is 6
+;
+entry:
+  br label %while.body
 
+while.body:
+  %right.08 = phi i32 [ 50, %entry ], [ %add2, %while.body ]
+  %left.07 = phi i32 [ 0, %entry ], [ %add, %while.body ]
+  %add = add nuw nsw i32 %left.07, 4
+  %add2 = add nsw i32 %right.08, -5
+  %cmp = icmp slt i32 %add, %add2
+  br i1 %cmp, label %while.body, label %while.end
+
+while.end:
+  ret void
+}



More information about the llvm-commits mailing list