[llvm] [LoopUnroll] Rotate loop before unrolling inside of UnrollRuntimeLoopRemainder (PR #148243)
Marek Sedláček via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 28 08:29:26 PDT 2025
================
@@ -587,21 +536,116 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog
: isEpilogProfitable(L);
- if (ULO.Runtime &&
- !UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount,
- EpilogProfitability, ULO.UnrollRemainder,
- ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
- PreserveLCSSA, ULO.SCEVExpansionBudget,
- ULO.RuntimeUnrollMultiExit, RemainderLoop)) {
+ bool LoopRotated = false;
+ bool ReminderUnrolled = false;
+ if (ULO.Runtime) {
+ // Call unroll with disabled rotation, to see if it is possible without it.
+ ReminderUnrolled = UnrollRuntimeLoopRemainder(
+ L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability,
+ ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
+ PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit,
+ RemainderLoop);
+
+ // If unroll is not possible, then try with loop rotation.
+ if (!ReminderUnrolled) {
+ BasicBlock *OrigHeader = L->getHeader();
+ BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator());
+ if (BI && !BI->isUnconditional() &&
+ isa<SCEVCouldNotCompute>(SE->getExitCount(L, L->getLoopLatch())) &&
+ !isa<SCEVCouldNotCompute>(SE->getExitCount(L, OrigHeader))) {
+ LLVM_DEBUG(
+ dbgs() << " Rotating loop to make the exit count computable.\n");
+ SimplifyQuery SQ{OrigHeader->getDataLayout()};
+ SQ.TLI = nullptr;
+ SQ.DT = DT;
+ SQ.AC = AC;
+ LoopRotated =
+ llvm::LoopRotation(L, LI, TTI, AC, DT, SE,
+ /*MemorySSAUpdater*/ nullptr, SQ,
+ /*RotationOnly*/ false, /*Threshold*/ 16,
+ /*IsUtilMode*/ false, /*PrepareForLTO*/ false,
+ [](Loop *, ScalarEvolution *) { return true; });
----------------
mark-sed wrote:
I do not, but I did try to quickly put it together now, it will not compile and some stuff could be optimized, but it would be something like this and I much prefer the current implementation. The reason why it is like this is because the unrolling uses those values that are being checked and many of those need to be recalculated after the rotation.
It would be something like this:
```diff
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 36c976e23eed..d1a41e6e73a5 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -616,106 +616,157 @@ bool llvm::UnrollRuntimeLoopRemainder(
return false;
}
+ // Use Scalar Evolution to compute the trip count. This allows more loops to
+ // be unrolled than relying on induction var simplification.
+ if (!SE)
+ return false;
+
// Guaranteed by LoopSimplifyForm.
BasicBlock *Latch = L->getLoopLatch();
BasicBlock *Header = L->getHeader();
BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator());
-
- if (!LatchBR || LatchBR->isUnconditional()) {
- // The loop-rotate pass can be helpful to avoid this in many cases.
- LLVM_DEBUG(
- dbgs()
- << "Loop latch not terminated by a conditional branch.\n");
- return false;
- }
-
unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;
BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex);
- if (L->contains(LatchExit)) {
- // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
- // targets of the Latch be an exit block out of the loop.
- LLVM_DEBUG(
- dbgs()
- << "One of the loop latch successors must be the exit block.\n");
- return false;
- }
-
// These are exit blocks other than the target of the latch exiting block.
SmallVector<BasicBlock *, 4> OtherExits;
L->getUniqueNonLatchExitBlocks(OtherExits);
- // Support only single exit and exiting block unless multi-exit loop
- // unrolling is enabled.
- if (!L->getExitingBlock() || OtherExits.size()) {
- // We rely on LCSSA form being preserved when the exit blocks are transformed.
- // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.)
- if (!PreserveLCSSA)
+
+ BasicBlock *PreHeader = L->getLoopPreheader();
+ BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
+ const DataLayout &DL = Header->getDataLayout();
+
+ auto canUnrollReminder = [&]() {
+ // Only unroll loops with a computable trip count.
+ // We calculate the backedge count by using getExitCount on the Latch block,
+ // which is proven to be the only exiting block in this loop. This is same as
+ // calculating getBackedgeTakenCount on the loop (which computes SCEV for all
+ // exiting blocks).
+ const SCEV *BECountSC = SE->getExitCount(L, Latch);
+
+ // Add 1 since the backedge count doesn't include the first loop iteration.
+ // (Note that overflow can occur, this is handled explicitly below)
+ SCEV *TripCountSC =
+ SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1));
+
+ if (!LatchBR || LatchBR->isUnconditional()) {
+ // The loop-rotate pass can be helpful to avoid this in many cases.
+ LLVM_DEBUG(
+ dbgs()
+ << "Loop latch not terminated by a conditional branch.\n");
return false;
+ }
- // Priority goes to UnrollRuntimeMultiExit if it's supplied.
- if (UnrollRuntimeMultiExit.getNumOccurrences()) {
- if (!UnrollRuntimeMultiExit)
- return false;
- } else {
- // Otherwise perform multi-exit unrolling, if either the target indicates
- // it is profitable or the general profitability heuristics apply.
- if (!RuntimeUnrollMultiExit &&
- !canProfitablyRuntimeUnrollMultiExitLoop(L, BPI, TTI, OtherExits,
- LatchExit,
- UseEpilogRemainder)) {
- LLVM_DEBUG(dbgs() << "Multiple exit/exiting blocks in loop and "
- "multi-exit unrolling not enabled!\n");
+ if (L->contains(LatchExit)) {
+ // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
+ // targets of the Latch be an exit block out of the loop.
+ LLVM_DEBUG(
+ dbgs()
+ << "One of the loop latch successors must be the exit block.\n");
+ return false;
+ }
+
+ // Support only single exit and exiting block unless multi-exit loop
+ // unrolling is enabled.
+ if (!L->getExitingBlock() || OtherExits.size()) {
+ // We rely on LCSSA form being preserved when the exit blocks are transformed.
+ // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.)
+ if (!PreserveLCSSA)
return false;
+
+ // Priority goes to UnrollRuntimeMultiExit if it's supplied.
+ if (UnrollRuntimeMultiExit.getNumOccurrences()) {
+ if (!UnrollRuntimeMultiExit)
+ return false;
+ } else {
+ // Otherwise perform multi-exit unrolling, if either the target indicates
+ // it is profitable or the general profitability heuristics apply.
+ if (!RuntimeUnrollMultiExit &&
+ !canProfitablyRuntimeUnrollMultiExitLoop(L, BPI, TTI, OtherExits,
+ LatchExit,
+ UseEpilogRemainder)) {
+ LLVM_DEBUG(dbgs() << "Multiple exit/exiting blocks in loop and "
+ "multi-exit unrolling not enabled!\n");
+ return false;
+ }
}
}
- }
- // Use Scalar Evolution to compute the trip count. This allows more loops to
- // be unrolled than relying on induction var simplification.
- if (!SE)
- return false;
- // Only unroll loops with a computable trip count.
- // We calculate the backedge count by using getExitCount on the Latch block,
- // which is proven to be the only exiting block in this loop. This is same as
- // calculating getBackedgeTakenCount on the loop (which computes SCEV for all
- // exiting blocks).
- const SCEV *BECountSC = SE->getExitCount(L, Latch);
- if (isa<SCEVCouldNotCompute>(BECountSC)) {
- LLVM_DEBUG(dbgs() << "Could not compute exit block SCEV\n");
- return false;
+ if (isa<SCEVCouldNotCompute>(BECountSC)) {
+ LLVM_DEBUG(dbgs() << "Could not compute exit block SCEV\n");
+ return false;
+ }
+
+ unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth();
+
+ if (isa<SCEVCouldNotCompute>(TripCountSC)) {
+ LLVM_DEBUG(dbgs() << "Could not compute trip count SCEV.\n");
+ return false;
+ }
+
+ SCEVExpander Expander(*SE, DL, "loop-unroll");
+ if (!AllowExpensiveTripCount &&
+ Expander.isHighCostExpansion(TripCountSC, L, SCEVExpansionBudget, TTI,
+ PreHeaderBR)) {
+ LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n");
+ return false;
+ }
+
+ // This constraint lets us deal with an overflowing trip count easily; see the
+ // comment on ModVal below.
+ if (Log2_32(Count) > BEWidth) {
+ LLVM_DEBUG(
+ dbgs()
+ << "Count failed constraint on overflow trip count calculation.\n");
+ return false;
+ }
+
+ return true;
+ };
+
+ bool LoopRotated = false;
+ if (!canUnrollReminder()) {
+ BasicBlock *OrigHeader = L->getHeader();
+ BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator());
+ if (BI && !BI->isUnconditional() &&
+ isa<SCEVCouldNotCompute>(SE->getExitCount(L, L->getLoopLatch())) &&
+ !isa<SCEVCouldNotCompute>(SE->getExitCount(L, OrigHeader))) {
+ LLVM_DEBUG(
+ dbgs() << " Rotating loop to make the exit count computable.\n");
+ SimplifyQuery SQ{OrigHeader->getDataLayout()};
+ SQ.TLI = nullptr;
+ SQ.DT = DT;
+ SQ.AC = AC;
+ LoopRotated =
+ llvm::LoopRotation(L, LI, TTI, AC, DT, SE,
+ /*MemorySSAUpdater*/ nullptr, SQ,
+ /*RotationOnly*/ false, /*Threshold*/ 16,
+ /*IsUtilMode*/ false, /*PrepareForLTO*/ false,
+ [](Loop *, ScalarEvolution *) { return true; });
+ }
}
+ if (!LoopRotated)
+ return false;
- unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth();
+ Latch = L->getLoopLatch();
+ Header = L->getHeader();
- // Add 1 since the backedge count doesn't include the first loop iteration.
- // (Note that overflow can occur, this is handled explicitly below)
- const SCEV *TripCountSC =
+ LatchBR = cast<BranchInst>(Latch->getTerminator());
+ ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;
+ LatchExit = LatchBR->getSuccessor(ExitIndex);
+
+ L->getUniqueNonLatchExitBlocks(OtherExits);
+ const SCEV *BECountSC = SE->getExitCount(L, Latch);
+ TripCountSC =
SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1));
- if (isa<SCEVCouldNotCompute>(TripCountSC)) {
- LLVM_DEBUG(dbgs() << "Could not compute trip count SCEV.\n");
- return false;
- }
- BasicBlock *PreHeader = L->getLoopPreheader();
- BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
- const DataLayout &DL = Header->getDataLayout();
+ PreHeader = L->getLoopPreheader();
+ PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
SCEVExpander Expander(*SE, DL, "loop-unroll");
- if (!AllowExpensiveTripCount &&
- Expander.isHighCostExpansion(TripCountSC, L, SCEVExpansionBudget, TTI,
- PreHeaderBR)) {
- LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n");
- return false;
- }
- // This constraint lets us deal with an overflowing trip count easily; see the
- // comment on ModVal below.
- if (Log2_32(Count) > BEWidth) {
- LLVM_DEBUG(
- dbgs()
- << "Count failed constraint on overflow trip count calculation.\n");
+ if (!canUnrollReminder())
return false;
- }
// Loop structure is the following:
//
```
https://github.com/llvm/llvm-project/pull/148243
More information about the llvm-commits
mailing list