[llvm] 5d7f84e - LoopRotate: Add code to update branch weights

Matthias Braun via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 11 10:39:40 PDT 2023


Author: Matthias Braun
Date: 2023-09-11T10:38:06-07:00
New Revision: 5d7f84ee17f3f601c49f6124a3a51e557de3ab53

URL: https://github.com/llvm/llvm-project/commit/5d7f84ee17f3f601c49f6124a3a51e557de3ab53
DIFF: https://github.com/llvm/llvm-project/commit/5d7f84ee17f3f601c49f6124a3a51e557de3ab53.diff

LOG: LoopRotate: Add code to update branch weights

This adds code to the loop rotation transformation to ensure that the
computed block execution counts for the loop bodies are the same before
and after the transformation. This isn't always true in practice, but I
believe this is because of numeric inaccuracies in the BlockFrequency
computation.

The invariants this is modeled on and heuristic choice of 0-trip loop
amount is explained in a lenghty comment in the new
`updateBranchWeights()` function.

Differential Revision: https://reviews.llvm.org/D157462

Added: 
    llvm/test/Transforms/LoopRotate/update-branch-weights.ll

Modified: 
    llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
    llvm/test/Transforms/LoopSimplify/merge-exits.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index d81db5647c608d0..22effcf7d88afd2 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -25,6 +25,8 @@
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/ProfDataUtils.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -50,6 +52,9 @@ static cl::opt<bool>
                 cl::desc("Allow loop rotation multiple times in order to reach "
                          "a better latch exit"));
 
+// Probability that a rotated loop has zero trip count / is never entered.
+static constexpr uint32_t ZeroTripCountWeights[] = {1, 127};
+
 namespace {
 /// A simple loop rotation transformation.
 class LoopRotate {
@@ -244,6 +249,93 @@ static bool canRotateDeoptimizingLatchExit(Loop *L) {
   return false;
 }
 
+static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
+                                bool HasConditionalPreHeader,
+                                bool SuccsSwapped) {
+  MDNode *WeightMD = getBranchWeightMDNode(PreHeaderBI);
+  if (WeightMD == nullptr)
+    return;
+
+  // LoopBI should currently be a clone of PreHeaderBI with the same
+  // metadata. But we double check to make sure we don't have a degenerate case
+  // where instsimplify changed the instructions.
+  if (WeightMD != getBranchWeightMDNode(LoopBI))
+    return;
+
+  SmallVector<uint32_t, 2> Weights;
+  extractFromBranchWeightMD(WeightMD, Weights);
+  if (Weights.size() != 2)
+    return;
+  uint32_t OrigLoopExitWeight = Weights[0];
+  uint32_t OrigLoopBackedgeWeight = Weights[1];
+
+  if (SuccsSwapped)
+    std::swap(OrigLoopExitWeight, OrigLoopBackedgeWeight);
+
+  // Update branch weights. Consider the following edge-counts:
+  //
+  //    |  |--------             |
+  //    V  V       |             V
+  //   Br i1 ...   |            Br i1 ...
+  //   |       |   |            |     |
+  //  x|      y|   |  becomes:  |   y0|  |-----
+  //   V       V   |            |     V  V    |
+  // Exit    Loop  |            |    Loop     |
+  //           |   |            |   Br i1 ... |
+  //           -----            |   |      |  |
+  //                          x0| x1|   y1 |  |
+  //                            V   V      ----
+  //                            Exit
+  //
+  // The following must hold:
+  //  -  x == x0 + x1        # counts to "exit" must stay the same.
+  //  - y0 == x - x0 == x1   # how often loop was entered at all.
+  //  - y1 == y - y0         # How often loop was repeated (after first iter.).
+  //
+  // We cannot generally deduce how often we had a zero-trip count loop so we
+  // have to make a guess for how to distribute x among the new x0 and x1.
+
+  uint32_t ExitWeight0 = 0; // aka x0
+  if (HasConditionalPreHeader) {
+    // Here we cannot know how many 0-trip count loops we have, so we guess:
+    if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
+      // If the loop count is bigger than the exit count then we set
+      // probabilities as if 0-trip count nearly never happens.
+      ExitWeight0 = ZeroTripCountWeights[0];
+      // Scale up counts if necessary so we can match `ZeroTripCountWeights` for
+      // the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
+      while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
+        // ... but don't overflow.
+        uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
+        if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
+            (OrigLoopExitWeight & HighBit) != 0)
+          break;
+        OrigLoopBackedgeWeight <<= 1;
+        OrigLoopExitWeight <<= 1;
+      }
+    } else {
+      // If there's a higher exit-count than backedge-count then we set
+      // probabilities as if there are only 0-trip and 1-trip cases.
+      ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
+    }
+  }
+  uint32_t ExitWeight1 = OrigLoopExitWeight - ExitWeight0;        // aka x1
+  uint32_t EnterWeight = ExitWeight1;                             // aka y0
+  uint32_t LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; // aka y1
+
+  MDBuilder MDB(LoopBI.getContext());
+  MDNode *LoopWeightMD =
+      MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
+                              SuccsSwapped ? ExitWeight1 : LoopBackWeight);
+  LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
+  if (HasConditionalPreHeader) {
+    MDNode *PreHeaderWeightMD =
+        MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
+                                SuccsSwapped ? ExitWeight0 : EnterWeight);
+    PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
+  }
+}
+
 /// Rotate loop LP. Return true if the loop is rotated.
 ///
 /// \param SimplifiedLatch is true if the latch was just folded into the final
@@ -363,7 +455,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
     // loop.  Otherwise loop is not suitable for rotation.
     BasicBlock *Exit = BI->getSuccessor(0);
     BasicBlock *NewHeader = BI->getSuccessor(1);
-    if (L->contains(Exit))
+    bool BISuccsSwapped = L->contains(Exit);
+    if (BISuccsSwapped)
       std::swap(Exit, NewHeader);
     assert(NewHeader && "Unable to determine new loop header");
     assert(L->contains(NewHeader) && !L->contains(Exit) &&
@@ -605,9 +698,14 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
     // to split as many edges.
     BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator());
     assert(PHBI->isConditional() && "Should be clone of BI condbr!");
-    if (!isa<ConstantInt>(PHBI->getCondition()) ||
-        PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) !=
-        NewHeader) {
+    const Value *Cond = PHBI->getCondition();
+    const bool HasConditionalPreHeader =
+        !isa<ConstantInt>(Cond) ||
+        PHBI->getSuccessor(cast<ConstantInt>(Cond)->isZero()) != NewHeader;
+
+    updateBranchWeights(*PHBI, *BI, HasConditionalPreHeader, BISuccsSwapped);
+
+    if (HasConditionalPreHeader) {
       // The conditional branch can't be folded, handle the general case.
       // Split edges as necessary to preserve LoopSimplify form.
 

diff  --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
new file mode 100644
index 000000000000000..9af6cfab4a2411d
--- /dev/null
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -0,0 +1,156 @@
+; RUN: opt < %s -passes='print<block-freq>' -disable-output 2>&1 | FileCheck %s --check-prefixes=BFI_BEFORE
+; RUN: opt < %s -passes='loop(loop-rotate),print<block-freq>' -disable-output 2>&1 | FileCheck %s --check-prefixes=BFI_AFTER
+; RUN: opt < %s -passes='loop(loop-rotate)' -S | FileCheck %s --check-prefixes=IR
+
+ at g = global i32 0
+
+; We should get the same "count =" results for "outer_loop_body" and
+; "inner_loop_body" before and after the transformation.
+
+; BFI_BEFORE-LABEL: block-frequency-info: func0
+; BFI_BEFORE: - entry: {{.*}} count = 1
+; BFI_BEFORE: - outer_loop_header: {{.*}} count = 1001
+; BFI_BEFORE: - outer_loop_body: {{.*}} count = 1000
+; BFI_BEFORE: - inner_loop_header: {{.*}} count = 4000
+; BFI_BEFORE: - inner_loop_body: {{.*}} count = 3000
+; BFI_BEFORE: - inner_loop_exit: {{.*}} count = 1000
+; BFI_BEFORE: - outer_loop_exit: {{.*}} count = 1
+
+; BFI_AFTER-LABEL: block-frequency-info: func0
+; BFI_AFTER: - entry: {{.*}} count = 1
+; BFI_AFTER: - outer_loop_body: {{.*}} count = 1000
+; BFI_AFTER: - inner_loop_body: {{.*}} count = 3000
+; BFI_AFTER: - inner_loop_exit: {{.*}} count = 1000
+; BFI_AFTER: - outer_loop_exit: {{.*}} count = 1
+
+; IR: inner_loop_body:
+; IR:   br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof [[PROF_FUNC0_0:![0-9]+]]
+; IR: inner_loop_exit:
+; IR:   br i1 %cmp0, label %outer_loop_body, label %outer_loop_exit, !prof [[PROF_FUNC0_1:![0-9]+]]
+;
+; A function with known loop-bounds where after loop-rotation we end with an
+; unconditional branch in the pre-header.
+define void @func0() !prof !0 {
+entry:
+  br label %outer_loop_header
+
+outer_loop_header:
+  %i0 = phi i32 [0, %entry], [%i0_inc, %inner_loop_exit]
+  %cmp0 = icmp slt i32 %i0, 1000
+  br i1 %cmp0, label %outer_loop_body, label %outer_loop_exit, !prof !1
+
+outer_loop_body:
+  store volatile i32 %i0, ptr @g, align 4
+  br label %inner_loop_header
+
+inner_loop_header:
+  %i1 = phi i32 [0, %outer_loop_body], [%i1_inc, %inner_loop_body]
+  %cmp1 = icmp slt i32 %i1, 3
+  br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof !2
+
+inner_loop_body:
+  store volatile i32 %i1, ptr @g, align 4
+  %i1_inc = add i32 %i1, 1
+  br label %inner_loop_header
+
+inner_loop_exit:
+  %i0_inc = add i32 %i0, 1
+  br label %outer_loop_header
+
+outer_loop_exit:
+  ret void
+}
+
+; BFI_BEFORE-LABEL: block-frequency-info: func1
+; BFI_BEFORE: - entry: {{.*}} count = 1024
+; BFI_BEFORE: - loop_header: {{.*}} count = 21504
+; BFI_BEFORE: - loop_body: {{.*}} count = 20480
+; BFI_BEFORE: - loop_exit: {{.*}} count = 1024
+
+; BFI_AFTER-LABEL: block-frequency-info: func1
+; BFI_AFTER: - entry: {{.*}} count = 1024
+; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1024
+; BFI_AFTER: - loop_body: {{.*}} count = 20608
+; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
+; BFI_AFTER: - loop_exit: {{.*}} count = 1024
+
+; IR: entry:
+; IR:   br i1 %cmp1, label %loop_body.lr.ph, label %loop_exit, !prof [[PROF_FUNC1_0:![0-9]+]]
+
+; IR: loop_body:
+; IR:   br i1 %cmp, label %loop_body, label %loop_header.loop_exit_crit_edge, !prof [[PROF_FUNC1_1:![0-9]+]]
+
+; A function with unknown loop-bounds so loop-rotation ends up with a
+; condition jump in pre-header and loop body. branch_weight shows body is
+; executed more often than header.
+define void @func1(i32 %n) !prof !3 {
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i32 [0, %entry], [%i_inc, %loop_body]
+  %cmp = icmp slt i32 %i, %n
+  br i1 %cmp, label %loop_body, label %loop_exit, !prof !4
+
+loop_body:
+  store volatile i32 %i, ptr @g, align 4
+  %i_inc = add i32 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+; BFI_BEFORE-LABEL: block-frequency-info: func2
+; BFI_BEFORE: - entry: {{.*}} count = 1024
+; BFI_BEFORE: - loop_header: {{.*}} count = 1056
+; BFI_BEFORE: - loop_body: {{.*}} count = 32
+; BFI_BEFORE: - loop_exit: {{.*}} count = 1024
+
+; BFI_AFTER-LABEL: block-frequency-info: func2
+; - entry: {{.*}} count = 1024
+; - loop_body.lr.ph: {{.*}} count = 32
+; - loop_body: {{.*}} count = 32
+; - loop_header.loop_exit_crit_edge: {{.*}} count = 32
+; - loop_exit: {{.*}} count = 1024
+
+; IR: entry:
+; IR:   br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC2_0:![0-9]+]]
+
+; IR: loop_body:
+; IR:   br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC2_1:![0-9]+]]
+
+; A function with unknown loop-bounds so loop-rotation ends up with a
+; condition jump in pre-header and loop body. Similar to `func1` but here
+; loop-exit count is higher than backedge count.
+define void @func2(i32 %n) !prof !3 {
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i32 [0, %entry], [%i_inc, %loop_body]
+  %cmp = icmp slt i32 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !5
+
+loop_body:
+  store volatile i32 %i, ptr @g, align 4
+  %i_inc = add i32 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
+!0 = !{!"function_entry_count", i64 1}
+!1 = !{!"branch_weights", i32 1000, i32 1}
+!2 = !{!"branch_weights", i32 3000, i32 1000}
+!3 = !{!"function_entry_count", i64 1024}
+!4 = !{!"branch_weights", i32 40, i32 2}
+!5 = !{!"branch_weights", i32 10240, i32 320}
+
+; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
+; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
+; IR: [[PROF_FUNC1_0]] = !{!"branch_weights", i32 127, i32 1}
+; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
+; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
+; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}

diff  --git a/llvm/test/Transforms/LoopSimplify/merge-exits.ll b/llvm/test/Transforms/LoopSimplify/merge-exits.ll
index f3c685647050507..10e634fd47a1aea 100644
--- a/llvm/test/Transforms/LoopSimplify/merge-exits.ll
+++ b/llvm/test/Transforms/LoopSimplify/merge-exits.ll
@@ -103,7 +103,7 @@ define float @merge_branches_profile_metadata(ptr %pTmp1, ptr %peakWeight, i32 %
 ; CHECK-NEXT:    [[T10:%.*]] = fcmp olt float [[T4]], 2.500000e+00
 ; CHECK-NEXT:    [[T12:%.*]] = icmp sgt i64 [[TMP0]], [[INDVARS_IV_NEXT]]
 ; CHECK-NEXT:    [[OR_COND:%.*]] = and i1 [[T10]], [[T12]]
-; CHECK-NEXT:    br i1 [[OR_COND]], label [[BB]], label [[BB1_BB3_CRIT_EDGE:%.*]], !prof [[PROF0]]
+; CHECK-NEXT:    br i1 [[OR_COND]], label [[BB]], label [[BB1_BB3_CRIT_EDGE:%.*]], !prof [[PROF1:![0-9]+]]
 ; CHECK:       bb1.bb3_crit_edge:
 ; CHECK-NEXT:    [[T4_LCSSA:%.*]] = phi float [ [[T4]], [[BB]] ]
 ; CHECK-NEXT:    [[T9_LCSSA:%.*]] = phi float [ [[T9]], [[BB]] ]


        


More information about the llvm-commits mailing list