[llvm] 148cceb - [CSSPGO] Refactoring SampleProfileMatcher::runOnFunction

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 30 18:01:34 PDT 2023


Author: wlei
Date: 2023-08-30T18:00:23-07:00
New Revision: 148cceb0d6b5229fd57edd81a3c795da1a0cdbb1

URL: https://github.com/llvm/llvm-project/commit/148cceb0d6b5229fd57edd81a3c795da1a0cdbb1
DIFF: https://github.com/llvm/llvm-project/commit/148cceb0d6b5229fd57edd81a3c795da1a0cdbb1.diff

LOG: [CSSPGO] Refactoring SampleProfileMatcher::runOnFunction

- rename `IRLocation` --> `IRAnchors`,  `ProfileLocation` --> `ProfileAnchors`
- reorganize runOnFunction, fact out the finding IR anchors code into `findIRAnchors`
- introduce a new function `findProfileAnchors` to populate the profile related anchors, the result is saved into `ProfileAnchors`, it's later used for both mismatch report and matching, this can avoid to parse the `getBodySamples` and `getCallsiteSamples` for multiple times.
- move the `MatchedCallsiteLocs` stuffs from `findIRAnchors` to `countProfileMismatches` so that all the staleness metrics report are computed in one function.
- move all matching related into `runStaleProfileMatching`, and move all mismatching report into `countProfileMismatches`

Reviewed By: wenlei

Differential Revision: https://reviews.llvm.org/D158817

Added: 
    

Modified: 
    llvm/lib/Transforms/IPO/SampleProfile.cpp
    llvm/test/Transforms/SampleProfile/profile-mismatch.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 85431fb32b4b4d..38a21240c47dd8 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -458,6 +458,11 @@ class SampleProfileMatcher {
   uint64_t MismatchedFuncHashSamples = 0;
   uint64_t TotalFuncHashSamples = 0;
 
+  // A dummy name for unknown indirect callee, used to 
diff erentiate from a
+  // non-call instruction that also has an empty callee name.
+  static constexpr const char *UnknownIndirectCallee =
+      "unknown.indirect.callee";
+
 public:
   SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
                        const PseudoProbeManager *ProbeManager)
@@ -478,12 +483,20 @@ class SampleProfileMatcher {
     return nullptr;
   }
   void runOnFunction(const Function &F, const FunctionSamples &FS);
+  void findIRAnchors(const Function &F,
+                     std::map<LineLocation, StringRef> &IRAnchors);
+  void findProfileAnchors(const FunctionSamples &FS,
+                          std::map<LineLocation, StringSet<>> &ProfileAnchors);
   void countProfileMismatches(
+      const Function &F, const FunctionSamples &FS,
+      const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, StringSet<>> &ProfileAnchors);
+  void countProfileCallsiteMismatches(
       const FunctionSamples &FS,
-      const std::unordered_set<LineLocation, LineLocationHash>
-          &MatchedCallsiteLocs,
-      uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
+      const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, StringSet<>> &ProfileAnchors,
 
+      uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
   LocToLocMap &getIRToProfileLocationMap(const Function &F) {
     auto Ret = FuncMappings.try_emplace(
         FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -491,12 +504,9 @@ class SampleProfileMatcher {
   }
   void distributeIRToProfileLocationMap();
   void distributeIRToProfileLocationMap(FunctionSamples &FS);
-  void populateProfileCallsites(
-      const FunctionSamples &FS,
-      StringMap<std::set<LineLocation>> &CalleeToCallsitesMap);
   void runStaleProfileMatching(
-      const std::map<LineLocation, StringRef> &IRLocations,
-      StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, StringSet<>> &ProfileAnchors,
       LocToLocMap &IRToProfileLocationMap);
 };
 
@@ -2108,77 +2118,147 @@ bool SampleProfileLoader::doInitialization(Module &M,
   return true;
 }
 
-void SampleProfileMatcher::countProfileMismatches(
-    const FunctionSamples &FS,
-    const std::unordered_set<LineLocation, LineLocationHash>
-        &MatchedCallsiteLocs,
-    uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
+void SampleProfileMatcher::findIRAnchors(
+    const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
+  // Extract profile matching anchors in the IR.
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      // TODO: Support line-number based location(AutoFDO).
+      if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) {
+        if (std::optional<PseudoProbe> Probe = extractProbe(I))
+          IRAnchors.emplace(LineLocation(Probe->Id, 0), StringRef());
+      }
 
-  auto isInvalidLineOffset = [](uint32_t LineOffset) {
-    return LineOffset & 0x8000;
-  };
+      if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
+        continue;
 
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (auto &I : FS.getBodySamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
+      const auto *CB = dyn_cast<CallBase>(&I);
+      if (auto &DLoc = I.getDebugLoc()) {
+        LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc);
+        StringRef CalleeName = UnknownIndirectCallee;
+        if (Function *Callee = CB->getCalledFunction())
+          CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
 
-    uint64_t Count = I.second.getSamples();
-    if (!I.second.getCallTargets().empty()) {
-      TotalCallsiteSamples += Count;
-      FuncProfiledCallsites++;
-      if (!MatchedCallsiteLocs.count(Loc)) {
-        MismatchedCallsiteSamples += Count;
-        FuncMismatchedCallsites++;
+        // Force to overwrite the callee name in case any non-call location was
+        // written before.
+        auto R = IRAnchors.emplace(IRCallsite, CalleeName);
+        R.first->second = CalleeName;
+        assert((!FunctionSamples::ProfileIsProbeBased || R.second ||
+                R.first->second == CalleeName) &&
+               "Overwrite non-call or 
diff erent callee name location for "
+               "pseudo probe callsite");
       }
     }
   }
+}
 
-  for (auto &I : FS.getCallsiteSamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
+void SampleProfileMatcher::countProfileMismatches(
+    const Function &F, const FunctionSamples &FS,
+    const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, StringSet<>> &ProfileAnchors) {
+  bool IsFuncHashMismatch = false;
+  if (FunctionSamples::ProfileIsProbeBased) {
+    uint64_t Count = FS.getTotalSamples();
+    TotalFuncHashSamples += Count;
+    TotalProfiledFunc++;
+    if (!ProbeManager->profileIsValid(F, FS)) {
+      MismatchedFuncHashSamples += Count;
+      NumMismatchedFuncHash++;
+      IsFuncHashMismatch = true;
+    }
+  }
 
-    uint64_t Count = 0;
-    for (auto &FM : I.second) {
-      Count += FM.second.getHeadSamplesEstimate();
+  uint64_t FuncMismatchedCallsites = 0;
+  uint64_t FuncProfiledCallsites = 0;
+  countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors,
+                                 FuncMismatchedCallsites,
+                                 FuncProfiledCallsites);
+  TotalProfiledCallsites += FuncProfiledCallsites;
+  NumMismatchedCallsites += FuncMismatchedCallsites;
+  LLVM_DEBUG({
+    if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
+        FuncMismatchedCallsites)
+      dbgs() << "Function checksum is matched but there are "
+             << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
+             << " mismatched callsites.\n";
+  });
+}
+
+void SampleProfileMatcher::countProfileCallsiteMismatches(
+    const FunctionSamples &FS,
+    const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, StringSet<>> &ProfileAnchors,
+    uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites, those callsite samples will be discarded.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+
+    StringRef IRCalleeName;
+    const auto &IR = IRAnchors.find(Loc);
+    if (IR != IRAnchors.end())
+      IRCalleeName = IR->second;
+
+    // Compute number of samples in the original profile.
+    uint64_t CallsiteSamples = 0;
+    auto CTM = FS.findCallTargetMapAt(Loc);
+    if (CTM) {
+      for (const auto &I : CTM.get())
+        CallsiteSamples += I.second;
     }
-    TotalCallsiteSamples += Count;
+    const auto *FSMap = FS.findFunctionSamplesMapAt(Loc);
+    if (FSMap) {
+      for (const auto &I : *FSMap)
+        CallsiteSamples += I.second.getTotalSamples();
+    }
+
+    bool CallsiteIsMatched = false;
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == UnknownIndirectCallee)
+      CallsiteIsMatched = true;
+    else if (Callees.size() == 1 && Callees.count(IRCalleeName))
+      CallsiteIsMatched = true;
+
     FuncProfiledCallsites++;
-    if (!MatchedCallsiteLocs.count(Loc)) {
-      MismatchedCallsiteSamples += Count;
+    TotalCallsiteSamples += CallsiteSamples;
+    if (!CallsiteIsMatched) {
       FuncMismatchedCallsites++;
+      MismatchedCallsiteSamples += CallsiteSamples;
     }
   }
 }
 
-// Populate the anchors(direct callee name) from profile.
-void SampleProfileMatcher::populateProfileCallsites(
+void SampleProfileMatcher::findProfileAnchors(
     const FunctionSamples &FS,
-    StringMap<std::set<LineLocation>> &CalleeToCallsitesMap) {
+    std::map<LineLocation, StringSet<>> &ProfileAnchors) {
+  auto isInvalidLineOffset = [](uint32_t LineOffset) {
+    return LineOffset & 0x8000;
+  };
+
   for (const auto &I : FS.getBodySamples()) {
-    const auto &Loc = I.first;
-    const auto &CTM = I.second.getCallTargets();
-    // Filter out possible indirect calls, use direct callee name as anchor.
-    if (CTM.size() == 1) {
-      StringRef CalleeName = CTM.begin()->first();
-      const auto &Candidates = CalleeToCallsitesMap.try_emplace(
-          CalleeName, std::set<LineLocation>());
-      Candidates.first->second.insert(Loc);
+    const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
+    for (const auto &I : I.second.getCallTargets()) {
+      auto Ret = ProfileAnchors.try_emplace(Loc, StringSet<>());
+      Ret.first->second.insert(I.first());
     }
   }
 
   for (const auto &I : FS.getCallsiteSamples()) {
     const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
     const auto &CalleeMap = I.second;
-    // Filter out possible indirect calls, use direct callee name as anchor.
-    if (CalleeMap.size() == 1) {
-      StringRef CalleeName = CalleeMap.begin()->first;
-      const auto &Candidates = CalleeToCallsitesMap.try_emplace(
-          CalleeName, std::set<LineLocation>());
-      Candidates.first->second.insert(Loc);
+    for (const auto &I : CalleeMap) {
+      auto Ret = ProfileAnchors.try_emplace(Loc, StringSet<>());
+      Ret.first->second.insert(I.first);
     }
   }
 }
@@ -2201,12 +2281,27 @@ void SampleProfileMatcher::populateProfileCallsites(
 //   [1, 2, 3(foo), 4,  7,  8(bar), 9]
 // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
 void SampleProfileMatcher::runStaleProfileMatching(
-    const std::map<LineLocation, StringRef> &IRLocations,
-    StringMap<std::set<LineLocation>> &CalleeToCallsitesMap,
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, StringSet<>> &ProfileAnchors,
     LocToLocMap &IRToProfileLocationMap) {
+  LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
+                    << "\n");
   assert(IRToProfileLocationMap.empty() &&
          "Run stale profile matching only once per function");
 
+  StringMap<std::set<LineLocation>> CalleeToCallsitesMap;
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    const auto &Callees = I.second;
+    // Filter out possible indirect calls, use direct callee name as anchor.
+    if (Callees.size() == 1) {
+      StringRef CalleeName = Callees.begin()->first();
+      const auto &Candidates = CalleeToCallsitesMap.try_emplace(
+          CalleeName, std::set<LineLocation>());
+      Candidates.first->second.insert(Loc);
+    }
+  }
+
   auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) {
     // Skip the unchanged location mapping to save memory.
     if (From != To)
@@ -2217,18 +2312,18 @@ void SampleProfileMatcher::runStaleProfileMatching(
   int32_t LocationDelta = 0;
   SmallVector<LineLocation> LastMatchedNonAnchors;
 
-  for (const auto &IR : IRLocations) {
+  for (const auto &IR : IRAnchors) {
     const auto &Loc = IR.first;
     StringRef CalleeName = IR.second;
     bool IsMatchedAnchor = false;
     // Match the anchor location in lexical order.
     if (!CalleeName.empty()) {
-      auto ProfileAnchors = CalleeToCallsitesMap.find(CalleeName);
-      if (ProfileAnchors != CalleeToCallsitesMap.end() &&
-          !ProfileAnchors->second.empty()) {
-        auto CI = ProfileAnchors->second.begin();
+      auto CandidateAnchors = CalleeToCallsitesMap.find(CalleeName);
+      if (CandidateAnchors != CalleeToCallsitesMap.end() &&
+          !CandidateAnchors->second.empty()) {
+        auto CI = CandidateAnchors->second.begin();
         const auto Candidate = *CI;
-        ProfileAnchors->second.erase(CI);
+        CandidateAnchors->second.erase(CI);
         InsertMatching(Loc, Candidate);
         LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName
                           << " is matched from " << Loc << " to " << Candidate
@@ -2268,105 +2363,29 @@ void SampleProfileMatcher::runStaleProfileMatching(
 
 void SampleProfileMatcher::runOnFunction(const Function &F,
                                          const FunctionSamples &FS) {
-  bool IsFuncHashMismatch = false;
-  if (FunctionSamples::ProfileIsProbeBased) {
-    uint64_t Count = FS.getTotalSamples();
-    TotalFuncHashSamples += Count;
-    TotalProfiledFunc++;
-    if (!ProbeManager->profileIsValid(F, FS)) {
-      MismatchedFuncHashSamples += Count;
-      NumMismatchedFuncHash++;
-      IsFuncHashMismatch = true;
-    }
-  }
-
-  std::unordered_set<LineLocation, LineLocationHash> MatchedCallsiteLocs;
-  // The value of the map is the name of direct callsite and use empty StringRef
-  // for non-direct-call site.
-  std::map<LineLocation, StringRef> IRLocations;
-
-  // Extract profile matching anchors and profile mismatch metrics in the IR.
-  for (auto &BB : F) {
-    for (auto &I : BB) {
-      // TODO: Support line-number based location(AutoFDO).
-      if (FunctionSamples::ProfileIsProbeBased && isa<PseudoProbeInst>(&I)) {
-        if (std::optional<PseudoProbe> Probe = extractProbe(I))
-          IRLocations.emplace(LineLocation(Probe->Id, 0), StringRef());
-      }
-
-      if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
-        continue;
-
-      const auto *CB = dyn_cast<CallBase>(&I);
-      if (auto &DLoc = I.getDebugLoc()) {
-        LineLocation IRCallsite = FunctionSamples::getCallSiteIdentifier(DLoc);
-
-        StringRef CalleeName;
-        if (Function *Callee = CB->getCalledFunction())
-          CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
-
-        // Force to overwrite the callee name in case any non-call location was
-        // written before.
-        auto R = IRLocations.emplace(IRCallsite, CalleeName);
-        R.first->second = CalleeName;
-        assert((!FunctionSamples::ProfileIsProbeBased || R.second ||
-                R.first->second == CalleeName) &&
-               "Overwrite non-call or 
diff erent callee name location for "
-               "pseudo probe callsite");
-
-        // Go through all the callsites on the IR and flag the callsite if the
-        // target name is the same as the one in the profile.
-        const auto CTM = FS.findCallTargetMapAt(IRCallsite);
-        const auto CallsiteFS = FS.findFunctionSamplesMapAt(IRCallsite);
-
-        // Indirect call case.
-        if (CalleeName.empty()) {
-          // Since indirect call does not have the CalleeName, check
-          // conservatively if callsite in the profile is a callsite location.
-          // This is to avoid nums of false positive since otherwise all the
-          // indirect call samples will be reported as mismatching.
-          if ((CTM && !CTM->empty()) || (CallsiteFS && !CallsiteFS->empty()))
-            MatchedCallsiteLocs.insert(IRCallsite);
-        } else {
-          // Check if the call target name is matched for direct call case.
-          if ((CTM && CTM->count(CalleeName)) ||
-              (CallsiteFS && CallsiteFS->count(CalleeName)))
-            MatchedCallsiteLocs.insert(IRCallsite);
-        }
-      }
-    }
-  }
+  // Anchors for IR. It's a map from IR location to callee name, callee name is
+  // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
+  // for unknown indrect callee name.
+  std::map<LineLocation, StringRef> IRAnchors;
+  findIRAnchors(F, IRAnchors);
+  // Anchors for profile. It's a map from callsite location to a set of callee
+  // name.
+  std::map<LineLocation, StringSet<>> ProfileAnchors;
+  findProfileAnchors(FS, ProfileAnchors);
 
   // Detect profile mismatch for profile staleness metrics report.
   if (ReportProfileStaleness || PersistProfileStaleness) {
-    uint64_t FuncMismatchedCallsites = 0;
-    uint64_t FuncProfiledCallsites = 0;
-    countProfileMismatches(FS, MatchedCallsiteLocs, FuncMismatchedCallsites,
-                           FuncProfiledCallsites);
-    TotalProfiledCallsites += FuncProfiledCallsites;
-    NumMismatchedCallsites += FuncMismatchedCallsites;
-    LLVM_DEBUG({
-      if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
-          FuncMismatchedCallsites)
-        dbgs() << "Function checksum is matched but there are "
-               << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
-               << " mismatched callsites.\n";
-    });
-  }
-
-  if (IsFuncHashMismatch && SalvageStaleProfile) {
-    LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
-                      << "\n");
-
-    StringMap<std::set<LineLocation>> CalleeToCallsitesMap;
-    populateProfileCallsites(FS, CalleeToCallsitesMap);
+    countProfileMismatches(F, FS, IRAnchors, ProfileAnchors);
+  }
 
+  // Run profile matching for checksum mismatched profile, currently only
+  // support for pseudo-probe.
+  if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
+      !ProbeManager->profileIsValid(F, FS)) {
     // The matching result will be saved to IRToProfileLocationMap, create a new
     // map for each function.
-    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
-
-    runStaleProfileMatching(IRLocations, CalleeToCallsitesMap,
-                            IRToProfileLocationMap);
+    runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
+                            getIRToProfileLocationMap(F));
   }
 }
 

diff  --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index 8340c3b0e62d5e..4ce24f4491f79d 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -6,9 +6,9 @@
 ; RUN: llvm-objdump --section-headers %t.obj | FileCheck %s --check-prefix=CHECK-OBJ
 ; RUN: llc < %t.ll -filetype=asm -o - | FileCheck %s --check-prefix=CHECK-ASM
 
-; CHECK: (2/3) of callsites' profile are invalid and (15/25) of samples are discarded due to callsite location mismatch.
+; CHECK: (2/3) of callsites' profile are invalid and (25/35) of samples are discarded due to callsite location mismatch.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 15, !"TotalCallsiteSamples", i64 25}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 25, !"TotalCallsiteSamples", i64 35}
 
 ; CHECK-OBJ: .llvm_stats
 
@@ -24,11 +24,11 @@
 ; CHECK-ASM: .byte 25
 ; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
 ; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MTU="
+; CHECK-ASM: .ascii  "MjU="
 ; CHECK-ASM: .byte 20
 ; CHECK-ASM: .ascii  "TotalCallsiteSamples"
 ; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MjU="
+; CHECK-ASM: .ascii  "MzU="
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"


        


More information about the llvm-commits mailing list