[llvm] Change how branch weight is annotated for direct call (PR #90315)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 23:46:10 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-pgo

Author: William Junda Huang (huangjd)

<details>
<summary>Changes</summary>

Currently when the Sample Profile reader matches a line with a direct function call, it annotates the function call branch weight with the sample's count. However if the sample contains call targets and one of them matches exactly with the direct function call, the call target count seems to be the more correct number. The inconsistency between sample count and (the sum of) call target count comes from multiple sources, for example if the profile is a product of merging/downsampling multiple profiles.

Example: 
```
main:1000:500
 1: 100 DirectCallee: 50
```

```
void main() {
  DirectCallee();
}
```

In this case it makes more sense to consider `DirectCallee`  having a branch weight of 50 not 100.  


---
Full diff: https://github.com/llvm/llvm-project/pull/90315.diff


3 Files Affected:

- (modified) llvm/include/llvm/ProfileData/SampleProf.h (+7) 
- (modified) llvm/lib/ProfileData/SampleProf.cpp (+19) 
- (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+25-5) 


``````````diff
diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index 51d590be124f10..10c4f0f4bfe120 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -902,6 +902,13 @@ class FunctionSamples {
     return Ret->second.getCallTargets();
   }
 
+  /// Returns the call target count of a specific function \p CalleeName at a
+  /// given location \p Callsite. Returns nullptr if not found. A \p Remapper
+  /// can be optionally provided to look up a name equivalent to \p CalleeName.
+  const uint64_t *
+  findCallTargetAt(const LineLocation &Callsite, StringRef CalleeName,
+                   SampleProfileReaderItaniumRemapper *Remapper) const;
+
   /// Return the function samples at the given callsite location.
   FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
     return CallsiteSamples[mapIRLocToProfileLoc(Loc)];
diff --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 59fa71899ed47b..8d3b52075641ae 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -275,6 +275,25 @@ void FunctionSamples::findAllNames(DenseSet<FunctionId> &NameSet) const {
   }
 }
 
+const uint64_t *FunctionSamples::findCallTargetAt(const LineLocation &Callsite,
+    StringRef CalleeName, SampleProfileReaderItaniumRemapper *Remapper) const {
+  const auto &FindRes = BodySamples.find(mapIRLocToProfileLoc(Callsite));
+  if (FindRes == BodySamples.end())
+    return nullptr;
+  const auto &CallTargets = FindRes->second.getCallTargets();
+  const auto &Ret = CallTargets.find(getRepInFormat(CalleeName));
+  if (Ret != CallTargets.end())
+    return &Ret->second;
+  if (Remapper && !UseMD5) {
+    if (auto RemappedName = Remapper->lookUpNameInProfile(CalleeName)) {
+      const auto &Ret = CallTargets.find(getRepInFormat(*RemappedName));
+      if (Ret != CallTargets.end())
+        return &Ret->second;
+    }
+  }
+  return nullptr;
+}
+
 const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
     const LineLocation &Loc, StringRef CalleeName,
     SampleProfileReaderItaniumRemapper *Remapper) const {
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0b3a6931e779b6..a927d5ac2627fb 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1617,12 +1617,13 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       for (auto &I : *BB) {
         if (!isa<CallInst>(I) && !isa<InvokeInst>(I))
           continue;
-        if (!cast<CallBase>(I).getCalledFunction()) {
-          const DebugLoc &DLoc = I.getDebugLoc();
+        const DebugLoc &DLoc = I.getDebugLoc();
+        const DILocation *DIL = DLoc;
+        const FunctionSamples *FS = findFunctionSamples(I);
+        Function *Callee = cast<CallBase>(I).getCalledFunction();
+        if (!Callee) {
           if (!DLoc)
             continue;
-          const DILocation *DIL = DLoc;
-          const FunctionSamples *FS = findFunctionSamples(I);
           if (!FS)
             continue;
           auto CallSite = FunctionSamples::getCallSiteIdentifier(DIL);
@@ -1659,7 +1660,26 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
           else if (OverwriteExistingWeights)
             I.setMetadata(LLVMContext::MD_prof, nullptr);
         } else if (!isa<IntrinsicInst>(&I)) {
-          setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
+          uint32_t BranchWeight = static_cast<uint32_t>(BlockWeights[BB]);
+          // If I is a direct call and we found it has a sample with a matching
+          // call target, we should use its count instead because it is more
+          // precise.
+          if (DLoc && FS) {
+            auto Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+            if (const uint64_t *CallTargetCount = FS->findCallTargetAt(
+                    Callsite, Callee->getName(), Reader->getRemapper())) {
+              BranchWeight = static_cast<uint32_t>(*CallTargetCount);
+              if (!FunctionSamples::ProfileIsCS) {
+                if (const FunctionSamples *InlinedCallee =
+                        FS->findFunctionSamplesAt(Callsite, Callee->getName(),
+                                                  Reader->getRemapper())) {
+                  BranchWeight += static_cast<uint32_t>(
+                      InlinedCallee->getHeadSamplesEstimate());
+                }
+              }
+            }
+          }
+          setBranchWeights(I, {BranchWeight});
         }
       }
     } else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/90315


More information about the llvm-commits mailing list