[llvm] [LoopRotate] Update branch weights for multi-exit loops (PR #164006)
Shimin Cui via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 17 12:49:44 PDT 2025
https://github.com/scui-ibm updated https://github.com/llvm/llvm-project/pull/164006
>From 54d5e5e49cb2dc6ea82f05b429439e3e1032ecae Mon Sep 17 00:00:00 2001
From: Shimin Cui <scui at xlperflep9.rtp.raleigh.ibm.com>
Date: Fri, 17 Oct 2025 19:20:52 +0000
Subject: [PATCH] [LoopRotate] Update branch weights for multi-exit loops
---
.../Transforms/Utils/LoopRotationUtils.cpp | 100 +++++++++------
.../LoopRotate/update-branch-weights.ll | 8 +-
.../LoopRotate/update-multi-exit-loop-weights | 115 ++++++++++++++++++
3 files changed, 185 insertions(+), 38 deletions(-)
create mode 100644 llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 0c8d6fa47b9ae..52ae973935c88 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -46,9 +46,6 @@ STATISTIC(NumInstrsHoisted,
STATISTIC(NumInstrsDuplicated,
"Number of instructions cloned into loop preheader");
-// Probability that a rotated loop has zero trip count / is never entered.
-static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
-
namespace {
/// A simple loop rotation transformation.
class LoopRotate {
@@ -200,7 +197,8 @@ static bool profitableToRotateLoopExitingLatch(Loop *L) {
return false;
}
-static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+static void updateBranchWeights(Loop *L, BranchInst &PreHeaderBI,
+ BranchInst &LoopBI,
bool HasConditionalPreHeader,
bool SuccsSwapped) {
MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
@@ -218,12 +216,49 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
if (Weights.size() != 2)
return;
uint32_t OrigLoopExitWeight = Weights[0];
- uint32_t OrigLoopBackedgeWeight = Weights[1];
+ uint32_t OrigLoopEnterWeight = Weights[1];
if (SuccsSwapped)
- std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+ std::swap(OrigLoopExitWeight, OrigLoopEnterWeight);
+
+ // For a multiple-exit loop, find the total weight of other exits.
+ uint32_t OtherLoopExitWeight = 0;
+ SmallVector<BasicBlock *, 16> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+ for (BasicBlock *ExitingBB : ExitingBlocks) {
+ Instruction *TI = ExitingBB->getTerminator();
+ if (TI == &LoopBI)
+ continue;
+
+ if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) ||
+ isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI) ||
+ isa<CallBrInst>(TI)))
+ continue;
+
+ MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
+ if (!WeightsNode)
+ continue;
+
+ SmallVector<uint32_t, 2> Weights;
+ extractBranchWeights(WeightsNode, Weights);
+ for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+ BasicBlock *Exit = TI->getSuccessor(I);
+ if (L->contains(Exit))
+ continue;
+
+ OtherLoopExitWeight += Weights[I];
+ }
+ }
- // Update branch weights. Consider the following edge-counts:
+ // Adjust OtherLoopExitWeight as it should not be larger than the loop enter
+ // weight.
+ if (OtherLoopExitWeight > OrigLoopEnterWeight)
+ OtherLoopExitWeight = OrigLoopEnterWeight;
+
+ uint32_t OrigLoopBackedgeWeight = OrigLoopEnterWeight - OtherLoopExitWeight;
+
+ // Update branch weights. Consider the following edge-counts (z for multiple
+ // exit loop):
//
// | |-------- |
// V V | V
@@ -231,16 +266,18 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
// | | | | |
// x| y| | becomes: | y0| |-----
// V V | | V V |
- // Exit Loop | | Loop |
- // | | | Br i1 ... |
+ // Exit <- Loop | | Loop |
+ // z | | | Br i1 ... |
// ----- | | | |
// x0| x1| y1 | |
// V V ----
- // Exit
+ // Exit <----|
+ // z
//
// The following must hold:
// - x == x0 + x1 # counts to "exit" must stay the same.
- // - y0 == x - x0 == x1 # how often loop was entered at all.
+ // - y0 == x - x0 + z # how often loop was entered at all.
+ // == x1 + z
// - y1 == y - y0 # How often loop was repeated (after first iter.).
//
// We cannot generally deduce how often we had a zero-trip count loop so we
@@ -255,19 +292,12 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
if (HasConditionalPreHeader) {
// Here we cannot know how many 0-trip count loops we have, so we guess:
if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
- // If the loop count is bigger than the exit count then we set
- // probabilities as if 0-trip count nearly never happens.
- ExitWeight0 = ZeroTripCountWeights[0];
- // Scale up counts if necessary so we can match `ZeroTripCountWeights`
- // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
- while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
- // ... but don't overflow.
- uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
- if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
- (OrigLoopExitWeight & HighBit) != 0)
- break;
- OrigLoopBackedgeWeight <<= 1;
- OrigLoopExitWeight <<= 1;
+ ExitWeight0 =
+ (OrigLoopExitWeight * (OrigLoopExitWeight + OtherLoopExitWeight)) /
+ (OrigLoopExitWeight + OrigLoopEnterWeight);
+ // Minimum ExitWeight0 1
+ if (ExitWeight0 == 0) {
+ ExitWeight0 = 1;
}
} else {
// If there's a higher exit-count than backedge-count then we set
@@ -280,36 +310,38 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
// weight collected by sampling-based PGO may be not very accurate due to
// sampling. Therefore this workaround is required here to avoid underflow
// of unsigned in following update of branch weight.
- if (OrigLoopExitWeight > OrigLoopBackedgeWeight)
+ if (OrigLoopExitWeight > OrigLoopBackedgeWeight) {
OrigLoopBackedgeWeight = OrigLoopExitWeight;
+ OrigLoopEnterWeight = OrigLoopBackedgeWeight + OtherLoopExitWeight;
+ }
}
assert(OrigLoopExitWeight >= ExitWeight0 && "Bad branch weight");
ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
- EnterWeight = ExitWeight1;
- assert(OrigLoopBackedgeWeight >= EnterWeight && "Bad branch weight");
- LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
+ EnterWeight = ExitWeight1 + OtherLoopExitWeight;
+ assert(OrigLoopEnterWeight >= EnterWeight && "Bad branch weight");
+ LoopBackWeight = OrigLoopEnterWeight - EnterWeight;
} else if (OrigLoopExitWeight == 0) {
if (OrigLoopBackedgeWeight == 0) {
// degenerate case... keep everything zero...
ExitWeight0 = 0;
ExitWeight1 = 0;
- EnterWeight = 0;
+ EnterWeight = OtherLoopExitWeight;
LoopBackWeight = 0;
} else {
// Special case "LoopExitWeight == 0" weights which behaves like an
- // endless where we don't want loop-enttry (y0) to be the same as
+ // endless where we don't want loop-entry (y0) to be the same as
// loop-exit (x1).
ExitWeight0 = 0;
ExitWeight1 = 0;
- EnterWeight = 1;
+ EnterWeight = (OtherLoopExitWeight != 0) ? OtherLoopExitWeight : 1;
LoopBackWeight = OrigLoopBackedgeWeight;
}
} else {
// loop is never entered.
assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
- ExitWeight0 = 1;
+ ExitWeight0 = OrigLoopExitWeight;
ExitWeight1 = 1;
- EnterWeight = 0;
+ EnterWeight = OtherLoopExitWeight;
LoopBackWeight = 0;
}
@@ -748,7 +780,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
!isa<ConstantInt>(Cond) ||
PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
- updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+ updateBranchWeights(L, *PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
if (HasConditionalPreHeader) {
// The conditional branch can't be folded, handle the general case.
diff --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
index 9a1f36ec5ff2b..77157a1f45e8a 100644
--- a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -70,9 +70,9 @@ outer_loop_exit:
; BFI_AFTER-LABEL: block-frequency-info: func1
; BFI_AFTER: - entry: {{.*}} count = 1024
-; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1016
+; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 512
; BFI_AFTER: - loop_body: {{.*}} count = 20480
-; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1016
+; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 512
; BFI_AFTER: - loop_exit: {{.*}} count = 1024
; IR-LABEL: define void @func1
@@ -285,8 +285,8 @@ loop_exit:
; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
-; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
-; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
+; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 1, i32 1}
+; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 39, i32 1}
; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
diff --git a/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
new file mode 100644
index 0000000000000..1a62a47d65810
--- /dev/null
+++ b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
@@ -0,0 +1,115 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s
+
+ at g = global i64 0
+
+define void @func_branch_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_branch_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0:![0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT: br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1:![0-9]+]]
+; CHECK: [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
+; CHECK: [[LOOP_HEADER:.*]]:
+; CHECK-NEXT: [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT: br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF2:![0-9]+]]
+; CHECK: [[LOOP_BODY]]:
+; CHECK-NEXT: [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT: [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT: [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT: [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT: br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF3:![0-9]+]]
+; CHECK: [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT: br label %[[LOOP_EXIT]]
+; CHECK: [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT: br label %[[LOOP_EXIT]]
+; CHECK: [[LOOP_EXIT]]:
+; CHECK-NEXT: ret void
+;
+entry:
+ br label %loop_header
+
+loop_header:
+ %i = phi i64 [0, %entry], [%i_inc, %if_then]
+ %cmp = icmp slt i64 %i, %n
+ br i1 %cmp, label %loop_exit, label %loop_body, !prof !1
+
+loop_body:
+ %gp = getelementptr inbounds i8, ptr @g, i64 %i
+ %gi = load i64, ptr %gp, align 8
+ %cmp.not = icmp eq i64 %gi, 0
+ br i1 %cmp.not, label %if_then, label %loop_exit, !prof !2
+
+if_then:
+ %i_inc = add i64 %i, 1
+ br label %loop_header
+
+loop_exit:
+ ret void
+}
+
+
+define void @func_zero_backage_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_zero_backage_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT: br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1]]
+; CHECK: [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
+; CHECK: [[LOOP_HEADER:.*]]:
+; CHECK-NEXT: [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT: br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF4:![0-9]+]]
+; CHECK: [[LOOP_BODY]]:
+; CHECK-NEXT: [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT: [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT: [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT: [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT: br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF5:![0-9]+]]
+; CHECK: [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT: br label %[[LOOP_EXIT]]
+; CHECK: [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT: br label %[[LOOP_EXIT]]
+; CHECK: [[LOOP_EXIT]]:
+; CHECK-NEXT: ret void
+;
+entry:
+ br label %loop_header
+
+loop_header:
+ %i = phi i64 [0, %entry], [%i_inc, %if_then]
+ %cmp = icmp slt i64 %i, %n
+ br i1 %cmp, label %loop_exit, label %loop_body, !prof !3
+
+loop_body:
+ %gp = getelementptr inbounds i8, ptr @g, i64 %i
+ %gi = load i64, ptr %gp, align 8
+ %cmp.not = icmp eq i64 %gi, 0
+ br i1 %cmp.not, label %if_then, label %loop_exit, !prof !4
+
+if_then:
+ %i_inc = add i64 %i, 1
+ br label %loop_header
+
+loop_exit:
+ ret void
+}
+
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 200, i32 900}
+!2 = !{!"branch_weights", i32 100, i32 800}
+!3 = !{!"branch_weights", i32 100, i32 900}
+!4 = !{!"branch_weights", i32 0, i32 900}
+;.
+; CHECK: [[PROF0]] = !{!"function_entry_count", i64 1000}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 100, i32 900}
+; CHECK: [[PROF2]] = !{!"branch_weights", i32 100, i32 0}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 800}
+; CHECK: [[PROF4]] = !{!"branch_weights", i32 1, i32 0}
+; CHECK: [[PROF5]] = !{!"branch_weights", i32 0, i32 900}
+;.
More information about the llvm-commits
mailing list