[llvm] 3b4775d - [NFC][PGO] Factor downscaling of branch weights out of `Instrumentation` into `ProfileData` (#153735)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 14 22:44:40 PDT 2025


Author: Mircea Trofin
Date: 2025-08-14T22:44:36-07:00
New Revision: 3b4775d31d8b5f4005e6d32539834bec990dd481

URL: https://github.com/llvm/llvm-project/commit/3b4775d31d8b5f4005e6d32539834bec990dd481
DIFF: https://github.com/llvm/llvm-project/commit/3b4775d31d8b5f4005e6d32539834bec990dd481.diff

LOG: [NFC][PGO] Factor downscaling of branch weights out of `Instrumentation` into `ProfileData` (#153735)

The logic isn’t instrumentation-specific, and the refactoring allows users avoid a dependency on `Instrumentation` and just take one on `ProfileData`​ (which a fairly low-level dependency)

Added: 
    

Modified: 
    llvm/include/llvm/IR/ProfDataUtils.h
    llvm/include/llvm/Transforms/Utils/Instrumentation.h
    llvm/lib/IR/ProfDataUtils.cpp
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index ca56e4aa81575..404875285beae 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
 LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
                                bool IsExpected);
 
+/// downscale the given weights preserving the ratio. If the maximum value is
+/// not already known and not provided via \param KnownMaxCount , it will be
+/// obtained from \param Weights.
+LLVM_ABI SmallVector<uint32_t>
+downscaleWeights(ArrayRef<uint64_t> Weights,
+                 std::optional<uint64_t> KnownMaxCount = std::nullopt);
+
+/// Calculate what to divide by to scale counts.
+///
+/// Given the maximum count, calculate a divisor that will scale all the
+/// weights to strictly less than std::numeric_limits<uint32_t>::max().
+inline uint64_t calculateCountScale(uint64_t MaxCount) {
+  return MaxCount < std::numeric_limits<uint32_t>::max()
+             ? 1
+             : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
+}
+
+/// Scale an individual branch count.
+///
+/// Scale a 64-bit weight down to 32-bits using \c Scale.
+///
+inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
+  uint64_t Scaled = Count / Scale;
+  assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+  return Scaled;
+}
+
 /// Specify that the branch weights for this terminator cannot be known at
 /// compile time. This should only be called by passes, and never as a default
 /// behavior in e.g. MDBuilder. The goal is to use this info to validate passes

diff  --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
index 962d9e734a40a..93ab8c693607f 100644
--- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h
+++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
   SanitizerCoverageOptions() = default;
 };
 
-/// Calculate what to divide by to scale counts.
-///
-/// Given the maximum count, calculate a divisor that will scale all the
-/// weights to strictly less than std::numeric_limits<uint32_t>::max().
-static inline uint64_t calculateCountScale(uint64_t MaxCount) {
-  return MaxCount < std::numeric_limits<uint32_t>::max()
-             ? 1
-             : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
-}
-
-/// Scale an individual branch count.
-///
-/// Scale a 64-bit weight down to 32-bits using \c Scale.
-///
-static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
-  uint64_t Scaled = Count / Scale;
-  assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
-  return Scaled;
-}
-
 // Use to ensure the inserted instrumentation has a DebugLocation; if none is
 // attached to the source instruction, try to use a DILocation with offset 0
 // scoped to surrounding function (if it has a DebugLocation).

diff  --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1b5f67689e6d..489fbfef00e4d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
   I.setMetadata(LLVMContext::MD_prof, BranchWeights);
 }
 
+SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
+                                       std::optional<uint64_t> KnownMaxCount) {
+  uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
+                                                : *llvm::max_element(Weights);
+  assert(MaxCount > 0 && "Bad max count");
+  uint64_t Scale = calculateCountScale(MaxCount);
+  SmallVector<unsigned, 4> DownscaledWeights;
+  for (const auto &ECI : Weights)
+    DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
+  return DownscaledWeights;
+}
+
 void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
   assert(T != 0 && "Caller should guarantee");
   auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index e0b22ef94d064..d9e850e7a2bf3 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
 
 void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
                            uint64_t MaxCount) {
-  assert(MaxCount > 0 && "Bad max count");
-  uint64_t Scale = calculateCountScale(MaxCount);
-  SmallVector<unsigned, 4> Weights;
-  for (const auto &ECI : EdgeCounts)
-    Weights.push_back(scaleBranchCount(ECI, Scale));
+  auto Weights = downscaleWeights(EdgeCounts, MaxCount);
 
   LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
                                            : Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
     uint64_t TotalCount =
         std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
                         [](uint64_t c1, uint64_t c2) { return c1 + c2; });
-    Scale = calculateCountScale(WSum);
+    uint64_t Scale = calculateCountScale(WSum);
     BranchProbability BP(scaleBranchCount(Weights[0], Scale),
                          scaleBranchCount(WSum, Scale));
     std::string BranchProbStr;


        


More information about the llvm-commits mailing list