[llvm] e4b4397 - [ScalarEvolution] Fix overflow when computing max trip counts

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 13 10:01:34 PDT 2021


Author: Philip Reames
Date: 2021-07-13T10:01:10-07:00
New Revision: e4b43973fbd41aee3b8197cf250e9fb9ac40f986

URL: https://github.com/llvm/llvm-project/commit/e4b43973fbd41aee3b8197cf250e9fb9ac40f986
DIFF: https://github.com/llvm/llvm-project/commit/e4b43973fbd41aee3b8197cf250e9fb9ac40f986.diff

LOG: [ScalarEvolution] Fix overflow when computing max trip counts

This is split from D105216 to reduce patch complexity.  Original code by Eli with very minor modification by me.

The primary point of this patch is to add the getUDivCeilSCEV routine.  I included the two callers with constant arguments as we know those must constant fold even without any of the fancy inference logic.

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 9a560a14ca6fb..dc8c14422f16f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -2029,6 +2029,16 @@ class ScalarEvolution {
   /// that the result is undefined if it does.
   const SCEV *computeBECount(const SCEV *Delta, const SCEV *Stride);
 
+  /// Compute ceil(N / D). N and D are treated as unsigned values.
+  ///
+  /// Since SCEV doesn't have native ceiling division, this generates a
+  /// SCEV expression of the following form:
+  ///
+  /// umin(N, 1) + floor((N - umin(N, 1)) / D)
+  ///
+  /// A denominator of zero or poison is handled the same way as getUDivExpr().
+  const SCEV *getUDivCeilSCEV(const SCEV *N, const SCEV *D);
+
   /// Compute the maximum backedge count based on the range of values
   /// permitted by Start, End, and Stride. This is for loops of the form
   /// {Start, +, Stride} LT End.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c9612ec54e0a9..5cb96a2111c76 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -11519,6 +11519,15 @@ bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
   return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
 }
 
+const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
+  // umin(N, 1) + floor((N - umin(N, 1)) / D)
+  // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
+  // expression fixes the case of N=0.
+  const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
+  const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
+  return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
+}
+
 const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta,
                                             const SCEV *Step) {
   const SCEV *One = getOne(Step->getType());
@@ -11562,8 +11571,8 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
   APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
                           : APIntOps::umin(getUnsignedRangeMax(End), Limit);
 
-  MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */,
-                              getConstant(StrideForMaxBECount) /* Step */);
+  MaxBECount = getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
+                               getConstant(StrideForMaxBECount) /* Step */);
 
   return MaxBECount;
 }
@@ -11847,8 +11856,8 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
 
   const SCEV *MaxBECount = isa<SCEVConstant>(BECount)
                                ? BECount
-                               : computeBECount(getConstant(MaxStart - MinEnd),
-                                                getConstant(MinStride));
+                               : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
+                                                 getConstant(MinStride));
 
   if (isa<SCEVCouldNotCompute>(MaxBECount))
     MaxBECount = BECount;


        


More information about the llvm-commits mailing list