[llvm] [SimplifyCFG] Set branch weights when merging conditional store to address (PR #154841)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 11 10:01:08 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/154841

>From 5b30a51d5e9977d0853dc2e869950df3cc7fa639 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 21 Aug 2025 13:54:49 -0700
Subject: [PATCH 1/4] [SimplifyCFG] Set branch weights when merging conditional
 store to address

---
 llvm/include/llvm/IR/ProfDataUtils.h          | 27 +++++++++++++++++++
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     | 16 +++++++++++
 .../SimplifyCFG/merge-cond-stores.ll          | 14 +++++++---
 3 files changed, 53 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b8386ddc86ca8..967df2ec9e29f 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
 #ifndef LLVM_IR_PROFDATAUTILS_H
 #define LLVM_IR_PROFDATAUTILS_H
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/IR/Metadata.h"
@@ -190,5 +191,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
 LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 
+/// get the branch weights of a branch conditioned on b1 | b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+                      const SmallVector<uint32_t, 2> &B2) {
+  // for the first conditional branch, the probability the "true" case is taken
+  // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+  // p(not b1) = B1[1] / (B1[0] + B1[1]).
+  // Similarly for the second conditional branch and B2.
+  //
+  // the probability of the new branch NOT being taken is:
+  // not P = p((not b1) and (not b2)) =
+  //       = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+  //       = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+  // then the probability of it being taken is: P = 1 - (not P).
+  // The denominator will be the same as above, and the numerator of P will be
+  // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+  // Which then reduces to what's shown below (out of the 4 terms coming out of
+  // the product of sums, the subtracted one cancels out)
+  assert(B1.size() == 2);
+  assert(B2.size() == 2);
+  auto FalseWeight = B1[1] * B2[1];
+  auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+  return {TrueWeight, FalseWeight};
+}
 } // namespace llvm
 #endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 970f85378d3d2..850e57e6b0b14 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
     cl::desc("Limit number of blocks a define in a threaded block is allowed "
              "to be live in"));
 
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
 STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps,
           "Number of switch instructions turned into linear mapping");
@@ -4438,6 +4440,20 @@ static bool mergeConditionalStoreToAddress(
   auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
                                       /*Unreachable=*/false,
                                       /*BranchWeights=*/nullptr, DTU);
+  if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
+      !ProfcheckDisableMetadataFixes) {
+    SmallVector<uint32_t, 2> PWeights, QWeights;
+    extractBranchWeights(*PBranch, PWeights);
+    extractBranchWeights(*QBranch, QWeights);
+    if (InvertPCond)
+      std::swap(PWeights[0], PWeights[1]);
+    if (InvertQCond)
+      std::swap(QWeights[0], QWeights[1]);
+    auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+    setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
+                     CombinedWeights[1],
+                     /*IsExpected=*/false);
+  }
 
   QB.SetInsertPoint(T);
   StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
index e1bd7916b3be0..b1cce4484bbab 100644
--- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
+++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
 ; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s
 
 ; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
 ; CHECK-NEXT:    [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
 ; CHECK-NEXT:    [[X3:%.*]] = icmp eq i32 [[B1:%.*]], 0
 ; CHECK-NEXT:    [[TMP0:%.*]] = or i1 [[X2]], [[X3]]
-; CHECK-NEXT:    br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
+; CHECK-NEXT:    br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]]
 ; CHECK:       1:
 ; CHECK-NEXT:    [[SPEC_SELECT:%.*]] = zext i1 [[X3]] to i32
 ; CHECK-NEXT:    store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
 ;
 entry:
   %x1 = icmp eq i32 %a, 0
-  br i1 %x1, label %yes1, label %fallthrough
+  br i1 %x1, label %yes1, label %fallthrough, !prof !0
 
 yes1:
   store i32 0, ptr %p
@@ -61,7 +61,7 @@ yes1:
 
 fallthrough:
   %x2 = icmp eq i32 %b, 0
-  br i1 %x2, label %yes2, label %end
+  br i1 %x2, label %yes2, label %end, !prof !1
 
 yes2:
   store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
 end:
   ret void
 }
+
+!0 = !{!"branch_weights", i32 7, i32 13}
+!1 = !{!"branch_weights", i32 3, i32 11}
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 137, i32 143}
+;.

>From 60a6f9428cd9134156f7e8bc480873ac7fd7bf33 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 10 Sep 2025 15:23:20 -0700
Subject: [PATCH 2/4] Update ProfDataUtils.h

---
 llvm/include/llvm/IR/ProfDataUtils.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 967df2ec9e29f..bfd00e5322c7d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -191,9 +191,9 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
 LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 
-/// get the branch weights of a branch conditioned on b1 | b2, where b1 and b2
-/// are 2 booleans that are the condition of 2 branches for which we have the
-/// branch weights B1 and B2, respectivelly.
+/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the conditions of 2 branches for which we have the
+/// branch weights B1 and B2, respectively.
 inline SmallVector<uint64_t, 2>
 getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
                       const SmallVector<uint32_t, 2> &B2) {

>From f3540ec759a04aeec404268c05fc73aca3717e31 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 10 Sep 2025 15:24:04 -0700
Subject: [PATCH 3/4] Update ProfDataUtils.h

---
 llvm/include/llvm/IR/ProfDataUtils.h | 16 +++++++++-------
 1 file changed, 9 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index bfd00e5322c7d..3480a4ea49435 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -191,22 +191,24 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
 LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 
-/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
-/// are 2 booleans that are the conditions of 2 branches for which we have the
-/// branch weights B1 and B2, respectively.
+/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the conditions of 2 branches for which we have the
+/// branch weights B1 and B2, respectively. In both B1 and B2, the first
+/// position (index 0) is for the 'true' branch, and the second position (index
+/// 1) is for the 'false' branch.
 inline SmallVector<uint64_t, 2>
 getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
                       const SmallVector<uint32_t, 2> &B2) {
-  // for the first conditional branch, the probability the "true" case is taken
-  // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+  // For the first conditional branch, the probability the "true" case is taken
+  // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
   // p(not b1) = B1[1] / (B1[0] + B1[1]).
   // Similarly for the second conditional branch and B2.
   //
-  // the probability of the new branch NOT being taken is:
+  // The probability of the new branch NOT being taken is:
   // not P = p((not b1) and (not b2)) =
   //       = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
   //       = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
-  // then the probability of it being taken is: P = 1 - (not P).
+  // Then the probability of it being taken is: P = 1 - (not P).
   // The denominator will be the same as above, and the numerator of P will be
   // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
   // Which then reduces to what's shown below (out of the 4 terms coming out of

>From 557c37c9e557a80326528bde4c3f1945a6d5e2ce Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 11 Sep 2025 10:00:56 -0700
Subject: [PATCH 4/4] Update ProfDataUtils.h

---
 llvm/include/llvm/IR/ProfDataUtils.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 3480a4ea49435..8fcd913c92997 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -209,10 +209,10 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
   //       = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
   //       = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
   // Then the probability of it being taken is: P = 1 - (not P).
-  // The denominator will be the same as above, and the numerator of P will be
-  // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
-  // Which then reduces to what's shown below (out of the 4 terms coming out of
-  // the product of sums, the subtracted one cancels out)
+  // The denominator will be the same as above, and the numerator of P will be:
+  // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+  // Which then reduces to what's shown below (out of the 4 terms coming out of
+  // the product of sums, the subtracted one cancels out).
   assert(B1.size() == 2);
   assert(B2.size() == 2);
   auto FalseWeight = B1[1] * B2[1];



More information about the llvm-commits mailing list