[llvm] Change how branch weight is annotated for direct call (PR #90315)

William Junda Huang via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 29 22:47:53 PDT 2024


https://github.com/huangjd updated https://github.com/llvm/llvm-project/pull/90315

>From fee9217f9600643fb0884c69ee2f4322d1fab82a Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Fri, 26 Apr 2024 06:35:22 -0400
Subject: [PATCH 1/3] More accurate branch weight annotation for direct call if
 it matches an exact call target in the sample

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 31 +++++++++++++++++++----
 1 file changed, 26 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0b3a6931e779b6..0c3f3c0c161a69 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1617,12 +1617,13 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       for (auto &I : *BB) {
         if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
           continue;
-        if (!cast<CallBase>(I).getCalledFunction()) {
-          const DebugLoc &DLoc = I.getDebugLoc();
+        const DebugLoc &DLoc = I.getDebugLoc();
+        const DILocation *DIL = DLoc;
+        const FunctionSamples *FS = findFunctionSamples(I);
+        Function *Callee = cast<CallBase>(I).getCalledFunction();
+        if (!Callee) {
           if (!DLoc)
             continue;
-          const DILocation *DIL = DLoc;
-          const FunctionSamples *FS = findFunctionSamples(I);
           if (!FS)
             continue;
           auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
@@ -1659,7 +1660,26 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
+          uint32_t BranchWeight = static_cast<uint32_t>(BlockWeights[BB]);
+          // If I is a direct call and we found it has a sample with a matching
+          // call target, we should use its count instead because it is more
+          // precise.
+          if (DLoc && FS) {
+            auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+            // Account for stale profile matching.
+            Callsite = FS->mapIRLocToProfileLoc(Callsite);
+            auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
+            if (CallTargetMap) {
+              auto FindRes = CallTargetMap->find(
+                  FunctionSamples::UseMD5 ?
+                      FunctionId(MD5Hash(Callee->getName())) :
+                      FunctionId(Callee->getName()));
+              if (FindRes != CallTargetMap->end()) {
+                BranchWeight = FindRes->second;
+              }
+            }
+          }
+          setBranchWeights(I, {BranchWeight});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -2201,6 +2221,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
 
 bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
   LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
+
   DILocation2SampleMap.clear();
   // By default the entry count is initialized to -1, which will be treated
   // conservatively by getEntryCount as the same as unknown (None). This is

>From 25dd4042586b02284b34ba8f02da5be1278aed61 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Sat, 27 Apr 2024 02:35:57 -0400
Subject: [PATCH 2/3] Move logic to SampleProf, and encapsulate lookup methods
 for future use

---
 llvm/include/llvm/ProfileData/SampleProf.h |  7 +++++++
 llvm/lib/ProfileData/SampleProf.cpp        | 19 +++++++++++++++++++
 llvm/lib/Transforms/IPO/SampleProfile.cpp  | 21 ++++++++++-----------
 3 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 51d590be124f10..10c4f0f4bfe120 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -902,6 +902,13 @@ class FunctionSamples {
     return Ret->second.getCallTargets();
   }
 
+  /// Returns the call target count of a specific function \p CalleeName at a
+  /// given location \p Callsite. Returns nullptr if not found. A \p Remapper
+  /// can be optionally provided to look up a name equivalent to \p CalleeName.
+  const uint64_t *
+  findCallTargetAt(const LineLocation &Callsite, StringRef CalleeName,
+                   SampleProfileReaderItaniumRemapper *Remapper) const;
+
   /// Return the function samples at the given callsite location.
   FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
     return CallsiteSamples[mapIRLocToProfileLoc(Loc)];
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 59fa71899ed47b..8d3b52075641ae 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -275,6 +275,25 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
   }
 }
 
+const uint64_t *FunctionSamples::findCallTargetAt(const LineLocation &Callsite,
+    StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const {
+  const auto &FindRes = BodySamples.find(mapIRLocToProfileLoc(Callsite));
+  if (FindRes == BodySamples.end())
+    return nullptr;
+  const auto &CallTargets = FindRes->second.getCallTargets();
+  const auto &Ret = CallTargets.find(getRepInFormat(CalleeName));
+  if (Ret != CallTargets.end())
+    return &Ret->second;
+  if (Remapper && !UseMD5) {
+    if (auto RemappedName = Remapper->lookUpNameInProfile(CalleeName)) {
+      const auto &Ret = CallTargets.find(getRepInFormat(*RemappedName));
+      if (Ret != CallTargets.end())
+        return &Ret->second;
+    }
+  }
+  return nullptr;
+}
+
 const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
     const LineLocation &Loc, StringRef CalleeName,
     SampleProfileReaderItaniumRemapper *Remapper) const {
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0c3f3c0c161a69..a927d5ac2627fb 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1666,16 +1666,16 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           // precise.
           if (DLoc && FS) {
             auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
-            // Account for stale profile matching.
-            Callsite = FS->mapIRLocToProfileLoc(Callsite);
-            auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
-            if (CallTargetMap) {
-              auto FindRes = CallTargetMap->find(
-                  FunctionSamples::UseMD5 ?
-                      FunctionId(MD5Hash(Callee->getName())) :
-                      FunctionId(Callee->getName()));
-              if (FindRes != CallTargetMap->end()) {
-                BranchWeight = FindRes->second;
+            if (const uint64_t *CallTargetCount = FS->findCallTargetAt(
+                    Callsite, Callee->getName(), Reader->getRemapper())) {
+              BranchWeight = static_cast<uint32_t>(*CallTargetCount);
+              if (!FunctionSamples::ProfileIsCS) {
+                if (const FunctionSamples *InlinedCallee =
+                        FS->findFunctionSamplesAt(Callsite, Callee->getName(),
+                                                  Reader->getRemapper())) {
+                  BranchWeight += static_cast<uint32_t>(
+                      InlinedCallee->getHeadSamplesEstimate());
+                }
               }
             }
           }
@@ -2221,7 +2221,6 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
 
 bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
   LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
-
   DILocation2SampleMap.clear();
   // By default the entry count is initialized to -1, which will be treated
   // conservatively by getEntryCount as the same as unknown (None). This is

>From 5b0d19ed558b61f76c34656ff1a58d97a4e91f17 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Tue, 30 Apr 2024 01:47:34 -0400
Subject: [PATCH 3/3] Added test case

---
 .../Inputs/direct-call-accurate-count.prof    |  4 +++
 .../direct-call-accurate-count.ll             | 36 +++++++++++++++++++
 2 files changed, 40 insertions(+)
 create mode 100644 llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
 create mode 100644 llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll

diff --git a/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
new file mode 100644
index 00000000000000..ae1e6f8263786c
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/Inputs/direct-call-accurate-count.prof
@@ -0,0 +1,4 @@
+test:10000:1000
+ 2: 456 callee:123
+test2:20000:2000
+ 3: 50 callee:30
diff --git a/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll
new file mode 100644
index 00000000000000..38b84e28b0560e
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/direct-call-accurate-count.ll
@@ -0,0 +1,36 @@
+; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%S/Inputs/direct-call-accurate-count.prof -salvage-stale-profile | FileCheck %s
+; RUN: llvm-profdata merge --sample --extbinary --use-md5 -output=%t %S/Inputs/direct-call-accurate-count.prof
+; RUN: opt -S %s -passes=sample-profile -sample-profile-file=%t -salvage-stale-profile | FileCheck %s
+
+declare void @callee() #0
+
+; CHECK-LABEL: @test
+define dso_local void @test() #1 !dbg !3 {
+  call void @callee(), !dbg !4
+; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT1:[0-9]+]]
+  ret void
+}
+
+; With stale profile
+; CHECK-LABEL: @test2
+define dso_local void @test2() #1 !dbg !5 {
+  call void @callee(), !dbg !6
+; CHECK: call void @callee(), !dbg !{{[0-9]+}}, !prof ![[BRANCH_WEIGHT2:[0-9]+]]
+  ret void
+}
+
+attributes #0 = { "use-sample-profile" }
+attributes #1 = { "use-sample-profile" }
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1)
+!1 = !DIFile(filename: "test.cpp", directory: "/")
+!2 = !{i32 2, !"Debug Info Version", i32 3}
+!3 = distinct !DISubprogram(name: "test", scope: !1, file: !1, line: 1, unit: !0)
+!4 = !DILocation(line: 3, column: 4, scope: !3)
+!5 = distinct !DISubprogram(name: "test2", scope: !1, file: !1, line: 11, unit: !0)
+!6 = !DILocation(line: 15, column: 4, scope: !5)
+; CHECK-DAG: ![[BRANCH_WEIGHT1]] = !{!"branch_weights", i32 123}
+; CHECK-DAG: ![[BRANCH_WEIGHT2]] = !{!"branch_weights", i32 30}



More information about the llvm-commits mailing list