[llvm] 585742c - [SCEV] When computing trip count, only zext if necessary

Joshua Cao via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 10 19:58:47 PDT 2023


Author: Joshua Cao
Date: 2023-04-10T19:40:52-07:00
New Revision: 585742cbfccd734b19c75dff9709b20367506668

URL: https://github.com/llvm/llvm-project/commit/585742cbfccd734b19c75dff9709b20367506668
DIFF: https://github.com/llvm/llvm-project/commit/585742cbfccd734b19c75dff9709b20367506668.diff

LOG: [SCEV] When computing trip count, only zext if necessary

This patch improves on https://reviews.llvm.org/D110587. To summarize
the patch, given backedge-taken count BC, trip count TC is `BC + 1`.
However, we don't know if BC we might overflow. So the patch modifies TC
computation to `1 + zext(BC)`.

This patch only adds the zext if necessary by looking at the constant
range. If we can determine that BC cannot be the max value for its
bitwidth, then we know adding 1 will not overflow, and the zext is not
needed. We apply loop guards before computing TC to get more data.

The primary motivation is to support my work on more precise trip
multiples in https://reviews.llvm.org/D141823. For example:

```
void test(unsigned n)
  __builtin_assume(n % 6 == 0);
  for (unsigned i = 0; i < n; ++i)
    foo();
```

Prior to this patch, we had `TC = 1 + zext(-1 + 6 * ((6 umax %n) /u
6))<nuw>`. SCEV range computation is able to determine that the BC
cannot be the max value, so the zext is not needed. The result is `TC
-> (6 * ((6 umax %n) /u 6))<nuw>`. From here, we would be able to
determine that %n is a multiple of 6.

There was one change in LoopCacheAnalysis/LoopInterchange required.
Before this patch, if a loop has BC = false, it would compute `TC -> 1 +
zext(false) -> 1`, which was fine. After this patch, it computes `TC -> 1
+ false = true`. CacheAnalysis would then sign extend the `true`, which
was not the intended the behavior. I modified CacheAnalysis such that
it would only zero extend trip counts.

This patch is not NFC, but also does not change any SCEV outputs. I
would like to get this patch out first to make work with trip multiples
easier.

Differential Revision: https://reviews.llvm.org/D147117

Added: 
    

Modified: 
    llvm/lib/Analysis/LoopCacheAnalysis.cpp
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
index 46198f78b6433..c3a56639b5c8f 100644
--- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp
@@ -297,7 +297,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
     Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType());
     const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS);
     Stride = SE.getNoopOrAnyExtend(Stride, WiderType);
-    TripCount = SE.getNoopOrAnyExtend(TripCount, WiderType);
+    TripCount = SE.getNoopOrZeroExtend(TripCount, WiderType);
     const SCEV *Numerator = SE.getMulExpr(Stride, TripCount);
     RefCost = SE.getUDivExpr(Numerator, CacheLineSize);
 
@@ -323,8 +323,8 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
       const SCEV *TripCount =
           computeTripCount(*AR->getLoop(), *Sizes.back(), SE);
       Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType());
-      RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType),
-                              SE.getNoopOrAnyExtend(TripCount, WiderType));
+      RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType),
+                              SE.getNoopOrZeroExtend(TripCount, WiderType));
     }
 
     LLVM_DEBUG(dbgs().indent(4)
@@ -334,7 +334,7 @@ CacheCostTy IndexedReference::computeRefCost(const Loop &L,
 
   // Attempt to fold RefCost into a constant.
   if (auto ConstantCost = dyn_cast<SCEVConstant>(RefCost))
-    return ConstantCost->getValue()->getSExtValue();
+    return ConstantCost->getValue()->getZExtValue();
 
   LLVM_DEBUG(dbgs().indent(4)
              << "RefCost is not a constant! Setting to RefCost=InvalidCost "

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 174eea8d364ab..52bd161cf9ddf 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8045,6 +8045,12 @@ const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
   if (!Extend)
     return getAddExpr(ExitCount, getOne(ExitCountType));
 
+  ConstantRange ExitCountRange =
+      getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
+  if (!ExitCountRange.contains(
+          APInt::getMaxValue(ExitCountRange.getBitWidth())))
+    return getAddExpr(ExitCount, getOne(ExitCountType));
+
   auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
                                     1 + ExitCountType->getScalarSizeInBits());
   return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
@@ -8227,15 +8233,14 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
     return 1;
 
   // Get the trip count
-  const SCEV *TCExpr = getTripCountFromExitCount(ExitCount);
+  const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
 
   const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
   if (!TC)
     // Attempt to factor more general cases. Returns the greatest power of
     // two divisor. If overflow happens, the trip count expression is still
     // divisible by the greatest power of 2 divisor returned.
-    return 1U << std::min((uint32_t)31,
-                          GetMinTrailingZeros(applyLoopGuards(TCExpr, L)));
+    return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr));
 
   ConstantInt *Result = TC->getValue();
 


        


More information about the llvm-commits mailing list