[llvm] LoopRotationUtils: Fix underflow for zero-branch weights (PR #66681)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 18 11:27:35 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

<details>
<summary>Changes</summary>

Add special case for loops with zero-weight on their exit edge to avoid an underflow. This fixes #<!-- -->66675

---
Full diff: https://github.com/llvm/llvm-project/pull/66681.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Utils/LoopRotationUtils.cpp (+37-24) 
- (modified) llvm/test/Transforms/LoopRotate/update-branch-weights.ll (+43) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
index 22effcf7d88afd2..fd0e43bd2433a88 100644
--- a/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
@@ -295,33 +295,46 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
   // We cannot generally deduce how often we had a zero-trip count loop so we
   // have to make a guess for how to distribute x among the new x0 and x1.
 
-  uint32_t ExitWeight0 = 0; // aka x0
-  if (HasConditionalPreHeader) {
-    // Here we cannot know how many 0-trip count loops we have, so we guess:
-    if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
-      // If the loop count is bigger than the exit count then we set
-      // probabilities as if 0-trip count nearly never happens.
-      ExitWeight0 = ZeroTripCountWeights[0];
-      // Scale up counts if necessary so we can match `ZeroTripCountWeights` for
-      // the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
-      while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
-        // ... but don't overflow.
-        uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
-        if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
-            (OrigLoopExitWeight & HighBit) != 0)
-          break;
-        OrigLoopBackedgeWeight <<= 1;
-        OrigLoopExitWeight <<= 1;
+  uint32_t ExitWeight0;    // aka x0
+  uint32_t ExitWeight1;    // aka x1
+  uint32_t EnterWeight;    // aka y0
+  uint32_t LoopBackWeight; // aka y1
+  if (OrigLoopExitWeight == 0) {
+    // Special case "LoopExitWeight == 0" weights which behaves like an endless
+    // where we don't want loop-enttry (y0) to be the same as loop-exit (x1).
+    ExitWeight0 = 0;
+    ExitWeight1 = 0;
+    EnterWeight = 1;
+    LoopBackWeight = OrigLoopBackedgeWeight;
+  } else {
+    ExitWeight0 = 0;
+    if (HasConditionalPreHeader) {
+      // Here we cannot know how many 0-trip count loops we have, so we guess:
+      if (OrigLoopBackedgeWeight > OrigLoopExitWeight) {
+        // If the loop count is bigger than the exit count then we set
+        // probabilities as if 0-trip count nearly never happens.
+        ExitWeight0 = ZeroTripCountWeights[0];
+        // Scale up counts if necessary so we can match `ZeroTripCountWeights`
+        // for the `ExitWeight0`:`ExitWeight1` (aka `x0`:`x1` ratio`) ratio.
+        while (OrigLoopExitWeight < ZeroTripCountWeights[1] + ExitWeight0) {
+          // ... but don't overflow.
+          uint32_t const HighBit = uint32_t{1} << (sizeof(uint32_t) * 8 - 1);
+          if ((OrigLoopBackedgeWeight & HighBit) != 0 ||
+              (OrigLoopExitWeight & HighBit) != 0)
+            break;
+          OrigLoopBackedgeWeight <<= 1;
+          OrigLoopExitWeight <<= 1;
+        }
+      } else {
+        // If there's a higher exit-count than backedge-count then we set
+        // probabilities as if there are only 0-trip and 1-trip cases.
+        ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
       }
-    } else {
-      // If there's a higher exit-count than backedge-count then we set
-      // probabilities as if there are only 0-trip and 1-trip cases.
-      ExitWeight0 = OrigLoopExitWeight - OrigLoopBackedgeWeight;
     }
+    ExitWeight1 = OrigLoopExitWeight - ExitWeight0;
+    EnterWeight = ExitWeight1;
+    LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight;
   }
-  uint32_t ExitWeight1 = OrigLoopExitWeight - ExitWeight0;        // aka x1
-  uint32_t EnterWeight = ExitWeight1;                             // aka y0
-  uint32_t LoopBackWeight = OrigLoopBackedgeWeight - EnterWeight; // aka y1
 
   MDBuilder MDB(LoopBI.getContext());
   MDNode *LoopWeightMD =
diff --git a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
index 9af6cfab4a2411d..c3882d3087345ca 100644
--- a/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
+++ b/llvm/test/Transforms/LoopRotate/update-branch-weights.ll
@@ -23,6 +23,7 @@
 ; BFI_AFTER: - inner_loop_exit: {{.*}} count = 1000
 ; BFI_AFTER: - outer_loop_exit: {{.*}} count = 1
 
+; IR-LABEL: define void @func0
 ; IR: inner_loop_body:
 ; IR:   br i1 %cmp1, label %inner_loop_body, label %inner_loop_exit, !prof [[PROF_FUNC0_0:![0-9]+]]
 ; IR: inner_loop_exit:
@@ -74,6 +75,7 @@ outer_loop_exit:
 ; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
 ; BFI_AFTER: - loop_exit: {{.*}} count = 1024
 
+; IR-LABEL: define void @func1
 ; IR: entry:
 ; IR:   br i1 %cmp1, label %loop_body.lr.ph, label %loop_exit, !prof [[PROF_FUNC1_0:![0-9]+]]
 
@@ -114,6 +116,7 @@ loop_exit:
 ; - loop_header.loop_exit_crit_edge: {{.*}} count = 32
 ; - loop_exit: {{.*}} count = 1024
 
+; IR-LABEL: define void @func2
 ; IR: entry:
 ; IR:   br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC2_0:![0-9]+]]
 
@@ -141,12 +144,51 @@ loop_exit:
   ret void
 }
 
+; BFI_BEFORE-LABEL: block-frequency-info: func3_zero_branch_weight
+; BFI_BEFORE: - entry: {{.*}} count = 1024
+; BFI_BEFORE: - loop_header: {{.*}} count = 2199023255296
+; BFI_BEFORE: - loop_body: {{.*}} count = 2199023254272
+; BFI_BEFORE: - loop_exit: {{.*}} count = 1024
+
+; BFI_AFTER-LABEL: block-frequency-info: func3_zero_branch_weight
+; BFI_AFTER: - entry: {{.*}} count = 1024
+; BFI_AFTER: - loop_body.lr.ph: {{.*}} count = 1024
+; BFI_AFTER: - loop_body: {{.*}} count = 2199023255296
+; BFI_AFTER: - loop_header.loop_exit_crit_edge: {{.*}} count = 1024
+; BFI_AFTER: - loop_exit: {{.*}} count = 1024
+
+; IR-LABEL: define void @func3_zero_branch_weight
+; IR: entry:
+; IR:   br i1 %cmp1, label %loop_exit, label %loop_body.lr.ph, !prof [[PROF_FUNC3_0:![0-9]+]]
+
+; IR: loop_body:
+; IR:   br i1 %cmp, label %loop_header.loop_exit_crit_edge, label %loop_body, !prof [[PROF_FUNC3_0]]
+
+define void @func3_zero_branch_weight(i32 %n) !prof !3 {
+entry:
+  br label %loop_header
+
+loop_header:
+  %i = phi i32 [0, %entry], [%i_inc, %loop_body]
+  %cmp = icmp slt i32 %i, %n
+  br i1 %cmp, label %loop_exit, label %loop_body, !prof !6
+
+loop_body:
+  store volatile i32 %i, ptr @g, align 4
+  %i_inc = add i32 %i, 1
+  br label %loop_header
+
+loop_exit:
+  ret void
+}
+
 !0 = !{!"function_entry_count", i64 1}
 !1 = !{!"branch_weights", i32 1000, i32 1}
 !2 = !{!"branch_weights", i32 3000, i32 1000}
 !3 = !{!"function_entry_count", i64 1024}
 !4 = !{!"branch_weights", i32 40, i32 2}
 !5 = !{!"branch_weights", i32 10240, i32 320}
+!6 = !{!"branch_weights", i32 0, i32 1}
 
 ; IR: [[PROF_FUNC0_0]] = !{!"branch_weights", i32 2000, i32 1000}
 ; IR: [[PROF_FUNC0_1]] = !{!"branch_weights", i32 999, i32 1}
@@ -154,3 +196,4 @@ loop_exit:
 ; IR: [[PROF_FUNC1_1]] = !{!"branch_weights", i32 2433, i32 127}
 ; IR: [[PROF_FUNC2_0]] = !{!"branch_weights", i32 9920, i32 320}
 ; IR: [[PROF_FUNC2_1]] = !{!"branch_weights", i32 320, i32 0}
+; IR: [[PROF_FUNC3_0]] = !{!"branch_weights", i32 0, i32 1}

``````````

</details>


https://github.com/llvm/llvm-project/pull/66681


More information about the llvm-commits mailing list