[llvm] Fix exact backedge count algorithm in Scalar-Evolution (PR #92560)

via llvm-commits llvm-commits at lists.llvm.org
Sun May 19 09:28:36 PDT 2024


https://github.com/mrdaybird updated https://github.com/llvm/llvm-project/pull/92560

>From 250adcbc8456d31a8601528b43f586b618a90aa2 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Fri, 17 May 2024 19:29:36 +0530
Subject: [PATCH 1/2] Update howManyLessThans

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e..9bf5bf80b4570 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12941,12 +12941,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
-  // When the RHS is not invariant, we do not know the end bound of the loop and
-  // cannot calculate the ExactBECount needed by ExitLimit. However, we can
-  // calculate the MaxBECount, given the start, stride and max value for the end
-  // bound of the loop (RHS), and the fact that IV does not overflow (which is
-  // checked above).
   if (!isLoopInvariant(RHS, L)) {
+    // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
+    if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+       return howManyLessThans(getMinusSCEV(IV, RHSAddRec), 
+        getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
+    }
+    // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
+    // given the start, stride and max value for the end bound of the 
+    // loop (RHS), and the fact that IV does not overflow (which is
+    // checked above).
     const SCEV *MaxBECount = computeMaxBECountForLT(
         Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
     return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,

>From b0b81efc0b337528b2a9b08d2ac92be54ba3a8de Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Sun, 19 May 2024 21:29:03 +0530
Subject: [PATCH 2/2] Update howManyLessThans

---
 llvm/lib/Analysis/ScalarEvolution.cpp | 80 +++++++++++++++++++++------
 1 file changed, 64 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 9bf5bf80b4570..5fe7219b59f93 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12941,27 +12941,77 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
+  const SCEV *End = nullptr, *BECount = nullptr, 
+    *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
-    if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
-       return howManyLessThans(getMinusSCEV(IV, RHSAddRec), 
-        getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
-    }
-    // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
-    // given the start, stride and max value for the end bound of the 
-    // loop (RHS), and the fact that IV does not overflow (which is
-    // checked above).
-    const SCEV *MaxBECount = computeMaxBECountForLT(
-        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
-    return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
-                     MaxBECount, false /*MaxOrZero*/, Predicates);
+    if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+      /*
+        The structure of loop we are trying to calculate backedge-count of:
+        left = left_start
+        right = right_start
+        while(left < right){
+          // ... do something here ...
+          left += s1; // stride of left is s1>0
+          right -= s2; // stride of right is -s2 (s2 > 0)
+        }
+        // left and right are converging at the middle
+        // (maybe not exactly at center)
+
+      */
+      const SCEV *RHSStart = RHSAddRec->getStart();
+      const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
+      // if Stride-RHSStride>0 and does not overflow, we can write
+      //  backedge count as:
+      //  RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0 
+
+      // check if Stride-RHSStride will not overflow
+      if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+        const SCEV *Denominator = getMinusSCEV(Stride, RHSStride); 
+        if (isKnownPositive(Denominator)) {
+          End = IsSigned ? getSMaxExpr(RHSStart, Start) : 
+            getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
+
+          const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
+          
+          BECount = getUDivCeilSCEV(Delta, Denominator);
+          BECountIfBackedgeTaken = getUDivCeilSCEV(
+            getMinusSCEV(RHSStart, Start), Denominator);
+
+          const SCEV *ConstantMaxBECount;
+          bool MaxOrZero = false;
+          if (isa<SCEVConstant>(BECount)) {
+            ConstantMaxBECount = BECount;
+          } else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+            ConstantMaxBECount = BECountIfBackedgeTaken;
+            MaxOrZero = true;
+          } else {
+            ConstantMaxBECount = computeMaxBECountForLT(
+                Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), 
+                IsSigned);
+          }
+
+          const SCEV *SymbolicMaxBECount = BECount;
+          return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, 
+                          MaxOrZero, Predicates);
+        }
+      }
+    } 
+    if (BECount == nullptr) {
+      // If we cannot calculate ExactBECount, we can calculate the MaxBECount, 
+      // given the start, stride and max value for the end bound of the 
+      // loop (RHS), and the fact that IV does not overflow (which is
+      // checked above).
+      const SCEV *MaxBECount = computeMaxBECountForLT(
+          Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+      return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
+                      MaxBECount, false /*MaxOrZero*/, Predicates);
+    }
   }
 
   // We use the expression (max(End,Start)-Start)/Stride to describe the
   // backedge count, as if the backedge is taken at least once max(End,Start)
   // is End and so the result is as above, and if not max(End,Start) is Start
   // so we get a backedge count of zero.
-  const SCEV *BECount = nullptr;
   auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
   assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
   assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
@@ -12993,7 +13043,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     BECount = getUDivExpr(Numerator, Stride);
   }
 
-  const SCEV *BECountIfBackedgeTaken = nullptr;
   if (!BECount) {
     auto canProveRHSGreaterThanEqualStart = [&]() {
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -13021,7 +13070,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
     // If we know that RHS >= Start in the context of loop, then we know that
     // max(RHS, Start) = RHS at this point.
-    const SCEV *End;
     if (canProveRHSGreaterThanEqualStart()) {
       End = RHS;
     } else {



More information about the llvm-commits mailing list