[llvm-branch-commits] [llvm] [SimplifyCFG] Set branch weights when merging conditional store to address (PR #154841)
Mircea Trofin via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 21 13:55:17 PDT 2025
https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/154841
None
>From f4441cbe5e38f6abc76604a8049f6e36fb4881a7 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 21 Aug 2025 13:54:49 -0700
Subject: [PATCH] [SimplifyCFG] Set branch weights when merging conditional
store to address
---
llvm/include/llvm/IR/ProfDataUtils.h | 22 +++++++++++++
llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 39 +++++++++++++++--------
2 files changed, 48 insertions(+), 13 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..c9284c1bc8dde 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
@@ -186,5 +187,26 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+ const SmallVector<uint32_t, 2> &B2) {
+ // the probability of the new branch being taken is:
+ // P = p(b1) + p(b2) - p (b1 and b2)
+ // not P = p((not b1) and (not b2)) =
+ // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+ // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+ // P = 1 - (not P)
+ // The numerator of P will be (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+ // ... which becomes what's shown below.
+ // We don't need the denominators, they are the same
+ assert(B1.size() == 2);
+ assert(B2.size() == 2);
+ auto FalseWeight = B1[1] * B2[1];
+ auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+ return {TrueWeight, FalseWeight};
+}
} // namespace llvm
#endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 4847add386dc4..e26a189564d13 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
// only given the branch precondition.
// Similarly strip attributes on call parameters that may cause UB in
// location the call is moved to.
- NewBonusInst->dropUBImplyingAttrsAndMetadata();
+ NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
NewBonusInst->insertInto(PredBlock, PTI->getIterator());
auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
@@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores(
// !annotation: Not impact semantics. Keep it.
if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges));
- I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation});
+ I->dropUBImplyingAttrsAndUnknownMetadata(
+ {LLVMContext::MD_annotation, LLVMContext::MD_prof});
// FIXME: DIAssignID is not supported for masked store yet.
// (Verifier::visitDIAssignIDMetadata)
at::deleteAssignmentMarkers(I);
@@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
if (!SpeculatedStoreValue || &I != SpeculatedStore) {
I.setDebugLoc(DebugLoc::getDropped());
}
- I.dropUBImplyingAttrsAndMetadata();
+ I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
// Drop ephemeral values.
if (EphTracker.contains(&I)) {
@@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress(
// OK, we're going to sink the stores to PostBB. The store has to be
// conditional though, so first create the predicate.
- Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator())
- ->getCondition();
- Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator())
- ->getCondition();
+ BranchInst *const PBranch =
+ cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator());
+ BranchInst *const QBranch =
+ cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator());
+ Value *const PCond = PBranch->getCondition();
+ Value *const QCond = QBranch->getCondition();
Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(),
PStore->getParent());
@@ -4418,19 +4421,29 @@ static bool mergeConditionalStoreToAddress(
IRBuilder<> QB(PostBB, PostBBFirst);
QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());
- Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
- Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
+ InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond;
+ InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond;
+ Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond);
+ Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond);
- if (InvertPCond)
- PPred = QB.CreateNot(PPred);
- if (InvertQCond)
- QPred = QB.CreateNot(QPred);
Value *CombinedPred = QB.CreateOr(PPred, QPred);
BasicBlock::iterator InsertPt = QB.GetInsertPoint();
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
+ if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) {
+ SmallVector<uint32_t, 2> PWeights, QWeights;
+ extractBranchWeights(*PBranch, PWeights);
+ extractBranchWeights(*QBranch, QWeights);
+ if (InvertPCond)
+ std::swap(PWeights[0], PWeights[1]);
+ if (InvertQCond)
+ std::swap(QWeights[0], QWeights[1]);
+ auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+ setBranchWeights(T, CombinedWeights[0], CombinedWeights[1],
+ /*IsExpected=*/false);
+ }
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
More information about the llvm-branch-commits
mailing list