[llvm] [NFC][PGO] Factor downscaling of branch weights out of `Instrumentation` into `ProfileData` (PR #153735)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 22:32:18 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/153735

>From 5ffa9b30afce7564f908a66d2a41541d5a2a45b5 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 14 Aug 2025 20:26:06 -0700
Subject: [PATCH] [NFC][PGO] Factor downscaling of branch weights out of
 `Instrumentation` into `ProfileData`

---
 llvm/include/llvm/IR/ProfDataUtils.h          | 27 +++++++++++++++++++
 .../llvm/Transforms/Utils/Instrumentation.h   | 20 --------------
 llvm/lib/IR/ProfDataUtils.cpp                 | 12 +++++++++
 .../Instrumentation/PGOInstrumentation.cpp    |  8 ++----
 4 files changed, 41 insertions(+), 26 deletions(-)

diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index ca56e4aa81575..404875285beae 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
 LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
                                bool IsExpected);
 
+/// downscale the given weights preserving the ratio. If the maximum value is
+/// not already known and not provided via \param KnownMaxCount , it will be
+/// obtained from \param Weights.
+LLVM_ABI SmallVector<uint32_t>
+downscaleWeights(ArrayRef<uint64_t> Weights,
+                 std::optional<uint64_t> KnownMaxCount = std::nullopt);
+
+/// Calculate what to divide by to scale counts.
+///
+/// Given the maximum count, calculate a divisor that will scale all the
+/// weights to strictly less than std::numeric_limits<uint32_t>::max().
+inline uint64_t calculateCountScale(uint64_t MaxCount) {
+  return MaxCount < std::numeric_limits<uint32_t>::max()
+             ? 1
+             : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
+}
+
+/// Scale an individual branch count.
+///
+/// Scale a 64-bit weight down to 32-bits using \c Scale.
+///
+inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
+  uint64_t Scaled = Count / Scale;
+  assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+  return Scaled;
+}
+
 /// Specify that the branch weights for this terminator cannot be known at
 /// compile time. This should only be called by passes, and never as a default
 /// behavior in e.g. MDBuilder. The goal is to use this info to validate passes
diff --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
index 962d9e734a40a..93ab8c693607f 100644
--- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h
+++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
   SanitizerCoverageOptions() = default;
 };
 
-/// Calculate what to divide by to scale counts.
-///
-/// Given the maximum count, calculate a divisor that will scale all the
-/// weights to strictly less than std::numeric_limits<uint32_t>::max().
-static inline uint64_t calculateCountScale(uint64_t MaxCount) {
-  return MaxCount < std::numeric_limits<uint32_t>::max()
-             ? 1
-             : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
-}
-
-/// Scale an individual branch count.
-///
-/// Scale a 64-bit weight down to 32-bits using \c Scale.
-///
-static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
-  uint64_t Scaled = Count / Scale;
-  assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
-  return Scaled;
-}
-
 // Use to ensure the inserted instrumentation has a DebugLocation; if none is
 // attached to the source instruction, try to use a DILocation with offset 0
 // scoped to surrounding function (if it has a DebugLocation).
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1b5f67689e6d..489fbfef00e4d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
 }
 
+SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
+                                       std::optional<uint64_t> KnownMaxCount) {
+  uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
+                                                : *llvm::max_element(Weights);
+  assert(MaxCount > 0 && "Bad max count");
+  uint64_t Scale = calculateCountScale(MaxCount);
+  SmallVector<unsigned, 4> DownscaledWeights;
+  for (const auto &ECI : Weights)
+    DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
+  return DownscaledWeights;
+}
+
 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
   assert(T != 0 && "Caller should guarantee");
   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index e0b22ef94d064..d9e850e7a2bf3 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
                            uint64_t MaxCount) {
-  assert(MaxCount > 0 && "Bad max count");
-  uint64_t Scale = calculateCountScale(MaxCount);
-  SmallVector<unsigned, 4> Weights;
-  for (const auto &ECI : EdgeCounts)
-    Weights.push_back(scaleBranchCount(ECI, Scale));
+  auto Weights = downscaleWeights(EdgeCounts, MaxCount);
 
   LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
                                            : Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
     uint64_t TotalCount =
         std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
                         [](uint64_t c1, uint64_t c2) { return c1 + c2; });
-    Scale = calculateCountScale(WSum);
+    uint64_t Scale = calculateCountScale(WSum);
     BranchProbability BP(scaleBranchCount(Weights[0], Scale),
                          scaleBranchCount(WSum, Scale));
     std::string BranchProbStr;



More information about the llvm-commits mailing list