[llvm] [nfc][PGO]Factor out profile scaling into a standalone helper function (PR #83780)
Mingming Liu via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 4 00:06:36 PST 2024
https://github.com/minglotus-6 created https://github.com/llvm/llvm-project/pull/83780
- Put the helper function in `ProfDataUtil.h/cpp`, which is already a dependency of `Instructions.cpp`
- The helper function could be re-used to update profiles of `InvokeInst` (in a follow-up pull request)
>From e7c6220a4f2c42c94fa33fd8c61da569ef67d4db Mon Sep 17 00:00:00 2001
From: mingmingl <mingmingl at google.com>
Date: Mon, 4 Mar 2024 00:04:11 -0800
Subject: [PATCH] [nfc][PGO]Factor out profile scaling into a standalone
function
---
llvm/include/llvm/IR/ProfDataUtils.h | 3 ++
llvm/lib/IR/Instructions.cpp | 46 +-------------------------
llvm/lib/IR/ProfDataUtils.cpp | 48 ++++++++++++++++++++++++++++
3 files changed, 52 insertions(+), 45 deletions(-)
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 255fa2ff1c7906..c0897408986fb3 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -108,5 +108,8 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
/// a `prof` metadata reference to instruction `I`.
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
+/// Scaling the profile data attached to 'I' using the ratio of S/T.
+void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+
} // namespace llvm
#endif
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 42cdcad78228f6..9ae71acd523c36 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -825,15 +825,6 @@ CallInst *CallInst::Create(CallInst *CI, ArrayRef<OperandBundleDef> OpB,
// of S/T. The meaning of "branch_weights" meta data for call instruction is
// transfered to represent call count.
void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
- auto *ProfileData = getMetadata(LLVMContext::MD_prof);
- if (ProfileData == nullptr)
- return;
-
- auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
- if (!ProfDataName || (!ProfDataName->getString().equals("branch_weights") &&
- !ProfDataName->getString().equals("VP")))
- return;
-
if (T == 0) {
LLVM_DEBUG(dbgs() << "Attempting to update profile weights will result in "
"div by 0. Ignoring. Likely the function "
@@ -842,42 +833,7 @@ void CallInst::updateProfWeight(uint64_t S, uint64_t T) {
"with non-zero prof info.");
return;
}
-
- MDBuilder MDB(getContext());
- SmallVector<Metadata *, 3> Vals;
- Vals.push_back(ProfileData->getOperand(0));
- APInt APS(128, S), APT(128, T);
- if (ProfDataName->getString().equals("branch_weights") &&
- ProfileData->getNumOperands() > 0) {
- // Using APInt::div may be expensive, but most cases should fit 64 bits.
- APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
- ->getValue()
- .getZExtValue());
- Val *= APS;
- Vals.push_back(MDB.createConstant(
- ConstantInt::get(Type::getInt32Ty(getContext()),
- Val.udiv(APT).getLimitedValue(UINT32_MAX))));
- } else if (ProfDataName->getString().equals("VP"))
- for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
- // The first value is the key of the value profile, which will not change.
- Vals.push_back(ProfileData->getOperand(i));
- uint64_t Count =
- mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
- ->getValue()
- .getZExtValue();
- // Don't scale the magic number.
- if (Count == NOMORE_ICP_MAGICNUM) {
- Vals.push_back(ProfileData->getOperand(i + 1));
- continue;
- }
- // Using APInt::div may be expensive, but most cases should fit 64 bits.
- APInt Val(128, Count);
- Val *= APS;
- Vals.push_back(MDB.createConstant(
- ConstantInt::get(Type::getInt64Ty(getContext()),
- Val.udiv(APT).getLimitedValue())));
- }
- setMetadata(LLVMContext::MD_prof, MDNode::get(getContext(), Vals));
+ scaleProfData(*this, S, T);
}
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp
index b1a10d0ce5a522..dc86f4204b1a1d 100644
--- a/llvm/lib/IR/ProfDataUtils.cpp
+++ b/llvm/lib/IR/ProfDataUtils.cpp
@@ -190,4 +190,52 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}
+void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
+ assert(T != 0 && "Caller should guarantee");
+ auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
+ if (ProfileData == nullptr)
+ return;
+
+ auto *ProfDataName = dyn_cast<MDString>(ProfileData->getOperand(0));
+ if (!ProfDataName || (!ProfDataName->getString().equals("branch_weights") &&
+ !ProfDataName->getString().equals("VP")))
+ return;
+
+ LLVMContext &C = I.getContext();
+
+ MDBuilder MDB(C);
+ SmallVector<Metadata *, 3> Vals;
+ Vals.push_back(ProfileData->getOperand(0));
+ APInt APS(128, S), APT(128, T);
+ if (ProfDataName->getString().equals("branch_weights") &&
+ ProfileData->getNumOperands() > 0) {
+ // Using APInt::div may be expensive, but most cases should fit 64 bits.
+ APInt Val(128, mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(1))
+ ->getValue()
+ .getZExtValue());
+ Val *= APS;
+ Vals.push_back(MDB.createConstant(ConstantInt::get(
+ Type::getInt32Ty(C), Val.udiv(APT).getLimitedValue(UINT32_MAX))));
+ } else if (ProfDataName->getString().equals("VP"))
+ for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) {
+ // The first value is the key of the value profile, which will not change.
+ Vals.push_back(ProfileData->getOperand(i));
+ uint64_t Count =
+ mdconst::dyn_extract<ConstantInt>(ProfileData->getOperand(i + 1))
+ ->getValue()
+ .getZExtValue();
+ // Don't scale the magic number.
+ if (Count == NOMORE_ICP_MAGICNUM) {
+ Vals.push_back(ProfileData->getOperand(i + 1));
+ continue;
+ }
+ // Using APInt::div may be expensive, but most cases should fit 64 bits.
+ APInt Val(128, Count);
+ Val *= APS;
+ Vals.push_back(MDB.createConstant(ConstantInt::get(
+ Type::getInt64Ty(C), Val.udiv(APT).getLimitedValue())));
+ }
+ I.setMetadata(LLVMContext::MD_prof, MDNode::get(C, Vals));
+}
+
} // namespace llvm
More information about the llvm-commits
mailing list