[llvm] Change how branch weight is annotated for direct call (PR #90315)
William Junda Huang via llvm-commits
llvm-commits at lists.llvm.org
Wed May 1 14:58:10 PDT 2024
https://github.com/huangjd updated https://github.com/llvm/llvm-project/pull/90315
>From fee9217f9600643fb0884c69ee2f4322d1fab82a Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Fri, 26 Apr 2024 06:35:22 -0400
Subject: [PATCH 1/4] More accurate branch weight annotation for direct call if
it matches an exact call target in the sample
---
llvm/lib/Transforms/IPO/SampleProfile.cpp | 31 +++++++++++++++++++----
1 file changed, 26 insertions(+), 5 deletions(-)
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0b3a6931e779b6..0c3f3c0c161a69 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1617,12 +1617,13 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
for (auto &I : *BB) {
if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
continue;
- if (!cast<CallBase>(I).getCalledFunction()) {
- const DebugLoc &DLoc = I.getDebugLoc();
+ const DebugLoc &DLoc = I.getDebugLoc();
+ const DILocation *DIL = DLoc;
+ const FunctionSamples *FS = findFunctionSamples(I);
+ Function *Callee = cast<CallBase>(I).getCalledFunction();
+ if (!Callee) {
if (!DLoc)
continue;
- const DILocation *DIL = DLoc;
- const FunctionSamples *FS = findFunctionSamples(I);
if (!FS)
continue;
auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
@@ -1659,7 +1660,26 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
- setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
+ uint32_t BranchWeight = static_cast<uint32_t>(BlockWeights[BB]);
+ // If I is a direct call and we found it has a sample with a matching
+ // call target, we should use its count instead because it is more
+ // precise.
+ if (DLoc && FS) {
+ auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+ // Account for stale profile matching.
+ Callsite = FS->mapIRLocToProfileLoc(Callsite);
+ auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
+ if (CallTargetMap) {
+ auto FindRes = CallTargetMap->find(
+ FunctionSamples::UseMD5 ?
+ FunctionId(MD5Hash(Callee->getName())) :
+ FunctionId(Callee->getName()));
+ if (FindRes != CallTargetMap->end()) {
+ BranchWeight = FindRes->second;
+ }
+ }
+ }
+ setBranchWeights(I, {BranchWeight});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -2201,6 +2221,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
+
DILocation2SampleMap.clear();
// By default the entry count is initialized to -1, which will be treated
// conservatively by getEntryCount as the same as unknown (None). This is
>From 25dd4042586b02284b34ba8f02da5be1278aed61 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Sat, 27 Apr 2024 02:35:57 -0400
Subject: [PATCH 2/4] Move logic to SampleProf, and encapsulate lookup methods
for future use
---
llvm/include/llvm/ProfileData/SampleProf.h | 7 +++++++
llvm/lib/ProfileData/SampleProf.cpp | 19 +++++++++++++++++++
llvm/lib/Transforms/IPO/SampleProfile.cpp | 21 ++++++++++-----------
3 files changed, 36 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 51d590be124f10..10c4f0f4bfe120 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -902,6 +902,13 @@ class FunctionSamples {
return Ret->second.getCallTargets();
}
+ /// Returns the call target count of a specific function \p CalleeName at a
+ /// given location \p Callsite. Returns nullptr if not found. A \p Remapper
+ /// can be optionally provided to look up a name equivalent to \p CalleeName.
+ const uint64_t *
+ findCallTargetAt(const LineLocation &Callsite, StringRef CalleeName,
+ SampleProfileReaderItaniumRemapper *Remapper) const;
+
/// Return the function samples at the given callsite location.
FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
return CallsiteSamples[mapIRLocToProfileLoc(Loc)];
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 59fa71899ed47b..8d3b52075641ae 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -275,6 +275,25 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
}
}
+const uint64_t *FunctionSamples::findCallTargetAt(const LineLocation &Callsite,
+ StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const {
+ const auto &FindRes = BodySamples.find(mapIRLocToProfileLoc(Callsite));
+ if (FindRes == BodySamples.end())
+ return nullptr;
+ const auto &CallTargets = FindRes->second.getCallTargets();
+ const auto &Ret = CallTargets.find(getRepInFormat(CalleeName));
+ if (Ret != CallTargets.end())
+ return &Ret->second;
+ if (Remapper && !UseMD5) {
+ if (auto RemappedName = Remapper->lookUpNameInProfile(CalleeName)) {
+ const auto &Ret = CallTargets.find(getRepInFormat(*RemappedName));
+ if (Ret != CallTargets.end())
+ return &Ret->second;
+ }
+ }
+ return nullptr;
+}
+
const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
const LineLocation &Loc, StringRef CalleeName,
SampleProfileReaderItaniumRemapper *Remapper) const {
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0c3f3c0c161a69..a927d5ac2627fb 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1666,16 +1666,16 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
// precise.
if (DLoc && FS) {
auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
- // Account for stale profile matching.
- Callsite = FS->mapIRLocToProfileLoc(Callsite);
- auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
- if (CallTargetMap) {
- auto FindRes = CallTargetMap->find(
- FunctionSamples::UseMD5 ?
- FunctionId(MD5Hash(Callee->getName())) :
- FunctionId(Callee->getName()));
- if (FindRes != CallTargetMap->end()) {
- BranchWeight = FindRes->second;
+ if (const uint64_t *CallTargetCount = FS->findCallTargetAt(
+ Callsite, Callee->getName(), Reader->getRemapper())) {
+ BranchWeight = static_cast<uint32_t>(*CallTargetCount);
+ if (!FunctionSamples::ProfileIsCS) {
+ if (const FunctionSamples *InlinedCallee =
+ FS->findFunctionSamplesAt(Callsite, Callee->getName(),
+ Reader->getRemapper())) {
+ BranchWeight += static_cast<uint32_t>(
+ InlinedCallee->getHeadSamplesEstimate());
+ }
}
}
}
@@ -2221,7 +2221,6 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
-
DILocation2SampleMap.clear();
// By default the entry count is initialized to -1, which will be treated
// conservatively by getEntryCount as the same as unknown (None). This is
>From 5b0d19ed558b61f76c34656ff1a58d97a4e91f17 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Tue, 30 Apr 2024 01:47:34 -0400
Subject: [PATCH 3/4] Added test case
---
.../Inputs/direct-call-accurate-count.prof | 4 +++
.../direct-call-accurate-count.ll | 36 +++++++++++++++++++
2 files changed, 40 insertions(+)
create mode 100644 llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
create mode 100644 llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
new file mode 100644
index 00000000000000..ae1e6f8263786c
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
@@ -0,0 +1,4 @@
+test:10000:1000
+ 2: 456 callee:123
+test2:20000:2000
+ 3: 50 callee:30
diff --git a/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll
new file mode 100644
index 00000000000000..38b84e28b0560e
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll
@@ -0,0 +1,36 @@
+; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%S/Inputs/direct-call-accurate-count.prof -salvage-stale-profile | FileCheck %s
+; RUN: llvm-profdata merge --sample --extbinary --use-md5 -output=%t %S/Inputs/direct-call-accurate-count.prof
+; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%t -salvage-stale-profile | FileCheck %s
+
+declare void @callee() #0
+
+; CHECK-LABEL: @test
+define dso_local void @test() #1 !dbg !3 {
+ call void @callee(), !dbg !4
+; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT1:[0-9]+]]
+ ret void
+}
+
+; With stale profile
+; CHECK-LABEL: @test2
+define dso_local void @test2() #1 !dbg !5 {
+ call void @callee(), !dbg !6
+; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT2:[0-9]+]]
+ ret void
+}
+
+attributes #0 = { "use-sample-profile" }
+attributes #1 = { "use-sample-profile" }
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1)
+!1 = !DIFile(filename: "test.cpp", directory: "/")
+!2 = !{i32 2, !"Debug Info Version", i32 3}
+!3 = distinct !DISubprogram(name: "test", scope: !1, file: !1, line: 1, unit: !0)
+!4 = !DILocation(line: 3, column: 4, scope: !3)
+!5 = distinct !DISubprogram(name: "test2", scope: !1, file: !1, line: 11, unit: !0)
+!6 = !DILocation(line: 15, column: 4, scope: !5)
+; CHECK-DAG: ![[BRANCH_WEIGHT1]] = !{!"branch_weights", i32 123}
+; CHECK-DAG: ![[BRANCH_WEIGHT2]] = !{!"branch_weights", i32 30}
>From 648689357ca4c07baac1c7310529f13ca8e380a3 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Wed, 1 May 2024 17:56:42 -0400
Subject: [PATCH 4/4] Move branch weight logic back to getInstWeight because in
branch weight annotation the entire CFG is considered for weight propagation
so it cannot be simply overriden at there.
---
.../Utils/SampleProfileLoaderBaseImpl.h | 13 ++++++++
llvm/lib/Transforms/IPO/SampleProfile.cpp | 30 ++++---------------
2 files changed, 18 insertions(+), 25 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
index 7c725a3c1216cb..844531d8c2db96 100644
--- a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
+++ b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
@@ -408,6 +408,19 @@ SampleProfileLoaderBaseImpl<BT>::getInstWeightImpl(const InstructionT &Inst) {
Discriminator = DIL->getBaseDiscriminator();
ErrorOr<uint64_t> R = FS->findSamplesAt(LineOffset, Discriminator);
+ if constexpr (std::is_base_of_v<llvm::Instruction, InstructionT>) {
+ // If Inst is a direct function call and matches a sample, we should check
+ // if the sample contains call target count of the matching function, and
+ // use that count value instead of sample count, because sample count may
+ // contain superfluous numbers from other non-matching call targets as a
+ // result of merging profiles.
+ if (const CallInst *Call = dyn_cast<CallInst>(&Inst))
+ if (const Function *Callee = Call->getCalledFunction())
+ if (const uint64_t *CallTargetCount =
+ FS->findCallTargetAt(LineLocation(LineOffset, Discriminator),
+ Callee->getName(), Reader->getRemapper()))
+ R.get() = *CallTargetCount;
+ }
if (R) {
bool FirstMark =
CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get());
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index a927d5ac2627fb..0b3a6931e779b6 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1617,13 +1617,12 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
for (auto &I : *BB) {
if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
continue;
- const DebugLoc &DLoc = I.getDebugLoc();
- const DILocation *DIL = DLoc;
- const FunctionSamples *FS = findFunctionSamples(I);
- Function *Callee = cast<CallBase>(I).getCalledFunction();
- if (!Callee) {
+ if (!cast<CallBase>(I).getCalledFunction()) {
+ const DebugLoc &DLoc = I.getDebugLoc();
if (!DLoc)
continue;
+ const DILocation *DIL = DLoc;
+ const FunctionSamples *FS = findFunctionSamples(I);
if (!FS)
continue;
auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
@@ -1660,26 +1659,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
else if (OverwriteExistingWeights)
I.setMetadata(LLVMContext::MD_prof, nullptr);
} else if (!isa<IntrinsicInst>(&I)) {
- uint32_t BranchWeight = static_cast<uint32_t>(BlockWeights[BB]);
- // If I is a direct call and we found it has a sample with a matching
- // call target, we should use its count instead because it is more
- // precise.
- if (DLoc && FS) {
- auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
- if (const uint64_t *CallTargetCount = FS->findCallTargetAt(
- Callsite, Callee->getName(), Reader->getRemapper())) {
- BranchWeight = static_cast<uint32_t>(*CallTargetCount);
- if (!FunctionSamples::ProfileIsCS) {
- if (const FunctionSamples *InlinedCallee =
- FS->findFunctionSamplesAt(Callsite, Callee->getName(),
- Reader->getRemapper())) {
- BranchWeight += static_cast<uint32_t>(
- InlinedCallee->getHeadSamplesEstimate());
- }
- }
- }
- }
- setBranchWeights(I, {BranchWeight});
+ setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
}
}
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
More information about the llvm-commits
mailing list