[llvm] Change how branch weight is annotated for direct call (PR #90315)

William Junda Huang via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 23:45:50 PDT 2024


https://github.com/huangjd created https://github.com/llvm/llvm-project/pull/90315

Currently when the Sample Profile reader matches a line with a direct function call, it annotates the function call branch weight with the sample's count. However if the sample contains call targets and one of them matches exactly with the direct function call, the call target count seems to be the more correct number. The inconsistency between sample count and (the sum of) call target count comes from multiple sources, for example if the profile is a product of merging/downsampling multiple profiles.

Example: 
```
main:1000:500
 1: 100 DirectCallee: 50
```

```
void main() {
  DirectCallee();
}
```

In this case it makes more sense to consider `DirectCallee`  having a branch weight of 50 not 100.  


>From fee9217f9600643fb0884c69ee2f4322d1fab82a Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Fri, 26 Apr 2024 06:35:22 -0400
Subject: [PATCH 1/2] More accurate branch weight annotation for direct call if
 it matches an exact call target in the sample

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 31 +++++++++++++++++++----
 1 file changed, 26 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0b3a6931e779b6..0c3f3c0c161a69 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1617,12 +1617,13 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       for (auto &I : *BB) {
         if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
           continue;
-        if (!cast<CallBase>(I).getCalledFunction()) {
-          const DebugLoc &DLoc = I.getDebugLoc();
+        const DebugLoc &DLoc = I.getDebugLoc();
+        const DILocation *DIL = DLoc;
+        const FunctionSamples *FS = findFunctionSamples(I);
+        Function *Callee = cast<CallBase>(I).getCalledFunction();
+        if (!Callee) {
           if (!DLoc)
             continue;
-          const DILocation *DIL = DLoc;
-          const FunctionSamples *FS = findFunctionSamples(I);
           if (!FS)
             continue;
           auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
@@ -1659,7 +1660,26 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
+          uint32_t BranchWeight = static_cast<uint32_t>(BlockWeights[BB]);
+          // If I is a direct call and we found it has a sample with a matching
+          // call target, we should use its count instead because it is more
+          // precise.
+          if (DLoc && FS) {
+            auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+            // Account for stale profile matching.
+            Callsite = FS->mapIRLocToProfileLoc(Callsite);
+            auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
+            if (CallTargetMap) {
+              auto FindRes = CallTargetMap->find(
+                  FunctionSamples::UseMD5 ?
+                      FunctionId(MD5Hash(Callee->getName())) :
+                      FunctionId(Callee->getName()));
+              if (FindRes != CallTargetMap->end()) {
+                BranchWeight = FindRes->second;
+              }
+            }
+          }
+          setBranchWeights(I, {BranchWeight});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -2201,6 +2221,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
 
 bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
   LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
+
   DILocation2SampleMap.clear();
   // By default the entry count is initialized to -1, which will be treated
   // conservatively by getEntryCount as the same as unknown (None). This is

>From 25dd4042586b02284b34ba8f02da5be1278aed61 Mon Sep 17 00:00:00 2001
From: William Huang <williamjhuang at google.com>
Date: Sat, 27 Apr 2024 02:35:57 -0400
Subject: [PATCH 2/2] Move logic to SampleProf, and encapsulate lookup methods
 for future use

---
 llvm/include/llvm/ProfileData/SampleProf.h |  7 +++++++
 llvm/lib/ProfileData/SampleProf.cpp        | 19 +++++++++++++++++++
 llvm/lib/Transforms/IPO/SampleProfile.cpp  | 21 ++++++++++-----------
 3 files changed, 36 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 51d590be124f10..10c4f0f4bfe120 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -902,6 +902,13 @@ class FunctionSamples {
     return Ret->second.getCallTargets();
   }
 
+  /// Returns the call target count of a specific function \p CalleeName at a
+  /// given location \p Callsite. Returns nullptr if not found. A \p Remapper
+  /// can be optionally provided to look up a name equivalent to \p CalleeName.
+  const uint64_t *
+  findCallTargetAt(const LineLocation &Callsite, StringRef CalleeName,
+                   SampleProfileReaderItaniumRemapper *Remapper) const;
+
   /// Return the function samples at the given callsite location.
   FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
     return CallsiteSamples[mapIRLocToProfileLoc(Loc)];
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 59fa71899ed47b..8d3b52075641ae 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -275,6 +275,25 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
   }
 }
 
+const uint64_t *FunctionSamples::findCallTargetAt(const LineLocation &Callsite,
+    StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const {
+  const auto &FindRes = BodySamples.find(mapIRLocToProfileLoc(Callsite));
+  if (FindRes == BodySamples.end())
+    return nullptr;
+  const auto &CallTargets = FindRes->second.getCallTargets();
+  const auto &Ret = CallTargets.find(getRepInFormat(CalleeName));
+  if (Ret != CallTargets.end())
+    return &Ret->second;
+  if (Remapper && !UseMD5) {
+    if (auto RemappedName = Remapper->lookUpNameInProfile(CalleeName)) {
+      const auto &Ret = CallTargets.find(getRepInFormat(*RemappedName));
+      if (Ret != CallTargets.end())
+        return &Ret->second;
+    }
+  }
+  return nullptr;
+}
+
 const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
     const LineLocation &Loc, StringRef CalleeName,
     SampleProfileReaderItaniumRemapper *Remapper) const {
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0c3f3c0c161a69..a927d5ac2627fb 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1666,16 +1666,16 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           // precise.
           if (DLoc && FS) {
             auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
-            // Account for stale profile matching.
-            Callsite = FS->mapIRLocToProfileLoc(Callsite);
-            auto CallTargetMap = FS->findCallTargetMapAt(Callsite);
-            if (CallTargetMap) {
-              auto FindRes = CallTargetMap->find(
-                  FunctionSamples::UseMD5 ?
-                      FunctionId(MD5Hash(Callee->getName())) :
-                      FunctionId(Callee->getName()));
-              if (FindRes != CallTargetMap->end()) {
-                BranchWeight = FindRes->second;
+            if (const uint64_t *CallTargetCount = FS->findCallTargetAt(
+                    Callsite, Callee->getName(), Reader->getRemapper())) {
+              BranchWeight = static_cast<uint32_t>(*CallTargetCount);
+              if (!FunctionSamples::ProfileIsCS) {
+                if (const FunctionSamples *InlinedCallee =
+                        FS->findFunctionSamplesAt(Callsite, Callee->getName(),
+                                                  Reader->getRemapper())) {
+                  BranchWeight += static_cast<uint32_t>(
+                      InlinedCallee->getHeadSamplesEstimate());
+                }
               }
             }
           }
@@ -2221,7 +2221,6 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
 
 bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) {
   LLVM_DEBUG(dbgs() << "\n\nProcessing Function " << F.getName() << "\n");
-
   DILocation2SampleMap.clear();
   // By default the entry count is initialized to -1, which will be treated
   // conservatively by getEntryCount as the same as unknown (None). This is



More information about the llvm-commits mailing list