[llvm] [SimplifyCFG] Set branch weights when merging conditional store to address (PR #154841)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 11 10:01:08 PDT 2025
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/154841
>From 5b30a51d5e9977d0853dc2e869950df3cc7fa639 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 21 Aug 2025 13:54:49 -0700
Subject: [PATCH 1/4] [SimplifyCFG] Set branch weights when merging conditional
store to address
---
llvm/include/llvm/IR/ProfDataUtils.h | 27 +++++++++++++++++++
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 16 +++++++++++
.../SimplifyCFG/merge-cond-stores.ll | 14 +++++++---
3 files changed, 53 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index b8386ddc86ca8..967df2ec9e29f 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
@@ -190,5 +191,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+/// get the branch weights of a branch conditioned on b1 | b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+ const SmallVector<uint32_t, 2> &B2) {
+ // for the first conditional branch, the probability the "true" case is taken
+ // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+ // p(not b1) = B1[1] / (B1[0] + B1[1]).
+ // Similarly for the second conditional branch and B2.
+ //
+ // the probability of the new branch NOT being taken is:
+ // not P = p((not b1) and (not b2)) =
+ // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+ // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+ // then the probability of it being taken is: P = 1 - (not P).
+ // The denominator will be the same as above, and the numerator of P will be
+ // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+ // Which then reduces to what's shown below (out of the 4 terms coming out of
+ // the product of sums, the subtracted one cancels out)
+ assert(B1.size() == 2);
+ assert(B2.size() == 2);
+ auto FalseWeight = B1[1] * B2[1];
+ auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+ return {TrueWeight, FalseWeight};
+}
} // namespace llvm
#endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 970f85378d3d2..850e57e6b0b14 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
cl::desc("Limit number of blocks a define in a threaded block is allowed "
"to be live in"));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
STATISTIC(NumLinearMaps,
"Number of switch instructions turned into linear mapping");
@@ -4438,6 +4440,20 @@ static bool mergeConditionalStoreToAddress(
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
+ if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
+ !ProfcheckDisableMetadataFixes) {
+ SmallVector<uint32_t, 2> PWeights, QWeights;
+ extractBranchWeights(*PBranch, PWeights);
+ extractBranchWeights(*QBranch, QWeights);
+ if (InvertPCond)
+ std::swap(PWeights[0], PWeights[1]);
+ if (InvertQCond)
+ std::swap(QWeights[0], QWeights[1]);
+ auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+ setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
+ CombinedWeights[1],
+ /*IsExpected=*/false);
+ }
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
index e1bd7916b3be0..b1cce4484bbab 100644
--- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
+++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s
; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
; CHECK-NEXT: [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
; CHECK-NEXT: [[X3:%.*]] = icmp eq i32 [[B1:%.*]], 0
; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[X2]], [[X3]]
-; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
+; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK: 1:
; CHECK-NEXT: [[SPEC_SELECT:%.*]] = zext i1 [[X3]] to i32
; CHECK-NEXT: store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
;
entry:
%x1 = icmp eq i32 %a, 0
- br i1 %x1, label %yes1, label %fallthrough
+ br i1 %x1, label %yes1, label %fallthrough, !prof !0
yes1:
store i32 0, ptr %p
@@ -61,7 +61,7 @@ yes1:
fallthrough:
%x2 = icmp eq i32 %b, 0
- br i1 %x2, label %yes2, label %end
+ br i1 %x2, label %yes2, label %end, !prof !1
yes2:
store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
end:
ret void
}
+
+!0 = !{!"branch_weights", i32 7, i32 13}
+!1 = !{!"branch_weights", i32 3, i32 11}
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 137, i32 143}
+;.
>From 60a6f9428cd9134156f7e8bc480873ac7fd7bf33 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 10 Sep 2025 15:23:20 -0700
Subject: [PATCH 2/4] Update ProfDataUtils.h
---
llvm/include/llvm/IR/ProfDataUtils.h | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 967df2ec9e29f..bfd00e5322c7d 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -191,9 +191,9 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
-/// get the branch weights of a branch conditioned on b1 | b2, where b1 and b2
-/// are 2 booleans that are the condition of 2 branches for which we have the
-/// branch weights B1 and B2, respectivelly.
+/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the conditions of 2 branches for which we have the
+/// branch weights B1 and B2, respectively.
inline SmallVector<uint64_t, 2>
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
const SmallVector<uint32_t, 2> &B2) {
>From f3540ec759a04aeec404268c05fc73aca3717e31 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 10 Sep 2025 15:24:04 -0700
Subject: [PATCH 3/4] Update ProfDataUtils.h
---
llvm/include/llvm/IR/ProfDataUtils.h | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index bfd00e5322c7d..3480a4ea49435 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -191,22 +191,24 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
-/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
-/// are 2 booleans that are the conditions of 2 branches for which we have the
-/// branch weights B1 and B2, respectively.
+/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the conditions of 2 branches for which we have the
+/// branch weights B1 and B2, respectively. In both B1 and B2, the first
+/// position (index 0) is for the 'true' branch, and the second position (index
+/// 1) is for the 'false' branch.
inline SmallVector<uint64_t, 2>
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
const SmallVector<uint32_t, 2> &B2) {
- // for the first conditional branch, the probability the "true" case is taken
- // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+ // For the first conditional branch, the probability the "true" case is taken
+ // is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
// p(not b1) = B1[1] / (B1[0] + B1[1]).
// Similarly for the second conditional branch and B2.
//
- // the probability of the new branch NOT being taken is:
+ // The probability of the new branch NOT being taken is:
// not P = p((not b1) and (not b2)) =
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
- // then the probability of it being taken is: P = 1 - (not P).
+ // Then the probability of it being taken is: P = 1 - (not P).
// The denominator will be the same as above, and the numerator of P will be
// (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
// Which then reduces to what's shown below (out of the 4 terms coming out of
>From 557c37c9e557a80326528bde4c3f1945a6d5e2ce Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 11 Sep 2025 10:00:56 -0700
Subject: [PATCH 4/4] Update ProfDataUtils.h
---
llvm/include/llvm/IR/ProfDataUtils.h | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 3480a4ea49435..8fcd913c92997 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -209,10 +209,10 @@ getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
// Then the probability of it being taken is: P = 1 - (not P).
- // The denominator will be the same as above, and the numerator of P will be
- // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
- // Which then reduces to what's shown below (out of the 4 terms coming out of
- // the product of sums, the subtracted one cancels out)
+ // The denominator will be the same as above, and the numerator of P will be:
+ // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+ // Which then reduces to what's shown below (out of the 4 terms coming out of
+ // the product of sums, the subtracted one cancels out).
assert(B1.size() == 2);
assert(B2.size() == 2);
auto FalseWeight = B1[1] * B2[1];
More information about the llvm-commits
mailing list