[llvm] b7aec4f - [SCEV] Support rewriting ZExt expressions with loop guard info.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 16 03:16:24 PST 2021
Author: Florian Hahn
Date: 2021-11-16T11:16:07Z
New Revision: b7aec4f08e5e231cda3366af961e227e18c5587b
URL: https://github.com/llvm/llvm-project/commit/b7aec4f08e5e231cda3366af961e227e18c5587b
DIFF: https://github.com/llvm/llvm-project/commit/b7aec4f08e5e231cda3366af961e227e18c5587b.diff
LOG: [SCEV] Support rewriting ZExt expressions with loop guard info.
So far, applying loop guard information has been restricted to
SCEVUnknown. In a few cases, like PR40961 and PR52464, this leads to
SCEV failing to determine tight upper bounds for the backedge taken
count.
This patch adjusts SCEVLoopGuardRewriter and applyLoopGuards to support
re-writing ZExt expressions.
This is a first step towards fixing PR40961 and PR52464.
Reviewed By: reames
Differential Revision: https://reviews.llvm.org/D113577
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 5bdbe4029eb42..e8556c2094353 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -13694,7 +13694,8 @@ ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
/// in the map. It skips AddRecExpr because we cannot guarantee that the
/// replacement is loop invariant in the loop of the AddRec.
///
-/// At the moment only rewriting SCEVUnknown is supported.
+/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is
+/// supported.
class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> ⤅
@@ -13711,9 +13712,18 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
return Expr;
return I->second;
}
+
+ const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+ auto I = Map.find(Expr);
+ if (I == Map.end())
+ return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
+ Expr);
+ return I->second;
+ }
};
const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+ SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
DenseMap<const SCEV *, const SCEV *>
@@ -13736,6 +13746,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
+ ExprsToRewrite.push_back(LHSUnknown);
return;
}
}
@@ -13749,7 +13760,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
// Check for a condition of the form (-C1 + X < C2). InstCombine will
// create this form when combining two checks of the form (X u< C2 + C1) and
// (X >=u C1).
- auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() {
+ auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
+ &ExprsToRewrite]() {
auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
if (!AddExpr || AddExpr->getNumOperands() != 2)
return false;
@@ -13772,21 +13784,35 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
RewriteMap[LHSUnknown] = getUMaxExpr(
getConstant(ExactRegion.getUnsignedMin()),
getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
+ ExprsToRewrite.push_back(LHSUnknown);
return true;
};
if (MatchRangeCheckIdiom())
return;
- // For now, limit to conditions that provide information about unknown
- // expressions. RHS also cannot contain add recurrences.
- auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
- if (!LHSUnknown || containsAddRecurrence(RHS))
+ // If RHS is SCEVUnknown, make sure the information is applied to it.
+ if (isa<SCEVUnknown>(RHS)) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+ // If LHS is a constant, apply information to the other expression.
+ if (isa<SCEVConstant>(LHS)) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+ // Do not apply information for constants or if RHS contains an AddRec.
+ if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
+ return;
+
+ // Limit to expressions that can be rewritten.
+ if (!isa<SCEVUnknown>(LHS) && !isa<SCEVZeroExtendExpr>(LHS))
return;
// Check whether LHS has already been rewritten. In that case we want to
// chain further rewrites onto the already rewritten value.
auto I = RewriteMap.find(LHS);
const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
+
const SCEV *RewrittenRHS = nullptr;
switch (Predicate) {
case CmpInst::ICMP_ULT:
@@ -13830,8 +13856,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
break;
}
- if (RewrittenRHS)
+ if (RewrittenRHS) {
RewriteMap[LHS] = RewrittenRHS;
+ if (LHS == RewrittenLHS)
+ ExprsToRewrite.push_back(LHS);
+ }
};
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
@@ -13887,6 +13916,19 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (RewriteMap.empty())
return Expr;
+
+ // Now that all rewrite information is collect, rewrite the collected
+ // expressions with the information in the map. This applies information to
+ // sub-expressions.
+ if (ExprsToRewrite.size() > 1) {
+ for (const SCEV *Expr : ExprsToRewrite) {
+ const SCEV *RewriteTo = RewriteMap[Expr];
+ RewriteMap.erase(Expr);
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+ RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
+ }
+ }
+
SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
return Rewriter.visit(Expr);
}
diff --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
index cceb25d5cff6a..de8ca288a9bd0 100644
--- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
+++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll
@@ -7,7 +7,7 @@
define void @rewrite_zext(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 2
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@@ -36,7 +36,7 @@ exit:
define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))<nuw><nsw>)<nsw> /u 4)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@@ -134,7 +134,7 @@ exit:
define void @rewrite_zext_and_base_1(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
@@ -168,7 +168,7 @@ exit:
define void @rewrite_zext_and_base_2(i32 %n) {
; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
-; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951
+; CHECK-NEXT: Loop %loop: max backedge-taken count is 3
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))<nuw><nsw>)<nsw> /u 8)
; CHECK-NEXT: Predicates:
; CHECK: Loop %loop: Trip multiple is 1
More information about the llvm-commits
mailing list