[llvm-branch-commits] [llvm] [SimplifyCFG] Set branch weights when merging conditional store to address (PR #154841)

Mircea Trofin via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Aug 21 13:55:17 PDT 2025


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/154841

None

>From f4441cbe5e38f6abc76604a8049f6e36fb4881a7 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 21 Aug 2025 13:54:49 -0700
Subject: [PATCH] [SimplifyCFG] Set branch weights when merging conditional
 store to address

---
 llvm/include/llvm/IR/ProfDataUtils.h      | 22 +++++++++++++
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp | 39 +++++++++++++++--------
 2 files changed, 48 insertions(+), 13 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..c9284c1bc8dde 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
 #ifndef LLVM_IR_PROFDATAUTILS_H
 #define LLVM_IR_PROFDATAUTILS_H
 
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/IR/Metadata.h"
@@ -186,5 +187,26 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
 /// Scaling the profile data attached to 'I' using the ratio of S/T.
 LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
 
+/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+                      const SmallVector<uint32_t, 2> &B2) {
+  // the probability of the new branch being taken is:
+  // P = p(b1) + p(b2) - p (b1 and b2)
+  // not P = p((not b1) and (not b2)) = 
+  //       = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+  //       = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+  // P = 1 - (not P)
+  // The numerator of P will be (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+  // ... which becomes what's shown below.
+  // We don't need the denominators, they are the same
+  assert(B1.size() == 2);
+  assert(B2.size() == 2);
+  auto FalseWeight = B1[1] * B2[1];
+  auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+  return {TrueWeight, FalseWeight};
+}
 } // namespace llvm
 #endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 4847add386dc4..e26a189564d13 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
     // only given the branch precondition.
     // Similarly strip attributes on call parameters that may cause UB in
     // location the call is moved to.
-    NewBonusInst->dropUBImplyingAttrsAndMetadata();
+    NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
 
     NewBonusInst->insertInto(PredBlock, PTI->getIterator());
     auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
@@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores(
     // !annotation: Not impact semantics. Keep it.
     if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
       MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges));
-    I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation});
+    I->dropUBImplyingAttrsAndUnknownMetadata(
+        {LLVMContext::MD_annotation, LLVMContext::MD_prof});
     // FIXME: DIAssignID is not supported for masked store yet.
     // (Verifier::visitDIAssignIDMetadata)
     at::deleteAssignmentMarkers(I);
@@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
     if (!SpeculatedStoreValue || &I != SpeculatedStore) {
       I.setDebugLoc(DebugLoc::getDropped());
     }
-    I.dropUBImplyingAttrsAndMetadata();
+    I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
 
     // Drop ephemeral values.
     if (EphTracker.contains(&I)) {
@@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress(
 
   // OK, we're going to sink the stores to PostBB. The store has to be
   // conditional though, so first create the predicate.
-  Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator())
-                     ->getCondition();
-  Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator())
-                     ->getCondition();
+  BranchInst *const PBranch =
+      cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator());
+  BranchInst *const QBranch =
+      cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator());
+  Value *const PCond = PBranch->getCondition();
+  Value *const QCond = QBranch->getCondition();
 
   Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(),
                                                 PStore->getParent());
@@ -4418,19 +4421,29 @@ static bool mergeConditionalStoreToAddress(
   IRBuilder<> QB(PostBB, PostBBFirst);
   QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());
 
-  Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
-  Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
+  InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond;
+  InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond;
+  Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond);
+  Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond);
 
-  if (InvertPCond)
-    PPred = QB.CreateNot(PPred);
-  if (InvertQCond)
-    QPred = QB.CreateNot(QPred);
   Value *CombinedPred = QB.CreateOr(PPred, QPred);
 
   BasicBlock::iterator InsertPt = QB.GetInsertPoint();
   auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
                                       /*Unreachable=*/false,
                                       /*BranchWeights=*/nullptr, DTU);
+  if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) {
+    SmallVector<uint32_t, 2> PWeights, QWeights;
+    extractBranchWeights(*PBranch, PWeights);
+    extractBranchWeights(*QBranch, QWeights);
+    if (InvertPCond)
+      std::swap(PWeights[0], PWeights[1]);
+    if (InvertQCond)
+      std::swap(QWeights[0], QWeights[1]);
+    auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+    setBranchWeights(T, CombinedWeights[0], CombinedWeights[1],
+                     /*IsExpected=*/false);
+  }
 
   QB.SetInsertPoint(T);
   StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));



More information about the llvm-branch-commits mailing list