[llvm] [SimpleLoopUnswitch] Record loops from unswitching non-trivial conditions (PR #141121)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Thu May 22 11:45:03 PDT 2025
https://github.com/antoniofrighetto created https://github.com/llvm/llvm-project/pull/141121
Track newly-cloned loops coming from unswitching non-trivial invariant conditions, so as to prevent conditions in such cloned blocks from being unswitched again. While this should optimistically suffice, ensure the outer loop basic block size is taken into account as well when estimating the cost for unswitching non-trivial conditions.
Fixes: https://github.com/llvm/llvm-project/issues/138509.
>From cd421580cca0f56934ea06e77834e3f76467c8d5 Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Thu, 22 May 2025 19:19:32 +0200
Subject: [PATCH] =?UTF-8?q?[SimpleLoopUnswitch]=C2=A0Record=20loops=20from?=
=?UTF-8?q?=20unswitching=20non-trivial=20conditions?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Track newly-cloned loops coming from unswitching non-trivial invariant
conditions, so as to prevent conditions in such cloned blocks from
being unswitched again. While this should optimistically suffice,
ensure the outer loop basic block size is taken into account as well
when estimating the cost for unswitching non-trivial conditions.
Fixes: https://github.com/llvm/llvm-project/issues/138509.
---
.../Transforms/Scalar/SimpleLoopUnswitch.cpp | 61 +++++++++++--------
.../Transforms/SimpleLoopUnswitch/pr138509.ll | 49 +++++++++++++++
2 files changed, 84 insertions(+), 26 deletions(-)
create mode 100644 llvm/test/Transforms/SimpleLoopUnswitch/pr138509.ll
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index 0bf90036b8b82..4ebb73e917370 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -2142,9 +2142,22 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
bool CurrentLoopValid, bool PartiallyInvariant,
bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
- // If we did a non-trivial unswitch, we have added new (cloned) loops.
- if (!NewLoops.empty())
+ auto RecordLoopAsUnswitched = [&](Loop *TargetLoop, StringRef Tag) {
+ auto &Ctx = TargetLoop->getHeader()->getContext();
+ const auto &DisableMDName = (Twine(Tag) + ".disable").str();
+ MDNode *DisableMD = MDNode::get(Ctx, MDString::get(Ctx, DisableMDName));
+ MDNode *NewLoopID = makePostTransformationMetadata(
+ Ctx, TargetLoop->getLoopID(), {Tag}, {DisableMD});
+ TargetLoop->setLoopID(NewLoopID);
+ };
+
+ // If we performed a non-trivial unswitch, we have added new cloned loops.
+ // Mark such newly-created loops as visited.
+ if (!NewLoops.empty()) {
+ for (Loop *NL : NewLoops)
+ RecordLoopAsUnswitched(NL, "llvm.loop.unswitch.nontrivial");
U.addSiblingLoops(NewLoops);
+ }
// If the current loop remains valid, we should revisit it to catch any
// other unswitch opportunities. Otherwise, we need to mark it as deleted.
@@ -2152,24 +2165,10 @@ void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName,
if (PartiallyInvariant) {
// Mark the new loop as partially unswitched, to avoid unswitching on
// the same condition again.
- auto &Context = L.getHeader()->getContext();
- MDNode *DisableUnswitchMD = MDNode::get(
- Context,
- MDString::get(Context, "llvm.loop.unswitch.partial.disable"));
- MDNode *NewLoopID = makePostTransformationMetadata(
- Context, L.getLoopID(), {"llvm.loop.unswitch.partial"},
- {DisableUnswitchMD});
- L.setLoopID(NewLoopID);
+ RecordLoopAsUnswitched(&L, "llvm.loop.unswitch.partial");
} else if (InjectedCondition) {
// Do the same for injection of invariant conditions.
- auto &Context = L.getHeader()->getContext();
- MDNode *DisableUnswitchMD = MDNode::get(
- Context,
- MDString::get(Context, "llvm.loop.unswitch.injection.disable"));
- MDNode *NewLoopID = makePostTransformationMetadata(
- Context, L.getLoopID(), {"llvm.loop.unswitch.injection"},
- {DisableUnswitchMD});
- L.setLoopID(NewLoopID);
+ RecordLoopAsUnswitched(&L, "llvm.loop.unswitch.injection");
} else
U.revisitCurrentLoop();
} else
@@ -2806,9 +2805,9 @@ static BranchInst *turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,
}
/// Cost multiplier is a way to limit potentially exponential behavior
-/// of loop-unswitch. Cost is multipied in proportion of 2^number of unswitch
-/// candidates available. Also accounting for the number of "sibling" loops with
-/// the idea to account for previous unswitches that already happened on this
+/// of loop-unswitch. Cost is multiplied in proportion of 2^number of unswitch
+/// candidates available. Also consider the number of "sibling" loops with
+/// the idea of accounting for previous unswitches that already happened on this
/// cluster of loops. There was an attempt to keep this formula simple,
/// just enough to limit the worst case behavior. Even if it is not that simple
/// now it is still not an attempt to provide a detailed heuristic size
@@ -2839,7 +2838,14 @@ static int CalculateUnswitchCostMultiplier(
return 1;
}
+ // When dealing with nested loops, the basic block size of the outer loop may
+ // increase significantly during unswitching non-trivial conditions. The final
+ // cost may be adjusted taking this into account.
auto *ParentL = L.getParentLoop();
+ int ParentSizeMultiplier = 1;
+ if (ParentL)
+ ParentSizeMultiplier = std::max((int)ParentL->getNumBlocks(), 1);
+
int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size()
: std::distance(LI.begin(), LI.end()));
// Count amount of clones that all the candidates might cause during
@@ -2887,11 +2893,13 @@ static int CalculateUnswitchCostMultiplier(
SiblingsMultiplier > UnswitchThreshold)
CostMultiplier = UnswitchThreshold;
else
- CostMultiplier = std::min(SiblingsMultiplier * (1 << ClonesPower),
- (int)UnswitchThreshold);
+ CostMultiplier =
+ std::min(SiblingsMultiplier * ParentSizeMultiplier * (1 << ClonesPower),
+ (int)UnswitchThreshold);
LLVM_DEBUG(dbgs() << " Computed multiplier " << CostMultiplier
- << " (siblings " << SiblingsMultiplier << " * clones "
+ << " (siblings " << SiblingsMultiplier << "* parent size "
+ << ParentSizeMultiplier << " * clones "
<< (1 << ClonesPower) << ")"
<< " for unswitch candidate: " << TI << "\n");
return CostMultiplier;
@@ -3504,8 +3512,9 @@ static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
SmallVector<NonTrivialUnswitchCandidate, 4> UnswitchCandidates;
IVConditionInfo PartialIVInfo;
Instruction *PartialIVCondBranch = nullptr;
- collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
- PartialIVCondBranch, L, LI, AA, MSSAU);
+ if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.nontrivial.disable"))
+ collectUnswitchCandidates(UnswitchCandidates, PartialIVInfo,
+ PartialIVCondBranch, L, LI, AA, MSSAU);
if (!findOptionMDForLoop(&L, "llvm.loop.unswitch.injection.disable"))
collectUnswitchCandidatesWithInjections(UnswitchCandidates, PartialIVInfo,
PartialIVCondBranch, L, DT, LI, AA,
diff --git a/llvm/test/Transforms/SimpleLoopUnswitch/pr138509.ll b/llvm/test/Transforms/SimpleLoopUnswitch/pr138509.ll
new file mode 100644
index 0000000000000..e24d17f088427
--- /dev/null
+++ b/llvm/test/Transforms/SimpleLoopUnswitch/pr138509.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes="loop-mssa(loop-simplifycfg,licm,loop-rotate,simple-loop-unswitch<nontrivial>)" < %s | FileCheck %s
+
+ at a = global i32 0, align 4
+ at b = global i32 0, align 4
+ at c = global i32 0, align 4
+ at d = global i32 0, align 4
+
+define i32 @main() {
+entry:
+ br label %outer.loop.header
+
+outer.loop.header: ; preds = %outer.loop.latch, %entry
+ br i1 false, label %exit, label %outer.loop.body
+
+outer.loop.body: ; preds = %inner.loop.header, %outer.loop.header
+ store i32 1, ptr @c, align 4
+ %cmp = icmp sgt i32 0, -1
+ br i1 %cmp, label %outer.loop.latch, label %exit
+
+inner.loop.header: ; preds = %outer.loop.latch, %inner.loop.body
+ %a_val = load i32, ptr @a, align 4
+ %c_val = load i32, ptr @c, align 4
+ %mul = mul nsw i32 %c_val, %a_val
+ store i32 %mul, ptr @b, align 4
+ %cmp2 = icmp sgt i32 %mul, -1
+ br i1 %cmp2, label %inner.loop.body, label %outer.loop.body
+
+inner.loop.body: ; preds = %inner.loop.header
+ %mul2 = mul nsw i32 %c_val, 3
+ store i32 %mul2, ptr @c, align 4
+ store i32 %c_val, ptr @d, align 4
+ %mul3 = mul nsw i32 %c_val, %a_val
+ %cmp3 = icmp sgt i32 %mul3, -1
+ br i1 %cmp3, label %inner.loop.header, label %exit
+
+outer.loop.latch: ; preds = %outer.loop.body
+ %d_val = load i32, ptr @d, align 4
+ store i32 %d_val, ptr @b, align 4
+ %cmp4 = icmp eq i32 %d_val, 0
+ br i1 %cmp4, label %inner.loop.header, label %outer.loop.header
+
+exit: ; preds = %inner.loop.body, %outer.loop.body, %outer.loop.header
+ ret i32 0
+}
+
+; CHECK: [[LOOP0:.*]] = distinct !{[[LOOP0]], [[META1:![0-9]+]]}
+; CHECK: [[META1]] = !{!"llvm.loop.unswitch.nontrivial.disable"}
+; CHECK: [[LOOP2:.*]] = distinct !{[[LOOP2]], [[META1]]}
More information about the llvm-commits
mailing list