[llvm] cbba71b - [ScalarEvolution] Fix overflow in computeBECount.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 16 16:15:27 PDT 2021


Author: Eli Friedman
Date: 2021-07-16T16:15:18-07:00
New Revision: cbba71bfb50fb668b80ed430125a460279928272

URL: https://github.com/llvm/llvm-project/commit/cbba71bfb50fb668b80ed430125a460279928272
DIFF: https://github.com/llvm/llvm-project/commit/cbba71bfb50fb668b80ed430125a460279928272.diff

LOG: [ScalarEvolution] Fix overflow in computeBECount.

The current implementation of computeBECount doesn't account for the
possibility that adding "Stride - 1" to Delta might overflow. For almost
all loops, it doesn't, but it's not actually proven anywhere.

To deal with this, use a variety of tricks to try to prove that the
addition doesn't overflow.  If the proof is impossible, use an alternate
sequence which never overflows.

Differential Revision: https://reviews.llvm.org/D105216

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
    llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 788e9cae49727..ae9c73fede961 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2032,13 +2032,6 @@ class ScalarEvolution {
   Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
   createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI);
 
-  /// Compute the backedge taken count knowing the interval 
diff erence, and
-  /// the stride for an inequality.  Result takes the form:
-  /// (Delta + (Stride - 1)) udiv Stride.
-  /// Caller must ensure that this expression either does not overflow or
-  /// that the result is undefined if it does.
-  const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride);
-
   /// Compute the maximum backedge count based on the range of values
   /// permitted by Start, End, and Stride. This is for loops of the form
   /// {Start, +, Stride} LT End.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 88858badb597a..32023e567d3f3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11528,13 +11528,6 @@ const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
   return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
 }
 
-const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta,
-                                            const SCEV *Step) {
-  const SCEV *One = getOne(Step->getType());
-  Delta = getAddExpr(Delta, getMinusSCEV(Step, One));
-  return getUDivExpr(Delta, Step);
-}
-
 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
                                                     const SCEV *Stride,
                                                     const SCEV *End,
@@ -11743,7 +11736,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       return RHS;
   }
 
-  const SCEV *End = RHS;
   // When the RHS is not invariant, we do not know the end bound of the loop and
   // cannot calculate the ExactBECount needed by ExitLimit. However, we can
   // calculate the MaxBECount, given the start, stride and max value for the end
@@ -11755,13 +11747,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
     return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
                      false /*MaxOrZero*/, Predicates);
   }
-  // If the backedge is taken at least once, then it will be taken
-  // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start
-  // is the LHS value of the less-than comparison the first time it is evaluated
-  // and End is the RHS.
-  const SCEV *BECountIfBackedgeTaken =
-    computeBECount(getMinusSCEV(End, Start), Stride);
-  
+
   // We use the expression (max(End,Start)-Start)/Stride to describe the
   // backedge count, as if the backedge is taken at least once max(End,Start)
   // is End and so the result is as above, and if not max(End,Start) is Start
@@ -11796,6 +11782,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       BECount = getUDivExpr(Numerator, Stride);
     }
   }
+
+  const SCEV *BECountIfBackedgeTaken = nullptr;
   if (!BECount) {
     auto canProveRHSGreaterThanEqualStart = [&]() {
       auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -11819,18 +11807,112 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
 
     // If we know that RHS >= Start in the context of loop, then we know that
     // max(RHS, Start) = RHS at this point.
-    if (canProveRHSGreaterThanEqualStart())
+    const SCEV *End;
+    if (canProveRHSGreaterThanEqualStart()) {
       End = RHS;
-    else
+    } else {
+      // If RHS < Start, the backedge will be taken zero times.  So in
+      // general, we can write the backedge-taken count as:
+      //
+      //     RHS >= Start ? ceil(RHS - Start) / Stride : 0
+      //
+      // We convert it to the following to make it more convenient for SCEV:
+      //
+      //     ceil(max(RHS, Start) - Start) / Stride
       End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
-    BECount = computeBECount(getMinusSCEV(End, Start), Stride);
+
+      // See what would happen if we assume the backedge is taken. This is
+      // used to compute MaxBECount.
+      BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
+    }
+
+    // At this point, we know:
+    //
+    // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
+    // 2. The index variable doesn't overflow.
+    //
+    // Therefore, we know N exists such that
+    // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
+    // doesn't overflow.
+    //
+    // Using this information, try to prove whether the addition in
+    // "(Start - End) + (Stride - 1)" has unsigned overflow.
+    const SCEV *One = getOne(Stride->getType());
+    bool MayAddOverflow = [&] {
+      if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
+        if (StrideC->getAPInt().isPowerOf2()) {
+          // Suppose Stride is a power of two, and Start/End are unsigned
+          // integers.  Let UMAX be the largest representable unsigned
+          // integer.
+          //
+          // By the preconditions of this function, we know
+          // "(Start + Stride * N) >= End", and this doesn't overflow.
+          // As a formula:
+          //
+          //   End <= (Start + Stride * N) <= UMAX
+          //
+          // Subtracting Start from all the terms:
+          //
+          //   End - Start <= Stride * N <= UMAX - Start
+          //
+          // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
+          //
+          //   End - Start <= Stride * N <= UMAX
+          //
+          // Stride * N is a multiple of Stride. Therefore,
+          //
+          //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+          //
+          // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
+          // Therefore, UMAX mod Stride == Stride - 1.  So we can write:
+          //
+          //   End - Start <= Stride * N <= UMAX - Stride - 1
+          //
+          // Dropping the middle term:
+          //
+          //   End - Start <= UMAX - Stride - 1
+          //
+          // Adding Stride - 1 to both sides:
+          //
+          //   (End - Start) + (Stride - 1) <= UMAX
+          //
+          // In other words, the addition doesn't have unsigned overflow.
+          //
+          // A similar proof works if we treat Start/End as signed values.
+          // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
+          // use signed max instead of unsigned max. Note that we're trying
+          // to prove a lack of unsigned overflow in either case.
+          return false;
+        }
+      }
+      if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
+        // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
+        // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
+        // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
+        //
+        // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
+        return false;
+      }
+      return true;
+    }();
+
+    const SCEV *Delta = getMinusSCEV(End, Start);
+    if (!MayAddOverflow) {
+      // floor((D + (S - 1)) / S)
+      // We prefer this formulation if it's legal because it's fewer operations.
+      BECount =
+          getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
+    } else {
+      BECount = getUDivCeilSCEV(Delta, Stride);
+    }
   }
 
   const SCEV *MaxBECount;
   bool MaxOrZero = false;
-  if (isa<SCEVConstant>(BECount))
+  if (isa<SCEVConstant>(BECount)) {
     MaxBECount = BECount;
-  else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+  } else if (BECountIfBackedgeTaken &&
+             isa<SCEVConstant>(BECountIfBackedgeTaken)) {
     // If we know exactly how many times the backedge will be taken if it's
     // taken at least once, then the backedge count will either be that or
     // zero.
@@ -11909,7 +11991,12 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
       return End;
   }
 
-  const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride);
+  // Compute ((Start - End) + (Stride - 1)) / Stride.
+  // FIXME: This can overflow. Holding off on fixing this for now;
+  // howManyGreaterThans will hopefully be gone soon.
+  const SCEV *One = getOne(Stride->getType());
+  const SCEV *BECount = getUDivExpr(
+      getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
 
   APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
                             : getUnsignedRangeMax(Start);

diff  --git a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
index 390f9747e7c72..79f3fa5544769 100644
--- a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
+++ b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
@@ -1,7 +1,7 @@
 ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution 2>&1 | FileCheck %s
 ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 2>&1 | FileCheck %s
 
-; CHECK: Loop %bb: backedge-taken count is ((-1 + (-1 * %x) + (1000 umax (3 + %x))) /u 3)
+; CHECK: Loop %bb: backedge-taken count is (((-3 + (-1 * (1 umin (-3 + (-1 * %x) + (1000 umax (3 + %x)))))<nuw><nsw> + (-1 * %x) + (1000 umax (3 + %x))) /u 3) + (1 umin (-3 + (-1 * %x) + (1000 umax (3 + %x)))))
 ; CHECK: Loop %bb: max backedge-taken count is 334
 
 

diff  --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
index a085ee997d12f..ec03623fb9fbf 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
@@ -35,7 +35,7 @@ for.end:                                          ; preds = %for.body, %entry
 
 ; Check that we are able to compute trip count of a loop without an entry guard.
 ; CHECK: Determining loop execution counts for: @foo2
-; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s))
+; CHECK: backedge-taken count is ((((-1 * (1 umin ((-1 * %s) + (%n smax %s))))<nuw><nsw> + (-1 * %s) + (%n smax %s)) /u (1 umax %s)) + (1 umin ((-1 * %s) + (%n smax %s))))
 
 ; We should have a conservative estimate for the max backedge taken count for
 ; loops with unknown stride.
@@ -85,7 +85,7 @@ for.end:                                          ; preds = %for.body, %entry
 
 ; Same as foo2, but with mustprogress on loop, not function
 ; CHECK: Determining loop execution counts for: @foo4
-; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s))
+; CHECK: backedge-taken count is ((((-1 * (1 umin ((-1 * %s) + (%n smax %s))))<nuw><nsw> + (-1 * %s) + (%n smax %s)) /u (1 umax %s)) + (1 umin ((-1 * %s) + (%n smax %s))))
 ; CHECK: max backedge-taken count is -1
 
 define void @foo4(i32* nocapture %A, i32 %n, i32 %s) {
@@ -108,7 +108,7 @@ for.end:                                          ; preds = %for.body, %entry
 
 ; A more complex case with pre-increment compare instead of post-increment.
 ; CHECK-LABEL: Determining loop execution counts for: @foo5
-; CHECK: Loop %for.body: backedge-taken count is ((-1 + (-1 * %start) + (1 umax %s) + (%n smax %start)) /u (1 umax %s))
+; CHECK: Loop %for.body: backedge-taken count is ((((-1 * (1 umin ((-1 * %start) + (%n smax %start))))<nuw><nsw> + (-1 * %start) + (%n smax %start)) /u (1 umax %s)) + (1 umin ((-1 * %start) + (%n smax %start))))
 
 ; We should have a conservative estimate for the max backedge taken count for
 ; loops with unknown stride.


        


More information about the llvm-commits mailing list