[llvm] [SCEV] Handle backedge-count logic for std::reverse like loops (PR #92560)

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed May 22 12:29:22 PDT 2024


================
@@ -12941,179 +12941,223 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
-  // When the RHS is not invariant, we do not know the end bound of the loop and
-  // cannot calculate the ExactBECount needed by ExitLimit. However, we can
-  // calculate the MaxBECount, given the start, stride and max value for the end
-  // bound of the loop (RHS), and the fact that IV does not overflow (which is
-  // checked above).
+  const SCEV *End = nullptr, *BECount = nullptr,
+             *BECountIfBackedgeTaken = nullptr;
   if (!isLoopInvariant(RHS, L)) {
-    const SCEV *MaxBECount = computeMaxBECountForLT(
-        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
-    return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
-                     MaxBECount, false /*MaxOrZero*/, Predicates);
-  }
-
-  // We use the expression (max(End,Start)-Start)/Stride to describe the
-  // backedge count, as if the backedge is taken at least once max(End,Start)
-  // is End and so the result is as above, and if not max(End,Start) is Start
-  // so we get a backedge count of zero.
-  const SCEV *BECount = nullptr;
-  auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
-  assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
-  assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
-  assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
-  // Can we prove (max(RHS,Start) > Start - Stride?
-  if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
-      isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
-    // In this case, we can use a refined formula for computing backedge taken
-    // count.  The general formula remains:
-    //   "End-Start /uceiling Stride" where "End = max(RHS,Start)"
-    // We want to use the alternate formula:
-    //   "((End - 1) - (Start - Stride)) /u Stride"
-    // Let's do a quick case analysis to show these are equivalent under
-    // our precondition that max(RHS,Start) > Start - Stride.
-    // * For RHS <= Start, the backedge-taken count must be zero.
-    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
-    //   "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
-    //   "Stride - 1 /u Stride" which is indeed zero for all non-zero values
-    //     of Stride.  For 0 stride, we've use umin(1,Stride) above, reducing
-    //     this to the stride of 1 case.
-    // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
-    //   "((End - 1) - (Start - Stride)) /u Stride" reduces to
-    //   "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
-    //   "((RHS - (Start - Stride) - 1) /u Stride".
-    //   Our preconditions trivially imply no overflow in that form.
-    const SCEV *MinusOne = getMinusOne(Stride->getType());
-    const SCEV *Numerator =
-        getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
-    BECount = getUDivExpr(Numerator, Stride);
-  }
-
-  const SCEV *BECountIfBackedgeTaken = nullptr;
-  if (!BECount) {
-    auto canProveRHSGreaterThanEqualStart = [&]() {
-      auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
-      const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
-      const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
-
-      if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
-          isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
-        return true;
+    if (const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)) {
----------------
efriedma-quic wrote:

Do you need to check `RHS->getLoop() == L`?

https://github.com/llvm/llvm-project/pull/92560


More information about the llvm-commits mailing list