[llvm] 7f55209 - [SCEV] Extend trip count to avoid overflow by default

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 11 09:56:01 PDT 2021


Author: Philip Reames
Date: 2021-10-11T09:55:55-07:00
New Revision: 7f55209cee55fa2f7d5954f7ec7df77d90585a7b

URL: https://github.com/llvm/llvm-project/commit/7f55209cee55fa2f7d5954f7ec7df77d90585a7b
DIFF: https://github.com/llvm/llvm-project/commit/7f55209cee55fa2f7d5954f7ec7df77d90585a7b.diff

LOG: [SCEV] Extend trip count to avoid overflow by default

As a brief reminder, an "exit count" is the number of times the backedge executes before some event. It can be zero if we exit before the backedge is reached. A "trip count" is the number of times the loop header is entered if we branch into the loop. In general, TC = BTC + 1 and thus a zero trip count is ill defined

There is a cornercases which we don't handle well. Let's assume i8 for our examples to keep things simple. If BTC = 255, then the correct trip count is 256. However, 256 is not representable in i8.

In theory, code which needs to reason about trip counts is responsible for checking for this cornercase, and either bailing out, or handling it correctly. Historically, we don't have a great track record about actually doing so.

When reviewing D109676, I found myself asking a basic question. Was there any good reason to preserve the current wrap-to-zero behavior when converting from backedge taken counts to trip counts? After reviewing existing code, I could not find a single case which appears to correctly and precisely handle the overflow case.

This patch changes the default behavior to extend instead of wrap. That is, if the result might be 256, we return a value of i9 type to ensure we interpret the count correctly. I did leave the legacy behavior as an option since a) loop-flatten stops triggering if I extend due to weirdly specific pattern matching I didn't understand and b) we could reasonably use the mode if we'd externally established a lack of overflow.

I want to emphasize that this change is *not* NFC. There are two call sites (one in ScalarEvolution.cpp, one in LoopCacheAnalysis.cpp) which are switched to the extend semantics. The former appears imprecise (but correct) for a constant 255 BTC. The later appears incorrect, though I don't have a test case.

Differential Revision: https://reviews.llvm.org/D110587

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index e10d52e3a975..79e8634775ed 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -759,9 +759,13 @@ class ScalarEvolution {
   /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip
   /// count".  A "trip count" is the number of times the header of the loop
   /// will execute if an exit is taken after the specified number of backedges
-  /// have been taken.  (e.g. TripCount = ExitCount + 1)  A zero result
-  /// must be interpreted as a loop having an unknown trip count.
-  const SCEV *getTripCountFromExitCount(const SCEV *ExitCount);
+  /// have been taken.  (e.g. TripCount = ExitCount + 1).  Note that the
+  /// expression can overflow if ExitCount = UINT_MAX.  \p Extend controls
+  /// how potential overflow is handled.  If true, a wider result type is
+  /// returned. ex: EC = 255 (i8), TC = 256 (i9).  If false, result unsigned
+  /// wraps with 2s-complement semantics.  ex: EC = 255 (i8), TC = 0 (i8)
+  const SCEV *getTripCountFromExitCount(const SCEV *ExitCount,
+                                        bool Extend = true);
 
   /// Returns the exact trip count of the loop if we can compute it, and
   /// the result is a small constant.  '0' is used to represent an unknown

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6fbd402f1e7e..d8d39307adcf 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -7216,10 +7216,21 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
 //                   Iteration Count Computation Code
 //
 
-const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) {
-  // Get the trip count from the BE count by adding 1.  Overflow, results
-  // in zero which means "unknown".
-  return getAddExpr(ExitCount, getOne(ExitCount->getType()));
+const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
+                                                       bool Extend) {
+  if (isa<SCEVCouldNotCompute>(ExitCount))
+    return getCouldNotCompute();
+
+  auto *ExitCountType = ExitCount->getType();
+  assert(ExitCountType->isIntegerTy());
+
+  if (!Extend)
+    return getAddExpr(ExitCount, getOne(ExitCountType));
+
+  auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
+                                    1 + ExitCountType->getScalarSizeInBits());
+  return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
+                    getOne(WiderType));
 }
 
 static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {

diff  --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 181c9790169c..965d1575518e 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -202,7 +202,12 @@ static bool findLoopComponents(
     LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n");
     return false;
   }
-  const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount);
+  // The use of the Extend=false flag on getTripCountFromExitCount was added
+  // during a refactoring to preserve existing behavior.  However, there's
+  // nothing obvious in the surrounding code when handles the overflow case.
+  // FIXME: audit code to establish whether there's a latent bug here.
+  const SCEV *SCEVTripCount =
+    SE->getTripCountFromExitCount(BackedgeTakenCount, false);
   const SCEV *SCEVRHS = SE->getSCEV(RHS);
   if (SCEVRHS == SCEVTripCount)
     return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
@@ -214,7 +219,7 @@ static bool findLoopComponents(
       // Find the extended backedge taken count and extended trip count using
       // SCEV. One of these should now match the RHS of the compare.
       BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
-      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt);
+      SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
       if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
         LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
         return false;


        


More information about the llvm-commits mailing list