[llvm] 5b35018 - [ScalarEvolution] Fix overflow in computeBECount.

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 8 10:10:54 PDT 2021


Author: Eli Friedman
Date: 2021-07-08T10:09:55-07:00
New Revision: 5b350183cdabd83573bc760ddf513f3e1d991bcb

URL: https://github.com/llvm/llvm-project/commit/5b350183cdabd83573bc760ddf513f3e1d991bcb
DIFF: https://github.com/llvm/llvm-project/commit/5b350183cdabd83573bc760ddf513f3e1d991bcb.diff

LOG: [ScalarEvolution] Fix overflow in computeBECount.

There are two issues with the current implementation of computeBECount:

1. It doesn't account for the possibility that adding "Stride - 1" to
Delta might overflow. For almost all loops, it doesn't, but it's not
actually proven anywhere.
2. It doesn't account for the possibility that Stride is zero. If Delta
is zero, the backedge is never taken; the value of Stride isn't
relevant. To handle this, we have to make sure that the expression
returned by computeBECount evaluates to zero.

To deal with this, add two new checks:

1. Use a variety of tricks to try to prove that the addition doesn't
overflow.  If the proof is impossible, use an alternate sequence which
never overflows.
2. Use umax(Stride, 1) to handle the possibility that Stride is zero.

Differential Revision: https://reviews.llvm.org/D105216

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride1.ll
    llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
    llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 9a560a14ca6fb..67664fdd5ff00 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2023,11 +2023,38 @@ class ScalarEvolution {
   createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI);
 
   /// Compute the backedge taken count knowing the interval 
diff erence, and
-  /// the stride for an inequality.  Result takes the form:
-  /// (Delta + (Stride - 1)) udiv Stride.
-  /// Caller must ensure that this expression either does not overflow or
-  /// that the result is undefined if it does.
-  const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride);
+  /// the stride for an inequality.
+  ///
+  /// Caller must ensure that non-negative N exists such that
+  /// (Start + Stride * N) >= End, and that computing "(Start + Stride * N)"
+  /// doesn't overflow. In other words:
+  /// 1. If IsSigned is true, Start <=s End. Otherwise, Start <=u End.
+  /// 2. If End is not equal to start and IsSigned is true, Stride >s 0. If
+  ///    End is not equal to start and IsSigned is false, Stride >u 0.
+  /// 3. The index variable doesn't overflow.
+  ///
+  /// If the preconditions hold, the backedge taken count is N.
+  ///
+  /// IsSigned determines whether End, Start, and Stride are treated as
+  /// signed values, for the purpose of optimizing the form of the result.
+  ///
+  /// This function tries to use an optimized form:
+  /// ((End - Start) + (Stride - 1)) /u Stride
+  ///
+  /// If it can't prove the addition doesn't overflow in that form, it uses
+  /// getUDivCeilSCEV.
+  const SCEV *computeBECount(bool IsSigned, const SCEV *Start, const SCEV *End,
+                             const SCEV *Stride);
+
+  /// Compute ceil(N / D). N and D are treated as unsigned values.
+  ///
+  /// Since SCEV doesn't have native ceiling division, this generates a
+  /// SCEV expression of the following form:
+  ///
+  /// umin(N, 1) + floor((N - umin(N, 1)) / D)
+  ///
+  /// A denominator of zero or poison is handled the same way as getUDivExpr().
+  const SCEV *getUDivCeilSCEV(const SCEV *N, const SCEV *D);
 
   /// Compute the maximum backedge count based on the range of values
   /// permitted by Start, End, and Stride. This is for loops of the form

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 843c04855bf74..7b72ed7c144de 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11497,11 +11497,108 @@ bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
   return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
 }
 
-const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta,
-                                            const SCEV *Step) {
-  const SCEV *One = getOne(Step->getType());
-  Delta = getAddExpr(Delta, getMinusSCEV(Step, One));
-  return getUDivExpr(Delta, Step);
+const SCEV *ScalarEvolution::computeBECount(bool IsSigned, const SCEV *Start,
+                                            const SCEV *End,
+                                            const SCEV *Stride) {
+  // The basic formula here is ceil((End - Start) / Stride).  Since SCEV
+  // doesn't natively have division that rounds up, we need to convert to
+  // floor division.
+  //
+  // MayOverflow is whether adding (End - Start) + (Stride - 1)
+  // can overflow if Stride is positive. It's a precondition of the
+  // function that "End - Start" doesn't overflow. We handle the case where
+  // Stride isn't positive later.
+  //
+  // In practice, the arithmetic almost never overflows, but we have to prove
+  // it.  We have a variety of ways to come up with a proof.
+  const SCEV *One = getOne(Stride->getType());
+  bool MayOverflow = [&] {
+    if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
+      if (StrideC->getAPInt().isPowerOf2()) {
+        // Suppose Stride is a power of two, and Start/End are unsigned
+        // integers.  Let UMAX be the largest representable unsigned
+        // integer.
+        //
+        // By the preconditions of this function (see comment in header), we
+        // know "(Start + Stride * N)" >= End, and this doesn't overflow.
+        // As a formula:
+        //
+        //   End <= (Start + Stride * N) <= UMAX
+        //
+        // Subtracting Start from all the terms:
+        //
+        //   End - Start <= Stride * N <= UMAX - Start
+        //
+        // Since Start is unsigned, UMAX - Start <= UMAX.  Therefore:
+        //
+        //   End - Start <= Stride * N <= UMAX
+        //
+        // Stride * N is a multiple of Stride. Therefore,
+        //
+        //   End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
+        //
+        // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
+        // Therefore, UMAX mod Stride == Stride - 1.  So we can write:
+        //
+        //   End - Start <= Stride * N <= UMAX - Stride - 1
+        //
+        // Dropping the middle term:
+        //
+        //   End - Start <= UMAX - Stride - 1
+        //
+        // Adding Stride - 1 to both sides:
+        //
+        //   (End - Start) + (Stride - 1) <= UMAX
+        //
+        // In other words, the addition doesn't have unsigned overflow.
+        //
+        // A similar proof works if we treat Start/End as signed values.
+        // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
+        // use signed max instead of unsigned max. Note that we're trying
+        // to prove a lack of unsigned overflow in either case.
+        return false;
+      }
+    }
+    if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
+      // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
+      // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
+      // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
+      //
+      // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
+      return false;
+    }
+    if (IsSigned && isKnownNonNegative(Start)) {
+      // IsSigned implies "Start <=s End <=s INT_MAX".
+      // "isKnownNonNegative(Start)" implies "Start >=s 0".
+      // Therefore, "0 <=s End - Start <=s INT_MAX - Start <= INT_MAX".
+      // IsSigned also implies "0 <=s Stride - 1 <s INT_MAX". Therefore,
+      // "(End - Start) + (Stride - 1) <u INT_MAX * 2 <u UINT_MAX".
+      return false;
+    }
+    return true;
+  }();
+
+  // Force the stride to at least one, so we don't divide by zero. The stride
+  // can be zero if Delta is zero. We don't actually care what value we use
+  // for Stride in this case, as long as it isn't zero.
+  Stride = getUMaxExpr(Stride, One);
+
+  const SCEV *Delta = getMinusSCEV(End, Start);
+  if (!MayOverflow) {
+    // floor((D + (S - 1)) / S)
+    // We prefer this formulation if it's legal because it's fewer operations.
+    return getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
+  }
+  return getUDivCeilSCEV(Delta, Stride);
+}
+
+const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
+  // umin(N, 1) + floor((N - umin(N, 1)) / D)
+  // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
+  // expression fixes the case of N=0.
+  const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
+  const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
+  return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
 }
 
 const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
@@ -11540,8 +11637,8 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
   APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
                           : APIntOps::umin(getUnsignedRangeMax(End), Limit);
 
-  MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */,
-                              getConstant(StrideForMaxBECount) /* Step */);
+  MaxBECount = getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
+                               getConstant(StrideForMaxBECount) /* Step */);
 
   return MaxBECount;
 }
@@ -11699,7 +11796,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   // is the LHS value of the less-than comparison the first time it is evaluated
   // and End is the RHS.
   const SCEV *BECountIfBackedgeTaken =
-    computeBECount(getMinusSCEV(End, Start), Stride);
+      computeBECount(IsSigned, Start, End, Stride);
   // If the loop entry is guarded by the result of the backedge test of the
   // first loop iteration, then we know the backedge will be taken at least
   // once and so the backedge taken count is as above. If not then we use the
@@ -11718,7 +11815,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
       End = RHS;
     else
       End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
-    BECount = computeBECount(getMinusSCEV(End, Start), Stride);
+    BECount = computeBECount(IsSigned, Start, End, Stride);
   }
 
   const SCEV *MaxBECount;
@@ -11804,7 +11901,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
       return End;
   }
 
-  const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride);
+  const SCEV *BECount = computeBECount(IsSigned, End, Start, Stride);
 
   APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
                             : getUnsignedRangeMax(Start);
@@ -11825,11 +11922,8 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
 
   const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
                                ? BECount
-                               : computeBECount(getConstant(MaxStart - MinEnd),
-                                                getConstant(MinStride));
-
-  if (isa<SCEVCouldNotCompute>(MaxBECount))
-    MaxBECount = BECount;
+                               : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
+                                                 getConstant(MinStride));
 
   return ExitLimit(BECount, MaxBECount, false, Predicates);
 }

diff  --git a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride1.ll b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride1.ll
index d780feb1251e3..bdf8eaf310283 100644
--- a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride1.ll
+++ b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride1.ll
@@ -1,14 +1,9 @@
 ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution | FileCheck %s
 ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
 
-; CHECK: Loop %bb: backedge-taken count is ((-5 + %x) /u 3)
+; CHECK: Loop %bb: backedge-taken count is (((-7 + (-1 * (1 umin (-7 + %x)))<nuw><nsw> + %x) /u 3) + (1 umin (-7 + %x)))
 ; CHECK: Loop %bb: max backedge-taken count is 1431655764
 
-
-; ScalarEvolution can't compute a trip count because it doesn't know if
-; dividing by the stride will have a remainder. This could theoretically
-; be teaching it how to use a more elaborate trip count computation.
-
 define i32 @f(i32 %x) nounwind readnone {
 entry:
 	%0 = icmp ugt i32 %x, 4		; <i1> [#uses=1]

diff  --git a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
index cece09305e0c6..adee3a11fae41 100644
--- a/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
+++ b/llvm/test/Analysis/ScalarEvolution/2008-11-18-Stride2.ll
@@ -1,12 +1,10 @@
 ; RUN: opt < %s -analyze -enable-new-pm=0 -scalar-evolution 2>&1 | FileCheck %s
 ; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 2>&1 | FileCheck %s
 
-; CHECK: Loop %bb: backedge-taken count is ((999 + (-1 * %x)) /u 3)
+; CHECK: Loop %bb: backedge-taken count is (((997 + (-1 * (1 umin (997 + (-1 * %x))))<nuw><nsw> + (-1 * %x)) /u 3) + (1 umin (997 + (-1 * %x))))
 ; CHECK: Loop %bb: max backedge-taken count is 334
 
-
-; This is a tricky testcase for unsigned wrap detection which ScalarEvolution
-; doesn't yet know how to do.
+; This is a tricky testcase for unsigned wrap detection.
 
 define i32 @f(i32 %x) nounwind readnone {
 entry:

diff  --git a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
index 6c599104dbbcb..8d0722191047a 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-unknown-stride.ll
@@ -4,8 +4,8 @@
 ; ScalarEvolution should be able to compute trip count of the loop by proving
 ; that this is not an infinite loop with side effects.
 
-; CHECK: Determining loop execution counts for: @foo1
-; CHECK: backedge-taken count is ((-1 + %n) /u %s)
+; CHECK-LABEL: Determining loop execution counts for: @foo1
+; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + %n) /u (1 umax %s))
 
 ; We should have a conservative estimate for the max backedge taken count for
 ; loops with unknown stride.
@@ -34,8 +34,8 @@ for.end:                                          ; preds = %for.body, %entry
 
 
 ; Check that we are able to compute trip count of a loop without an entry guard.
-; CHECK: Determining loop execution counts for: @foo2
-; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s)
+; CHECK-LABEL: Determining loop execution counts for: @foo2
+; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s))
 
 ; We should have a conservative estimate for the max backedge taken count for
 ; loops with unknown stride.
@@ -61,7 +61,7 @@ for.end:                                          ; preds = %for.body, %entry
 
 ; Check that without mustprogress we don't make assumptions about infinite
 ; loops being UB.
-; CHECK: Determining loop execution counts for: @foo3
+; CHECK-LABEL: Determining loop execution counts for: @foo3
 ; CHECK: Loop %for.body: Unpredictable backedge-taken count.
 ; CHECK: Loop %for.body: Unpredictable max backedge-taken count.
 
@@ -84,8 +84,8 @@ for.end:                                          ; preds = %for.body, %entry
 }
 
 ; Same as foo2, but with mustprogress on loop, not function
-; CHECK: Determining loop execution counts for: @foo4
-; CHECK: backedge-taken count is ((-1 + (%n smax %s)) /u %s)
+; CHECK-LABEL: Determining loop execution counts for: @foo4
+; CHECK: backedge-taken count is ((-1 + (-1 * %s) + (1 umax %s) + (%n smax %s)) /u (1 umax %s))
 ; CHECK: max backedge-taken count is -1
 
 define void @foo4(i32* nocapture %A, i32 %n, i32 %s) {
@@ -106,5 +106,31 @@ for.end:                                          ; preds = %for.body, %entry
   ret void
 }
 
+; A more complex case with pre-increment compare instead of post-increment.
+; CHECK-LABEL: Determining loop execution counts for: @foo5
+; CHECK: Loop %for.body: backedge-taken count is ((((-1 * (1 umin ((-1 * %start) + (%n smax %start))))<nuw><nsw> + (-1 * %start) + (%n smax %start)) /u (1 umax %s)) + (1 umin ((-1 * %start) + (%n smax %start))))
+
+; We should have a conservative estimate for the max backedge taken count for
+; loops with unknown stride.
+; CHECK: max backedge-taken count is -1
+
+define void @foo5(i32* nocapture %A, i32 %n, i32 %s, i32 %start) mustprogress {
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %entry, %for.body
+  %i.05 = phi i32 [ %add, %for.body ], [ %start, %entry ]
+  %arrayidx = getelementptr inbounds i32, i32* %A, i32 %i.05
+  %0 = load i32, i32* %arrayidx, align 4
+  %inc = add nsw i32 %0, 1
+  store i32 %inc, i32* %arrayidx, align 4
+  %add = add nsw i32 %i.05, %s
+  %cmp = icmp slt i32 %i.05, %n
+  br i1 %cmp, label %for.body, label %for.end
+
+for.end:                                          ; preds = %for.body, %entry
+  ret void
+}
+
 !8 = distinct !{!8, !9}
 !9 = !{!"llvm.loop.mustprogress"}


        


More information about the llvm-commits mailing list