[llvm] 09d879d - [SCEV] Common code for computing trip count in a fixed type [NFC-ish]
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Tue Apr 25 12:05:12 PDT 2023
Author: Philip Reames
Date: 2023-04-25T12:04:42-07:00
New Revision: 09d879d060ed31b22a6e72f7f5e44fe9b5660aa3
URL: https://github.com/llvm/llvm-project/commit/09d879d060ed31b22a6e72f7f5e44fe9b5660aa3
DIFF: https://github.com/llvm/llvm-project/commit/09d879d060ed31b22a6e72f7f5e44fe9b5660aa3.diff
LOG: [SCEV] Common code for computing trip count in a fixed type [NFC-ish]
This is a follow on to D147117 and D147355. In both cases, we were adding special cases to compute zext(BTC+1) instead of zext(BTC)+1 when the BTC+1 computation was known not to overflow.
Differential Revision: https://reviews.llvm.org/D148661
Added:
Modified:
llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/lib/Transforms/Scalar/LoopFlatten.cpp
llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index f27cf22639c96..0f281d09eefea 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -791,16 +791,19 @@ class ScalarEvolution {
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS);
+ /// A version of getTripCountFromExitCount below which always picks an
+ /// evaluation type which can not result in overflow.
+ const SCEV *getTripCountFromExitCount(const SCEV *ExitCount);
+
/// Convert from an "exit count" (i.e. "backedge taken count") to a "trip
/// count". A "trip count" is the number of times the header of the loop
/// will execute if an exit is taken after the specified number of backedges
/// have been taken. (e.g. TripCount = ExitCount + 1). Note that the
- /// expression can overflow if ExitCount = UINT_MAX. \p Extend controls
- /// how potential overflow is handled. If true, a wider result type is
- /// returned. ex: EC = 255 (i8), TC = 256 (i9). If false, result unsigned
- /// wraps with 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8)
- const SCEV *getTripCountFromExitCount(const SCEV *ExitCount,
- bool Extend = true);
+ /// expression can overflow if ExitCount = UINT_MAX. If EvalTy is not wide
+ /// enough to hold the result without overflow, result unsigned wraps with
+ /// 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8)
+ const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, Type *EvalTy,
+ const Loop *L);
/// Returns the exact trip count of the loop if we can compute it, and
/// the result is a small constant. '0' is used to represent an unknown
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6a3e91a0e3249..15bb95429ca6d 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -8037,27 +8037,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
// Iteration Count Computation Code
//
-const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
- bool Extend) {
+const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) {
if (isa<SCEVCouldNotCompute>(ExitCount))
return getCouldNotCompute();
auto *ExitCountType = ExitCount->getType();
assert(ExitCountType->isIntegerTy());
+ auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
+ 1 + ExitCountType->getScalarSizeInBits());
+ return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
+}
+
+const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount,
+ Type *EvalTy,
+ const Loop *L) {
+ if (isa<SCEVCouldNotCompute>(ExitCount))
+ return getCouldNotCompute();
- if (!Extend)
- return getAddExpr(ExitCount, getOne(ExitCountType));
+ unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
+ unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
- ConstantRange ExitCountRange =
+ auto CanAddOneWithoutOverflow = [&]() {
+ ConstantRange ExitCountRange =
getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
- if (!ExitCountRange.contains(
- APInt::getMaxValue(ExitCountRange.getBitWidth())))
- return getAddExpr(ExitCount, getOne(ExitCountType));
-
- auto *WiderType = Type::getIntNTy(ExitCountType->getContext(),
- 1 + ExitCountType->getScalarSizeInBits());
- return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType),
- getOne(WiderType));
+ if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
+ return true;
+
+ return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
+ getMinusOne(ExitCount->getType()));
+ };
+
+ // If we need to zero extend the backedge count, check if we can add one to
+ // it prior to zero extending without overflow. Provided this is safe, it
+ // allows better simplification of the +1.
+ if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
+ return getZeroExtendExpr(
+ getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
+
+ // Get the total trip count from the count by adding 1. This may wrap.
+ return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
}
static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
index 591f30c7d2264..edc8a4956dd1c 100644
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -315,12 +315,12 @@ static bool verifyTripCount(Value *RHS, Loop *L,
return false;
}
- // The Extend=false flag is used for getTripCountFromExitCount as we want
- // to verify and match it with the pattern matched tripcount. Please note
- // that overflow checks are performed in checkOverflow, but are first tried
- // to avoid by widening the IV.
+ // Evaluating in the trip count's type can not overflow here as the overflow
+ // checks are performed in checkOverflow, but are first tried to avoid by
+ // widening the IV.
const SCEV *SCEVTripCount =
- SE->getTripCountFromExitCount(BackedgeTakenCount, /*Extend=*/false);
+ SE->getTripCountFromExitCount(BackedgeTakenCount,
+ BackedgeTakenCount->getType(), L);
const SCEV *SCEVRHS = SE->getSCEV(RHS);
if (SCEVRHS == SCEVTripCount)
@@ -333,7 +333,8 @@ static bool verifyTripCount(Value *RHS, Loop *L,
// Find the extended backedge taken count and extended trip count using
// SCEV. One of these should now match the RHS of the compare.
BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType());
- SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false);
+ SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt,
+ RHS->getType(), L);
if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
LLVM_DEBUG(dbgs() << "Could not find valid trip count\n");
return false;
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 2c999d7210186..bb0099e409a9d 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -983,33 +983,6 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
return SE->getMinusSCEV(Start, Index);
}
-/// Compute trip count from the backedge taken count.
-static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
- Loop *CurLoop, const DataLayout *DL,
- ScalarEvolution *SE) {
- const SCEV *TripCountS = nullptr;
- // The # stored bytes is (BECount+1). Expand the trip count out to
- // pointer size if it isn't already.
- //
- // If we're going to need to zero extend the BE count, check if we can add
- // one to it prior to zero extending without overflow. Provided this is safe,
- // it allows better simplification of the +1.
- if (DL->getTypeSizeInBits(BECount->getType()) <
- DL->getTypeSizeInBits(IntPtr) &&
- SE->isLoopEntryGuardedByCond(
- CurLoop, ICmpInst::ICMP_NE, BECount,
- SE->getMinusOne(BECount->getType()))) {
- TripCountS = SE->getZeroExtendExpr(
- SE->getAddExpr(BECount, SE->getOne(BECount->getType())),
- IntPtr);
- } else {
- TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
- SE->getOne(IntPtr));
- }
-
- return TripCountS;
-}
-
/// Compute the number of bytes as a SCEV from the backedge taken count.
///
/// This also maps the SCEV into the provided type and tries to handle the
@@ -1017,8 +990,8 @@ static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
const SCEV *StoreSizeSCEV, Loop *CurLoop,
const DataLayout *DL, ScalarEvolution *SE) {
- const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
-
+ const SCEV *TripCountSCEV =
+ SE->getTripCountFromExitCount(BECount, IntPtr, CurLoop);
return SE->getMulExpr(TripCountSCEV,
SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
SCEV::FlagNUW);
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 435cb9a3018a0..645b62d72ea9f 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -987,35 +987,7 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
ScalarEvolution &SE = *PSE.getSE();
-
- unsigned BackEdgeSize = SE.getTypeSizeInBits(BackedgeTakenCount->getType());
- unsigned IdxSize = IdxTy->getPrimitiveSizeInBits();
-
- // If we need to need to zero extend the backedge count, check if we can
- // add one to it prior to zero extending without overflow. Provided this is
- // safe, it allows better simplification of the +1.
- if (OrigLoop && BackEdgeSize < IdxSize &&
- SE.isLoopEntryGuardedByCond(
- OrigLoop, ICmpInst::ICMP_NE, BackedgeTakenCount,
- SE.getMinusOne(BackedgeTakenCount->getType()))) {
- return SE.getZeroExtendExpr(
- SE.getAddExpr(BackedgeTakenCount,
- SE.getOne(BackedgeTakenCount->getType())),
- IdxTy);
- }
-
- // The exit count might have the type of i64 while the phi is i32. This can
- // happen if we have an induction variable that is sign extended before the
- // compare. The only way that we get a backedge taken count is that the
- // induction variable was signed and as such will not overflow. In such a case
- // truncation is legal.
- if (BackEdgeSize > IdxSize)
- BackedgeTakenCount = SE.getTruncateOrNoop(BackedgeTakenCount, IdxTy);
- BackedgeTakenCount = SE.getNoopOrZeroExtend(BackedgeTakenCount, IdxTy);
-
- // Get the total trip count from the count by adding 1.
- return SE.getAddExpr(BackedgeTakenCount,
- SE.getOne(BackedgeTakenCount->getType()));
+ return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
}
static Value *getRuntimeVFAsFloat(IRBuilderBase &B, Type *FTy,
More information about the llvm-commits
mailing list