[llvm] f7c54c4 - [LoopUnroll] Fold all exits based on known trip count/multiple
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 17 11:58:43 PDT 2021
Author: Nikita Popov
Date: 2021-06-17T20:58:34+02:00
New Revision: f7c54c4603a2df8c0833f5ddc04a5e109bca1c14
URL: https://github.com/llvm/llvm-project/commit/f7c54c4603a2df8c0833f5ddc04a5e109bca1c14
DIFF: https://github.com/llvm/llvm-project/commit/f7c54c4603a2df8c0833f5ddc04a5e109bca1c14.diff
LOG: [LoopUnroll] Fold all exits based on known trip count/multiple
Fold all exits based on known trip count/multiple information from
SCEV. Previously only the latch exit or the single exit were folded.
This doesn't yet eliminate ULO.TripCount and ULO.TripMultiple
entirely: They're still used to a) decide whether runtime unrolling
should be performed and b) for ORE remarks. However, the core
unrolling logic is independent of them now.
Differential Revision: https://reviews.llvm.org/D104203
Added:
Modified:
llvm/lib/Transforms/Utils/LoopUnroll.cpp
llvm/test/Transforms/LoopUnroll/multiple-exits.ll
llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
index b3658fbe9e1c6..4d6112787a526 100644
--- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
@@ -328,6 +328,37 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
if (MaxTripCount && ULO.Count > MaxTripCount)
ULO.Count = MaxTripCount;
+ struct ExitInfo {
+ unsigned TripCount;
+ unsigned TripMultiple;
+ unsigned BreakoutTrip;
+ bool ExitOnTrue;
+ SmallVector<BasicBlock *> ExitingBlocks;
+ };
+ DenseMap<BasicBlock *, ExitInfo> ExitInfos;
+ SmallVector<BasicBlock *, 4> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+ for (auto *ExitingBlock : ExitingBlocks) {
+ // The folding code is not prepared to deal with non-branch instructions
+ // right now.
+ auto *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
+ if (!BI)
+ continue;
+
+ ExitInfo &Info = ExitInfos.try_emplace(ExitingBlock).first->second;
+ Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
+ Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
+ if (Info.TripCount != 0) {
+ Info.BreakoutTrip = Info.TripCount % ULO.Count;
+ Info.TripMultiple = 0;
+ } else {
+ Info.BreakoutTrip = Info.TripMultiple =
+ (unsigned)GreatestCommonDivisor64(ULO.Count, Info.TripMultiple);
+ }
+ Info.ExitOnTrue = !L->contains(BI->getSuccessor(0));
+ Info.ExitingBlocks.push_back(ExitingBlock);
+ }
+
// Are we eliminating the loop control altogether? Note that we can know
// we're eliminating the backedge without knowing exactly which iteration
// of the unrolled body exits.
@@ -362,31 +393,12 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
// A conditional branch which exits the loop, which can be optimized to an
// unconditional branch in the unrolled loop in some cases.
- BranchInst *ExitingBI = nullptr;
bool LatchIsExiting = L->isLoopExiting(LatchBlock);
- if (LatchIsExiting)
- ExitingBI = LatchBI;
- else if (BasicBlock *ExitingBlock = L->getExitingBlock())
- ExitingBI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) {
LLVM_DEBUG(
dbgs() << "Can't unroll; a conditional latch must exit the loop");
return LoopUnrollResult::Unmodified;
}
- LLVM_DEBUG({
- if (ExitingBI)
- dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName()
- << "\n";
- else
- dbgs() << " No single exiting block\n";
- });
-
- // Warning: ExactTripCount is the exact trip count for the block ending in
- // ExitingBI, not neccessarily an exact exit count *for the loop*. The
- // distinction comes when we have an exiting latch, but the loop exits
- // through another exit first.
- const unsigned ExactTripCount = ExitingBI ?
- SE->getSmallConstantTripCount(L,ExitingBI->getParent()) : 0;
// Loops containing convergent instructions must have a count that divides
// their TripMultiple.
@@ -421,6 +433,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
}
// If we know the trip count, we know the multiple...
+ // TODO: This is only used for the ORE code, remove it.
unsigned BreakoutTrip = 0;
if (ULO.TripCount != 0) {
BreakoutTrip = ULO.TripCount % ULO.Count;
@@ -504,12 +517,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
}
std::vector<BasicBlock *> Headers;
- std::vector<BasicBlock *> ExitingBlocks;
std::vector<BasicBlock *> Latches;
Headers.push_back(Header);
Latches.push_back(LatchBlock);
- if (ExitingBI)
- ExitingBlocks.push_back(ExitingBI->getParent());
// The current on-the-fly SSA update requires blocks to be processed in
// reverse postorder so that LastValueMap contains the correct value at each
@@ -609,9 +619,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
// Keep track of the exiting block and its successor block contained in
// the loop for the current iteration.
- if (ExitingBI)
- if (*BB == ExitingBlocks[0])
- ExitingBlocks.push_back(New);
+ auto ExitInfoIt = ExitInfos.find(*BB);
+ if (ExitInfoIt != ExitInfos.end())
+ ExitInfoIt->second.ExitingBlocks.push_back(New);
NewBlocks.push_back(New);
UnrolledLoopBlocks.push_back(New);
@@ -701,71 +711,79 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
- if (ExitingBI) {
- auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
- auto *Term = cast<BranchInst>(Src->getTerminator());
- const unsigned Idx = ExitOnTrue ^ WillExit;
- BasicBlock *Dest = Term->getSuccessor(Idx);
- BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
+ auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
+ auto *Term = cast<BranchInst>(Src->getTerminator());
+ const unsigned Idx = ExitOnTrue ^ WillExit;
+ BasicBlock *Dest = Term->getSuccessor(Idx);
+ BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
- // Remove predecessors from all non-Dest successors.
- DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
+ // Remove predecessors from all non-Dest successors.
+ DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
- // Replace the conditional branch with an unconditional one.
- BranchInst::Create(Dest, Term);
- Term->eraseFromParent();
+ // Replace the conditional branch with an unconditional one.
+ BranchInst::Create(Dest, Term);
+ Term->eraseFromParent();
- DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
- };
+ DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
+ };
- auto WillExit = [&](unsigned i, unsigned j) -> Optional<bool> {
- if (CompletelyUnroll) {
- if (PreserveOnlyFirst) {
- if (i == 0)
- return None;
- return j == 0;
- }
- // Complete (but possibly inexact) unrolling
- if (j == 0)
- return true;
- // Warning: ExactTripCount is the trip count of the exiting
- // block which ends in ExitingBI, not neccessarily the loop.
- if (ExactTripCount && j != ExactTripCount)
- return false;
- return None;
+ auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j,
+ bool IsLatch) -> Optional<bool> {
+ if (CompletelyUnroll) {
+ if (PreserveOnlyFirst) {
+ if (i == 0)
+ return None;
+ return j == 0;
}
-
- if (RuntimeTripCount && j != 0)
+ // Complete (but possibly inexact) unrolling
+ if (j == 0)
+ return true;
+ if (Info.TripCount && j != Info.TripCount)
return false;
+ return None;
+ }
- if (j != BreakoutTrip &&
- (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) {
- // If we know the trip count or a multiple of it, we can safely use an
- // unconditional branch for some iterations.
+ if (RuntimeTripCount) {
+ // If runtime unrolling inserts a prologue, information about non-latch
+ // exits may be stale.
+ if (IsLatch && j != 0)
return false;
- }
return None;
- };
+ }
+
+ if (j != Info.BreakoutTrip &&
+ (Info.TripMultiple == 0 || j % Info.TripMultiple != 0)) {
+ // If we know the trip count or a multiple of it, we can safely use an
+ // unconditional branch for some iterations.
+ return false;
+ }
+ return None;
+ };
- // Fold branches for iterations where we know that they will exit or not
- // exit.
- bool ExitOnTrue = !L->contains(ExitingBI->getSuccessor(0));
- for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
+ // Fold branches for iterations where we know that they will exit or not
+ // exit.
+ for (const auto &Pair : ExitInfos) {
+ const ExitInfo &Info = Pair.second;
+ for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) {
// The branch destination.
unsigned j = (i + 1) % e;
- Optional<bool> KnownWillExit = WillExit(i, j);
+ bool IsLatch = Pair.first == LatchBlock;
+ Optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch);
if (!KnownWillExit)
continue;
- // TODO: Also fold known-exiting branches for non-latch exits.
- if (*KnownWillExit && !LatchIsExiting)
+ // We don't fold known-exiting branches for non-latch exits here,
+ // because this ensures that both all loop blocks and all exit blocks
+ // remain reachable in the CFG.
+ // TODO: We could fold these branches, but it would require much more
+ // sophisticated updates to LoopInfo.
+ if (*KnownWillExit && !IsLatch)
continue;
- SetDest(ExitingBlocks[i], *KnownWillExit, ExitOnTrue);
+ SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue);
}
}
-
// When completely unrolling, the last latch becomes unreachable.
if (!LatchIsExiting && CompletelyUnroll)
changeToUnreachable(Latches.back()->getTerminator(), /* UseTrap */ false,
diff --git a/llvm/test/Transforms/LoopUnroll/multiple-exits.ll b/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
index 8a3f51a1fb94e..decba5c654ddc 100644
--- a/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
+++ b/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
@@ -9,49 +9,49 @@ define void @test1() {
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH:%.*]], label [[EXIT:%.*]]
+; CHECK-NEXT: br label [[LATCH:%.*]]
; CHECK: latch:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_1:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_1:%.*]]
; CHECK: exit:
; CHECK-NEXT: ret void
; CHECK: latch.1:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_2:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_2:%.*]]
; CHECK: latch.2:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_3:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_3:%.*]]
; CHECK: latch.3:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_4:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_4:%.*]]
; CHECK: latch.4:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_5:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_5:%.*]]
; CHECK: latch.5:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_6:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_6:%.*]]
; CHECK: latch.6:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_7:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_7:%.*]]
; CHECK: latch.7:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_8:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_8:%.*]]
; CHECK: latch.8:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 true, label [[LATCH_9:%.*]], label [[EXIT]]
+; CHECK-NEXT: br label [[LATCH_9:%.*]]
; CHECK: latch.9:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: call void @bar()
-; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT]]
+; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT:%.*]]
; CHECK: latch.10:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[EXIT]]
diff --git a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
index 3076084c53309..1cd86ec145df8 100644
--- a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
+++ b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
@@ -168,7 +168,7 @@ define void @test3(i32* noalias %A, i1 %cond) {
; CHECK-NEXT: call void @bar(i32 [[TMP0]])
; CHECK-NEXT: br i1 [[COND:%.*]], label [[FOR_BODY:%.*]], label [[FOR_END:%.*]]
; CHECK: for.body:
-; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]], label [[FOR_END]]
+; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]]
; CHECK: for.body.for.body_crit_edge:
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1
; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4
@@ -177,14 +177,14 @@ define void @test3(i32* noalias %A, i1 %cond) {
; CHECK: for.end:
; CHECK-NEXT: ret void
; CHECK: for.body.1:
-; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]], label [[FOR_END]]
+; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]]
; CHECK: for.body.for.body_crit_edge.1:
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2
; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4
; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]])
; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_2:%.*]], label [[FOR_END]]
; CHECK: for.body.2:
-; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]], label [[FOR_END]]
+; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]]
; CHECK: for.body.for.body_crit_edge.2:
; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3
; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4
More information about the llvm-commits
mailing list