[llvm] [LoopRotate] Update branch weights for multi-exit loops (PR #164006)

Shimin Cui via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 17 12:49:44 PDT 2025


https://github.com/scui-ibm updated https://github.com/llvm/llvm-project/pull/164006

>From 54d5e5e49cb2dc6ea82f05b429439e3e1032ecae Mon Sep 17 00:00:00 2001
From: Shimin Cui <scui at xlperflep9.rtp.raleigh.ibm.com>
Date: Fri, 17 Oct 2025 19:20:52 +0000
Subject: [PATCH] [LoopRotate] Update branch weights for multi-exit loops

---
 .../Transforms/Utils/LoopRotationUtils.cpp    | 100 +++++++++------
 .../LoopRotate/update-branch-weights.ll       |   8 +-
 .../LoopRotate/update-multi-exit-loop-weights | 115 ++++++++++++++++++
 3 files changed, 185 insertions(+), 38 deletions(-)
 create mode 100644 llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights

diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 0c8d6fa47b9ae..52ae973935c88 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -46,9 +46,6 @@ STATISTIC(NumInstrsHoisted,
 STATISTIC(NumInstrsDuplicated,
           "Number of instructions cloned into loop preheader");
 
-// Probability that a rotated loop has zero trip count / is never entered.
-static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
-
 namespace {
 /// A simple loop rotation transformation.
 class LoopRotate {
@@ -200,7 +197,8 @@ static bool profitableToRotateLoopExitingLatch(Loop *L) {
   return false;
 }
 
-static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+static void updateBranchWeights(Loop *L, BranchInst &PreHeaderBI,
+                                BranchInst &LoopBI,
                                 bool HasConditionalPreHeader,
                                 bool SuccsSwapped) {
   MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
@@ -218,12 +216,49 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
   if (Weights.size() != 2)
     return;
   uint32_t OrigLoopExitWeight = Weights[0];
-  uint32_t OrigLoopBackedgeWeight = Weights[1];
+  uint32_t OrigLoopEnterWeight = Weights[1];
 
   if (SuccsSwapped)
-    std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+    std::swap(OrigLoopExitWeight, OrigLoopEnterWeight);
+
+  // For a multiple-exit loop, find the total weight of other exits.
+  uint32_t OtherLoopExitWeight = 0;
+  SmallVector<BasicBlock *, 16> ExitingBlocks;
+  L->getExitingBlocks(ExitingBlocks);
+  for (BasicBlock *ExitingBB : ExitingBlocks) {
+    Instruction *TI = ExitingBB->getTerminator();
+    if (TI == &LoopBI)
+      continue;
+
+    if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) ||
+          isa<IndirectBrInst>(TI) || isa<InvokeInst>(TI) ||
+          isa<CallBrInst>(TI)))
+      continue;
+
+    MDNode *WeightsNode = getValidBranchWeightMDNode(*TI);
+    if (!WeightsNode)
+      continue;
+
+    SmallVector<uint32_t, 2> Weights;
+    extractBranchWeights(WeightsNode, Weights);
+    for (unsigned I = 0, E = Weights.size(); I != E; ++I) {
+      BasicBlock *Exit = TI->getSuccessor(I);
+      if (L->contains(Exit))
+        continue;
+
+      OtherLoopExitWeight += Weights[I];
+    }
+  }
 
-  // Update branch weights. Consider the following edge-counts:
+  // Adjust OtherLoopExitWeight as it should not be larger than the loop enter
+  // weight.
+  if (OtherLoopExitWeight > OrigLoopEnterWeight)
+    OtherLoopExitWeight = OrigLoopEnterWeight;
+
+  uint32_t OrigLoopBackedgeWeight = OrigLoopEnterWeight - OtherLoopExitWeight;
+
+  // Update branch weights. Consider the following edge-counts (z for multiple
+  // exit loop):
   //
   //    |  |--------             |
   //    V  V       |             V
@@ -231,16 +266,18 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
   //   |       |   |            |     |
   //  x|      y|   |  becomes:  |   y0|  |-----
   //   V       V   |            |     V  V    |
-  // Exit    Loop  |            |    Loop     |
-  //           |   |            |   Br i1 ... |
+  // Exit <- Loop  |            |    Loop     |
+  //      z    |   |            |   Br i1 ... |
   //           -----            |   |      |  |
   //                          x0| x1|   y1 |  |
   //                            V   V      ----
-  //                            Exit
+  //                            Exit  <----|
+  //                                    z
   //
   // The following must hold:
   //  -  x == x0 + x1        # counts to "exit" must stay the same.
-  //  - y0 == x - x0 == x1   # how often loop was entered at all.
+  //  - y0 == x - x0 + z     # how often loop was entered at all.
+  //       == x1 + z
   //  - y1 == y - y0         # How often loop was repeated (after first iter.).
   //
   // We cannot generally deduce how often we had a zero-trip count loop so we
@@ -255,19 +292,12 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
     if (HasConditionalPreHeader) {
       // Here we cannot know how many 0-trip count loops we have, so we guess:
       if (OrigLoopBackedgeWeight >= OrigLoopExitWeight) {
-        // If the loop count is bigger than the exit count then we set
-        // probabilities as if 0-trip count nearly never happens.
-        ExitWeight0 = ZeroTripCountWeights[0];
-        // Scale up counts if necessary so we can match `ZeroTripCountWeights`
-        // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
-        while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
-          // ... but don't overflow.
-          uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
-          if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
-              (OrigLoopExitWeight & HighBit) != 0)
-            break;
-          OrigLoopBackedgeWeight <<= 1;
-          OrigLoopExitWeight <<= 1;
+        ExitWeight0 =
+            (OrigLoopExitWeight * (OrigLoopExitWeight + OtherLoopExitWeight)) /
+            (OrigLoopExitWeight + OrigLoopEnterWeight);
+        // Minimum ExitWeight0 1
+        if (ExitWeight0 == 0) {
+          ExitWeight0 = 1;
         }
       } else {
         // If there's a higher exit-count than backedge-count then we set
@@ -280,36 +310,38 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
       // weight collected by sampling-based PGO may be not very accurate due to
       // sampling. Therefore this workaround is required here to avoid underflow
       // of unsigned in following update of branch weight.
-      if (OrigLoopExitWeight > OrigLoopBackedgeWeight)
+      if (OrigLoopExitWeight > OrigLoopBackedgeWeight) {
         OrigLoopBackedgeWeight = OrigLoopExitWeight;
+        OrigLoopEnterWeight = OrigLoopBackedgeWeight + OtherLoopExitWeight;
+      }
     }
     assert(OrigLoopExitWeight >= ExitWeight0 && "Bad branch weight");
     ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
-    EnterWeight = ExitWeight1;
-    assert(OrigLoopBackedgeWeight >= EnterWeight && "Bad branch weight");
-    LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
+    EnterWeight = ExitWeight1 + OtherLoopExitWeight;
+    assert(OrigLoopEnterWeight >= EnterWeight && "Bad branch weight");
+    LoopBackWeight = OrigLoopEnterWeight - EnterWeight;
   } else if (OrigLoopExitWeight == 0) {
     if (OrigLoopBackedgeWeight == 0) {
       // degenerate case... keep everything zero...
       ExitWeight0 = 0;
       ExitWeight1 = 0;
-      EnterWeight = 0;
+      EnterWeight = OtherLoopExitWeight;
       LoopBackWeight = 0;
     } else {
       // Special case "LoopExitWeight == 0" weights which behaves like an
-      // endless where we don't want loop-enttry (y0) to be the same as
+      // endless where we don't want loop-entry (y0) to be the same as
       // loop-exit (x1).
       ExitWeight0 = 0;
       ExitWeight1 = 0;
-      EnterWeight = 1;
+      EnterWeight = (OtherLoopExitWeight != 0) ? OtherLoopExitWeight : 1;
       LoopBackWeight = OrigLoopBackedgeWeight;
     }
   } else {
     // loop is never entered.
     assert(OrigLoopBackedgeWeight == 0 && "remaining case is backedge zero");
-    ExitWeight0 = 1;
+    ExitWeight0 = OrigLoopExitWeight;
     ExitWeight1 = 1;
-    EnterWeight = 0;
+    EnterWeight = OtherLoopExitWeight;
     LoopBackWeight = 0;
   }
 
@@ -748,7 +780,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
       !isa<ConstantInt>(Cond) ||
       PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
 
-  updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+  updateBranchWeights(L, *PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
 
   if (HasConditionalPreHeader) {
     // The conditional branch can't be folded, handle the general case.
diff --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
index 9a1f36ec5ff2b..77157a1f45e8a 100644
--- a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -70,9 +70,9 @@ outer_loop_exit:
 
 ; BFI_AFTER-LABEL: block-frequency-info: func1
 ; BFI_AFTER: - entry: {{.*}} count = 1024
-; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1016
+; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 512
 ; BFI_AFTER: - loop_body: {{.*}} count = 20480
-; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1016
+; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 512
 ; BFI_AFTER: - loop_exit: {{.*}} count = 1024
 
 ; IR-LABEL: define void @func1
@@ -285,8 +285,8 @@ loop_exit:
 
 ; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
 ; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
-; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
-; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
+; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 1, i32 1}
+; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 39, i32 1}
 ; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
 ; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
 ; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}
diff --git a/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
new file mode 100644
index 0000000000000..1a62a47d65810
--- /dev/null
+++ b/llvm/test/Transforms/LoopRotate/update-multi-exit-loop-weights
@@ -0,0 +1,115 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
+; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s
+
+ at g = global i64 0
+
+define void @func_branch_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_branch_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0:![0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT:    br label %[[LOOP_BODY:.*]]
+; CHECK:       [[LOOP_HEADER:.*]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF2:![0-9]+]]
+; CHECK:       [[LOOP_BODY]]:
+; CHECK-NEXT:    [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT:    [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT:    [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT:    br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF3:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_EXIT]]:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i64 [0, %entry], [%i_inc, %if_then]
+  %cmp = icmp slt i64 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !1
+
+loop_body:
+  %gp = getelementptr inbounds i8, ptr @g, i64 %i
+  %gi = load i64, ptr %gp, align 8
+  %cmp.not = icmp eq i64 %gi, 0
+  br i1 %cmp.not, label %if_then, label %loop_exit, !prof !2
+
+if_then:
+  %i_inc = add i64 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+
+define void @func_zero_backage_weight(i64 %n) !prof !0 {
+; CHECK-LABEL: define void @func_zero_backage_weight(
+; CHECK-SAME: i64 [[N:%.*]]) !prof [[PROF0]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i64 0, [[N]]
+; CHECK-NEXT:    br i1 [[CMP1]], label %[[LOOP_EXIT:.*]], label %[[LOOP_BODY_LR_PH:.*]], !prof [[PROF1]]
+; CHECK:       [[LOOP_BODY_LR_PH]]:
+; CHECK-NEXT:    br label %[[LOOP_BODY:.*]]
+; CHECK:       [[LOOP_HEADER:.*]]:
+; CHECK-NEXT:    [[I:%.*]] = phi i64 [ [[I_INC:%.*]], %[[LOOP_BODY]] ]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i64 [[I]], [[N]]
+; CHECK-NEXT:    br i1 [[CMP]], label %[[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE:.*]], label %[[LOOP_BODY]], !prof [[PROF4:![0-9]+]]
+; CHECK:       [[LOOP_BODY]]:
+; CHECK-NEXT:    [[I2:%.*]] = phi i64 [ 0, %[[LOOP_BODY_LR_PH]] ], [ [[I]], %[[LOOP_HEADER]] ]
+; CHECK-NEXT:    [[GP:%.*]] = getelementptr inbounds i8, ptr @g, i64 [[I2]]
+; CHECK-NEXT:    [[GI:%.*]] = load i64, ptr [[GP]], align 8
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i64 [[GI]], 0
+; CHECK-NEXT:    [[I_INC]] = add i64 [[I2]], 1
+; CHECK-NEXT:    br i1 [[CMP_NOT]], label %[[LOOP_HEADER]], label %[[LOOP_BODY_LOOP_EXIT_CRIT_EDGE:.*]], !prof [[PROF5:![0-9]+]]
+; CHECK:       [[LOOP_BODY_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_HEADER_LOOP_EXIT_CRIT_EDGE]]:
+; CHECK-NEXT:    br label %[[LOOP_EXIT]]
+; CHECK:       [[LOOP_EXIT]]:
+; CHECK-NEXT:    ret void
+;
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i64 [0, %entry], [%i_inc, %if_then]
+  %cmp = icmp slt i64 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !3
+
+loop_body:
+  %gp = getelementptr inbounds i8, ptr @g, i64 %i
+  %gi = load i64, ptr %gp, align 8
+  %cmp.not = icmp eq i64 %gi, 0
+  br i1 %cmp.not, label %if_then, label %loop_exit, !prof !4
+
+if_then:
+  %i_inc = add i64 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+!0 = !{!"function_entry_count", i64 1000}
+!1 = !{!"branch_weights", i32 200, i32 900}
+!2 = !{!"branch_weights", i32 100, i32 800}
+!3 = !{!"branch_weights", i32 100, i32 900}
+!4 = !{!"branch_weights", i32 0, i32 900}
+;.
+; CHECK: [[PROF0]] = !{!"function_entry_count", i64 1000}
+; CHECK: [[PROF1]] = !{!"branch_weights", i32 100, i32 900}
+; CHECK: [[PROF2]] = !{!"branch_weights", i32 100, i32 0}
+; CHECK: [[PROF3]] = !{!"branch_weights", i32 100, i32 800}
+; CHECK: [[PROF4]] = !{!"branch_weights", i32 1, i32 0}
+; CHECK: [[PROF5]] = !{!"branch_weights", i32 0, i32 900}
+;.



More information about the llvm-commits mailing list