[llvm] [SCEV] Collect and merge loop guards through PHI nodes with multiple incoming values (PR #113915)
Julian Nagele via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 28 07:39:52 PDT 2024
https://github.com/juliannagele created https://github.com/llvm/llvm-project/pull/113915
This patch aims to strengthen collection of loop guards by processing PHI nodes with multiple incoming values as follows: collect guards for all incoming values/blocks and try to merge them into a single one for the PHI node.
>From b595b8c98d547915c47f99e84dcf9204409aaaae Mon Sep 17 00:00:00 2001
From: Julian Nagele <j.nagele at apple.com>
Date: Fri, 18 Oct 2024 17:40:06 +0100
Subject: [PATCH] [SCEV] Collect and merge loop guards through PHI nodes with
multiple incoming Values
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 5 ++
llvm/lib/Analysis/ScalarEvolution.cpp | 85 +++++++++++++++++--
.../Analysis/ScalarEvolution/trip-count.ll | 82 ++++++++++++++++++
.../Transforms/PhaseOrdering/X86/pr38280.ll | 2 +-
4 files changed, 164 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 179a2c38d9d3c2..cdc46cf24a0546 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1316,6 +1316,11 @@ class ScalarEvolution {
LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+ static LoopGuards
+ collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+ const BasicBlock *Block, const BasicBlock *Pred,
+ SmallPtrSet<const BasicBlock *, 8> VisitedBlocks);
+
public:
/// Collect rewrite map for loop guards for loop \p L, together with flags
/// indicating if NUW and NSW can be preserved during rewriting.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c939270ed39a65..d9cab0471ef0f3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10648,7 +10648,7 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};
- return {nullptr, nullptr};
+ return {nullptr, BB};
}
/// SCEV structural equivalence is usually sufficient for testing whether two
@@ -15217,7 +15217,16 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
ScalarEvolution::LoopGuards
ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
+ BasicBlock *Header = L->getHeader();
+ BasicBlock *Pred = L->getLoopPredecessor();
LoopGuards Guards(SE);
+ return collectFromBlock(SE, Guards, Header, Pred, {});
+}
+
+ScalarEvolution::LoopGuards ScalarEvolution::LoopGuards::collectFromBlock(
+ ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
+ const BasicBlock *Block, const BasicBlock *Pred,
+ SmallPtrSet<const BasicBlock *, 8> VisitedBlocks) {
SmallVector<const SCEV *> ExprsToRewrite;
auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
const SCEV *RHS,
@@ -15556,14 +15565,13 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
}
};
- BasicBlock *Header = L->getHeader();
SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
// First, collect information from assumptions dominating the loop.
for (auto &AssumeVH : SE.AC.assumptions()) {
if (!AssumeVH)
continue;
auto *AssumeI = cast<CallInst>(AssumeVH);
- if (!SE.DT.dominates(AssumeI, Header))
+ if (!SE.DT.dominates(AssumeI, Block))
continue;
Terms.emplace_back(AssumeI->getOperand(0), true);
}
@@ -15574,8 +15582,8 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
if (GuardDecl)
for (const auto *GU : GuardDecl->users())
if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
- if (Guard->getFunction() == Header->getParent() &&
- SE.DT.dominates(Guard, Header))
+ if (Guard->getFunction() == Block->getParent() &&
+ SE.DT.dominates(Guard, Block))
Terms.emplace_back(Guard->getArgOperand(0), true);
// Third, collect conditions from dominating branches. Starting at the loop
@@ -15583,11 +15591,10 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
// predecessors that can be found that have unique successors leading to the
// original header.
// TODO: share this logic with isLoopEntryGuardedByCond.
- for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
- L->getLoopPredecessor(), Header);
- Pair.first;
+ std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
+ for (; Pair.first;
Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
-
+ VisitedBlocks.insert(Pair.second);
const BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
@@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
Terms.emplace_back(LoopEntryPredicate->getCondition(),
LoopEntryPredicate->getSuccessor(0) == Pair.second);
}
+ // Finally, if we stopped climbing the predecessor chain because
+ // there wasn't a unique one to continue, try to collect conditions
+ // for PHINodes by recursively following all of their incoming
+ // blocks and try to merge the found conditions to build a new one
+ // for the Phi.
+ if (Pair.second->hasNPredecessorsOrMore(2)) {
+ for (auto &Phi : Pair.second->phis()) {
+ if (!SE.isSCEVable(Phi.getType()))
+ continue;
+
+ using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
+ auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
+ &Phi](unsigned int In) -> MinMaxPattern {
+ LoopGuards G(SE);
+ if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
+ collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
+ VisitedBlocks);
+ const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
+ auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
+ if (!SM)
+ return {nullptr, scCouldNotCompute};
+ if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
+ return {C0, SM->getSCEVType()};
+ if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
+ return {C1, SM->getSCEVType()};
+ return {nullptr, scCouldNotCompute};
+ };
+ auto MergeMinMaxConst = [](MinMaxPattern P1,
+ MinMaxPattern P2) -> MinMaxPattern {
+ auto [C1, T1] = P1;
+ auto [C2, T2] = P2;
+ if (!C1 || !C2 || T1 != T2)
+ return {nullptr, scCouldNotCompute};
+ switch (T1) {
+ case scUMaxExpr:
+ return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
+ case scSMaxExpr:
+ return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
+ case scUMinExpr:
+ return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
+ case scSMinExpr:
+ return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
+ default:
+ llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
+ }
+ };
+ auto P = GetMinMaxConst(0);
+ for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
+ if (!P.first)
+ break;
+ P = MergeMinMaxConst(P, GetMinMaxConst(In));
+ }
+ if (P.first) {
+ const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
+ SmallVector<const SCEV *, 2> Ops({P.first, LHS});
+ const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
+ Guards.RewriteMap.insert({LHS, RHS});
+ }
+ }
+ }
// Now apply the information from the collected conditions to
// Guards.RewriteMap. Conditions are processed in reverse order, so the
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count.ll b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
index 8fc5b9b4096127..7304409814b0e1 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
@@ -211,3 +211,85 @@ for.body:
exit:
ret void
}
+
+define void @epilogue(i64 %count) {
+; CHECK-LABEL: 'epilogue'
+; CHECK-NEXT: Determining loop execution counts for: @epilogue
+; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 6
+; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
+; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
+; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT: Loop %while.body: Trip multiple is 1
+entry:
+ %cmp = icmp ugt i64 %count, 7
+ br i1 %cmp, label %while.body, label %epilogue.preheader
+
+while.body:
+ %iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
+ %sub = add i64 %iv, -8
+ %exitcond.not = icmp ugt i64 %sub, 7
+ br i1 %exitcond.not, label %while.body, label %while.loopexit
+
+while.loopexit:
+ %sub.exit = phi i64 [ %sub, %while.body ]
+ br label %epilogue.preheader
+
+epilogue.preheader:
+ %count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
+ %epilogue.cmp = icmp eq i64 %count.epilogue, 0
+ br i1 %epilogue.cmp, label %exit, label %epilogue
+
+epilogue:
+ %iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
+ %dec = add i64 %iv.epilogue, -1
+ %exitcond.epilogue = icmp eq i64 %dec, 0
+ br i1 %exitcond.epilogue, label %exit, label %epilogue
+
+exit:
+ ret void
+
+}
+
+define void @epilogue2(i64 %count) {
+; CHECK-LABEL: 'epilogue2'
+; CHECK-NEXT: Determining loop execution counts for: @epilogue2
+; CHECK-NEXT: Loop %epilogue: backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT: Loop %epilogue: constant max backedge-taken count is i64 8
+; CHECK-NEXT: Loop %epilogue: symbolic max backedge-taken count is (-1 + %count.epilogue)
+; CHECK-NEXT: Loop %epilogue: Trip multiple is 1
+; CHECK-NEXT: Loop %while.body: backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT: Loop %while.body: constant max backedge-taken count is i64 2305843009213693951
+; CHECK-NEXT: Loop %while.body: symbolic max backedge-taken count is ((-8 + %count) /u 8)
+; CHECK-NEXT: Loop %while.body: Trip multiple is 1
+entry:
+ %cmp = icmp ugt i64 %count, 9
+ br i1 %cmp, label %while.body, label %epilogue.preheader
+
+while.body:
+ %iv = phi i64 [ %sub, %while.body ], [ %count, %entry ]
+ %sub = add i64 %iv, -8
+ %exitcond.not = icmp ugt i64 %sub, 7
+ br i1 %exitcond.not, label %while.body, label %while.loopexit
+
+while.loopexit:
+ %sub.exit = phi i64 [ %sub, %while.body ]
+ br label %epilogue.preheader
+
+epilogue.preheader:
+ %count.epilogue = phi i64 [ %count, %entry ], [ %sub.exit, %while.loopexit ]
+ %epilogue.cmp = icmp eq i64 %count.epilogue, 0
+ br i1 %epilogue.cmp, label %exit, label %epilogue
+
+epilogue:
+ %iv.epilogue = phi i64 [ %dec, %epilogue ], [ %count.epilogue, %epilogue.preheader ]
+ %dec = add i64 %iv.epilogue, -1
+ %exitcond.epilogue = icmp eq i64 %dec, 0
+ br i1 %exitcond.epilogue, label %exit, label %epilogue
+
+exit:
+ ret void
+
+}
diff --git a/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll b/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
index 70b002f766b753..966d7e3cded0ab 100644
--- a/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
+++ b/llvm/test/Transforms/PhaseOrdering/X86/pr38280.ll
@@ -41,7 +41,7 @@ define void @apply_delta(ptr nocapture noundef %dst, ptr nocapture noundef reado
; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds i8, ptr [[DST_ADDR_130]], i64 1
; CHECK-NEXT: [[INCDEC_PTR8]] = getelementptr inbounds i8, ptr [[SRC_ADDR_129]], i64 1
; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i64 [[DEC]], 0
-; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[WHILE_END9]], label [[WHILE_BODY4]]
; CHECK: while.end9:
; CHECK-NEXT: ret void
;
More information about the llvm-commits
mailing list