[llvm] Fix exact backedge count algorithm in Scalar-Evolution (PR #92560)
via llvm-commits
llvm-commits at lists.llvm.org
Sun May 19 09:28:36 PDT 2024
https://github.com/mrdaybird updated https://github.com/llvm/llvm-project/pull/92560
>From 250adcbc8456d31a8601528b43f586b618a90aa2 Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Fri, 17 May 2024 19:29:36 +0530
Subject: [PATCH 1/2] Update howManyLessThans
---
llvm/lib/Analysis/ScalarEvolution.cpp | 14 +++++++++-----
1 file changed, 9 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 515b9d0744f6e..9bf5bf80b4570 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12941,12 +12941,16 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
return RHS;
}
- // When the RHS is not invariant, we do not know the end bound of the loop and
- // cannot calculate the ExactBECount needed by ExitLimit. However, we can
- // calculate the MaxBECount, given the start, stride and max value for the end
- // bound of the loop (RHS), and the fact that IV does not overflow (which is
- // checked above).
if (!isLoopInvariant(RHS, L)) {
+ // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
+ if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+ return howManyLessThans(getMinusSCEV(IV, RHSAddRec),
+ getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
+ }
+ // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
+ // given the start, stride and max value for the end bound of the
+ // loop (RHS), and the fact that IV does not overflow (which is
+ // checked above).
const SCEV *MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
>From b0b81efc0b337528b2a9b08d2ac92be54ba3a8de Mon Sep 17 00:00:00 2001
From: Vaibhav Pathak <pathakvaibhav at protonmail.com>
Date: Sun, 19 May 2024 21:29:03 +0530
Subject: [PATCH 2/2] Update howManyLessThans
---
llvm/lib/Analysis/ScalarEvolution.cpp | 80 +++++++++++++++++++++------
1 file changed, 64 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 9bf5bf80b4570..5fe7219b59f93 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -12941,27 +12941,77 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
return RHS;
}
+ const SCEV *End = nullptr, *BECount = nullptr,
+ *BECountIfBackedgeTaken = nullptr;
if (!isLoopInvariant(RHS, L)) {
- // If RHS is an add recurrence, try again with lhs=lhs-rhs and rhs=0
- if(auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
- return howManyLessThans(getMinusSCEV(IV, RHSAddRec),
- getZero(IV->getType()), L, true, ControlsOnlyExit, AllowPredicates);
- }
- // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
- // given the start, stride and max value for the end bound of the
- // loop (RHS), and the fact that IV does not overflow (which is
- // checked above).
- const SCEV *MaxBECount = computeMaxBECountForLT(
- Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
- return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
- MaxBECount, false /*MaxOrZero*/, Predicates);
+ if (auto RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS)){
+ /*
+ The structure of loop we are trying to calculate backedge-count of:
+ left = left_start
+ right = right_start
+ while(left < right){
+ // ... do something here ...
+ left += s1; // stride of left is s1>0
+ right -= s2; // stride of right is -s2 (s2 > 0)
+ }
+ // left and right are converging at the middle
+ // (maybe not exactly at center)
+
+ */
+ const SCEV *RHSStart = RHSAddRec->getStart();
+ const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
+ // if Stride-RHSStride>0 and does not overflow, we can write
+ // backedge count as:
+ // RHSStart >= Start ? (RHSStart - Start)/(Stride - RHSStride) ? 0
+
+ // check if Stride-RHSStride will not overflow
+ if (willNotOverflow(llvm::Instruction::Sub, true, Stride, RHSStride)) {
+ const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
+ if (isKnownPositive(Denominator)) {
+ End = IsSigned ? getSMaxExpr(RHSStart, Start) :
+ getUMaxExpr(RHSStart, Start); // max(RHSStart, Start)
+
+ const SCEV *Delta = getMinusSCEV(End, Start); // End >= Start
+
+ BECount = getUDivCeilSCEV(Delta, Denominator);
+ BECountIfBackedgeTaken = getUDivCeilSCEV(
+ getMinusSCEV(RHSStart, Start), Denominator);
+
+ const SCEV *ConstantMaxBECount;
+ bool MaxOrZero = false;
+ if (isa<SCEVConstant>(BECount)) {
+ ConstantMaxBECount = BECount;
+ } else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) {
+ ConstantMaxBECount = BECountIfBackedgeTaken;
+ MaxOrZero = true;
+ } else {
+ ConstantMaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()),
+ IsSigned);
+ }
+
+ const SCEV *SymbolicMaxBECount = BECount;
+ return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount,
+ MaxOrZero, Predicates);
+ }
+ }
+ }
+ if (BECount == nullptr) {
+ // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
+ // given the start, stride and max value for the end bound of the
+ // loop (RHS), and the fact that IV does not overflow (which is
+ // checked above).
+ const SCEV *MaxBECount = computeMaxBECountForLT(
+ Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
+ return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
+ MaxBECount, false /*MaxOrZero*/, Predicates);
+ }
}
// We use the expression (max(End,Start)-Start)/Stride to describe the
// backedge count, as if the backedge is taken at least once max(End,Start)
// is End and so the result is as above, and if not max(End,Start) is Start
// so we get a backedge count of zero.
- const SCEV *BECount = nullptr;
auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
@@ -12993,7 +13043,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
BECount = getUDivExpr(Numerator, Stride);
}
- const SCEV *BECountIfBackedgeTaken = nullptr;
if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
@@ -13021,7 +13070,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// If we know that RHS >= Start in the context of loop, then we know that
// max(RHS, Start) = RHS at this point.
- const SCEV *End;
if (canProveRHSGreaterThanEqualStart()) {
End = RHS;
} else {
More information about the llvm-commits
mailing list