[llvm] [NFC][PGO] Factor downscaling of branch weights out of `Instrumentation` into `ProfileData` (PR #153735)
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 14 22:32:18 PDT 2025
https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/153735
>From 5ffa9b30afce7564f908a66d2a41541d5a2a45b5 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 14 Aug 2025 20:26:06 -0700
Subject: [PATCH] [NFC][PGO] Factor downscaling of branch weights out of
`Instrumentation` into `ProfileData`
---
llvm/include/llvm/IR/ProfDataUtils.h | 27 +++++++++++++++++++
.../llvm/Transforms/Utils/Instrumentation.h | 20 --------------
llvm/lib/IR/ProfDataUtils.cpp | 12 +++++++++
.../Instrumentation/PGOInstrumentation.cpp | 8 ++----
4 files changed, 41 insertions(+), 26 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index ca56e4aa81575..404875285beae 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);
+/// downscale the given weights preserving the ratio. If the maximum value is
+/// not already known and not provided via \param KnownMaxCount , it will be
+/// obtained from \param Weights.
+LLVM_ABI SmallVector<uint32_t>
+downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount = std::nullopt);
+
+/// Calculate what to divide by to scale counts.
+///
+/// Given the maximum count, calculate a divisor that will scale all the
+/// weights to strictly less than std::numeric_limits<uint32_t>::max().
+inline uint64_t calculateCountScale(uint64_t MaxCount) {
+ return MaxCount < std::numeric_limits<uint32_t>::max()
+ ? 1
+ : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
+}
+
+/// Scale an individual branch count.
+///
+/// Scale a 64-bit weight down to 32-bits using \c Scale.
+///
+inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
+ uint64_t Scaled = Count / Scale;
+ assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+ return Scaled;
+}
+
/// Specify that the branch weights for this terminator cannot be known at
/// compile time. This should only be called by passes, and never as a default
/// behavior in e.g. MDBuilder. The goal is to use this info to validate passes
diff --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
index 962d9e734a40a..93ab8c693607f 100644
--- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h
+++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
SanitizerCoverageOptions() = default;
};
-/// Calculate what to divide by to scale counts.
-///
-/// Given the maximum count, calculate a divisor that will scale all the
-/// weights to strictly less than std::numeric_limits<uint32_t>::max().
-static inline uint64_t calculateCountScale(uint64_t MaxCount) {
- return MaxCount < std::numeric_limits<uint32_t>::max()
- ? 1
- : MaxCount / std::numeric_limits<uint32_t>::max() + 1;
-}
-
-/// Scale an individual branch count.
-///
-/// Scale a 64-bit weight down to 32-bits using \c Scale.
-///
-static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
- uint64_t Scaled = Count / Scale;
- assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
- return Scaled;
-}
-
// Use to ensure the inserted instrumentation has a DebugLocation; if none is
// attached to the source instruction, try to use a DILocation with offset 0
// scoped to surrounding function (if it has a DebugLocation).
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1b5f67689e6d..489fbfef00e4d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
+SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
+ std::optional<uint64_t> KnownMaxCount) {
+ uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
+ : *llvm::max_element(Weights);
+ assert(MaxCount > 0 && "Bad max count");
+ uint64_t Scale = calculateCountScale(MaxCount);
+ SmallVector<unsigned, 4> DownscaledWeights;
+ for (const auto &ECI : Weights)
+ DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
+ return DownscaledWeights;
+}
+
void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
assert(T != 0 && "Caller should guarantee");
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index e0b22ef94d064..d9e850e7a2bf3 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t MaxCount) {
- assert(MaxCount > 0 && "Bad max count");
- uint64_t Scale = calculateCountScale(MaxCount);
- SmallVector<unsigned, 4> Weights;
- for (const auto &ECI : EdgeCounts)
- Weights.push_back(scaleBranchCount(ECI, Scale));
+ auto Weights = downscaleWeights(EdgeCounts, MaxCount);
LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
: Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t TotalCount =
std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
[](uint64_t c1, uint64_t c2) { return c1 + c2; });
- Scale = calculateCountScale(WSum);
+ uint64_t Scale = calculateCountScale(WSum);
BranchProbability BP(scaleBranchCount(Weights[0], Scale),
scaleBranchCount(WSum, Scale));
std::string BranchProbStr;
More information about the llvm-commits
mailing list