[llvm-branch-commits] [llvm] [SimplifyCFG] Set branch weights when merging conditional store to address (PR #154841)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Aug 25 13:16:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Mircea Trofin (mtrofin)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/154841.diff
3 Files Affected:
- (modified) llvm/include/llvm/IR/ProfDataUtils.h (+27)
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+16)
- (modified) llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll (+10-4)
``````````diff
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..ebf8559cd3d91 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
@@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+ const SmallVector<uint32_t, 2> &B2) {
+ // for the first conditional branch, the probability the "true" case is taken
+ // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+ // p(not b1) = B1[1] / (B1[0] + B1[1]).
+ // Similarly for the second conditional branch and B2.
+ //
+ // the probability of the new branch NOT being taken is:
+ // not P = p((not b1) and (not b2)) =
+ // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+ // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+ // then the probability of it being taken is: P = 1 - (not P).
+ // The denominator will be the same as above, and the numerator of P will be
+ // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+ // Which then reduces to what's shown below (out of the 4 terms coming out of
+ // the product of sums, the subtracted one cancels out)
+ assert(B1.size() == 2);
+ assert(B2.size() == 2);
+ auto FalseWeight = B1[1] * B2[1];
+ auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+ return {TrueWeight, FalseWeight};
+}
} // namespace llvm
#endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 270598e2b674b..370b282d1b14d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
cl::desc("Limit number of blocks a define in a threaded block is allowed "
"to be live in"));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
STATISTIC(NumLinearMaps,
"Number of switch instructions turned into linear mapping");
@@ -4431,6 +4433,20 @@ static bool mergeConditionalStoreToAddress(
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
+ if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
+ !ProfcheckDisableMetadataFixes) {
+ SmallVector<uint32_t, 2> PWeights, QWeights;
+ extractBranchWeights(*PBranch, PWeights);
+ extractBranchWeights(*QBranch, QWeights);
+ if (InvertPCond)
+ std::swap(PWeights[0], PWeights[1]);
+ if (InvertQCond)
+ std::swap(QWeights[0], QWeights[1]);
+ auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+ setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
+ CombinedWeights[1],
+ /*IsExpected=*/false);
+ }
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
index b5c4b8aa51db4..ee723463d4b06 100644
--- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
+++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s
; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
; CHECK-NEXT: [[X1_NOT:%.*]] = icmp eq i32 [[A:%.*]], 0
; CHECK-NEXT: [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[X1_NOT]], [[X2]]
-; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
+; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK: 1:
; CHECK-NEXT: [[SPEC_SELECT:%.*]] = zext i1 [[X2]] to i32
; CHECK-NEXT: store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
;
entry:
%x1 = icmp eq i32 %a, 0
- br i1 %x1, label %yes1, label %fallthrough
+ br i1 %x1, label %yes1, label %fallthrough, !prof !0
yes1:
store i32 0, ptr %p
@@ -61,7 +61,7 @@ yes1:
fallthrough:
%x2 = icmp eq i32 %b, 0
- br i1 %x2, label %yes2, label %end
+ br i1 %x2, label %yes2, label %end, !prof !1
yes2:
store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
end:
ret void
}
+
+!0 = !{!"branch_weights", i32 7, i32 13}
+!1 = !{!"branch_weights", i32 3, i32 11}
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 259, i32 21}
+;.
``````````
</details>
https://github.com/llvm/llvm-project/pull/154841
More information about the llvm-branch-commits
mailing list