[llvm] [CSSPGO] Compute and report profile matching recovered callsites and samples (PR #79090)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 12 20:51:24 PST 2024


================
@@ -2443,53 +2374,231 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
   std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
   findProfileAnchors(*FSFlattened, ProfileAnchors);
 
-  // Detect profile mismatch for profile staleness metrics report.
-  // Skip reporting the metrics for imported functions.
-  if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
-      (ReportProfileStaleness || PersistProfileStaleness)) {
-    // Use top-level nested FS for counting profile mismatch metrics since
-    // currently once a callsite is mismatched, all its children profiles are
-    // dropped.
-    if (const auto *FS = Reader.getSamplesFor(F))
-      countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors);
-  }
+  // Compute the callsite match states for profile staleness report.
+  if (ReportProfileStaleness || PersistProfileStaleness)
+    computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
 
   // Run profile matching for checksum mismatched profile, currently only
   // support for pseudo-probe.
   if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
       !ProbeManager->profileIsValid(F, *FSFlattened)) {
     // The matching result will be saved to IRToProfileLocationMap, create a new
     // map for each function.
+    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
     runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
-                            getIRToProfileLocationMap(F));
+                            IRToProfileLocationMap);
+    // Find and update callsite match states after matching.
+    if (ReportProfileStaleness || PersistProfileStaleness)
+      computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
+                                 &IRToProfileLocationMap);
   }
 }
 
-void SampleProfileMatcher::runOnModule() {
-  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
-                                   FunctionSamples::ProfileIsCS);
-  for (auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+void SampleProfileMatcher::computeCallsiteMatchStates(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    const LocToLocMap *IRToProfileLocationMap) {
+  bool IsPostMatch = IRToProfileLocationMap != nullptr;
+  auto &CallsiteMatchStates =
+      FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    // IRToProfileLocationMap is null in pre-match phrase.
+    if (!IRToProfileLocationMap)
+      return IRLoc;
+    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap->end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  };
+
+  for (const auto &I : IRAnchors) {
+    // In post-match, use the matching result to remap the current IR callsite.
+    const auto &Loc = MapIRLocToProfileLoc(I.first);
+    const auto &IRCalleeName = I.second;
+    const auto &It = ProfileAnchors.find(Loc);
+    if (It == ProfileAnchors.end())
       continue;
-    runOnFunction(F);
+    const auto &Callees = It->second;
+
+    bool IsCallsiteMatched = false;
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
+      IsCallsiteMatched = true;
+    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
+      IsCallsiteMatched = true;
+
+    if (IsCallsiteMatched) {
+      auto R = CallsiteMatchStates.emplace(Loc, MatchState::Matched);
+      // When there is an existing state, we know it's in post-match phrase.
+      // Update the matching state accordingly.
+      if (!R.second) {
+        if (R.first->second == MatchState::Mismatched)
+          R.first->second = MatchState::Recovered;
+        if (R.first->second == MatchState::Matched)
+          R.first->second = MatchState::StayMatched;
+      }
+    }
+  }
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    [[maybe_unused]] const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
+    // In post-match, if the state is not updated to Recovered or StayMatched,
+    // update it to Mismatched.
+    else if (IsPostMatch && It->second == MatchState::Matched)
+      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
+  }
+}
+
+void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
+                                                      bool IsTopLevel) {
+  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
+  // Skip the function that is external or renamed.
+  if (!FuncDesc)
+    return;
+
+  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+    if (IsTopLevel)
+      NumStaleProfileFunc++;
+    // Given currently all probe ids are after block probe ids, once the
+    // checksum is mismatched, it's likely all the callites are mismatched and
+    // dropped. We conservatively count all the samples as mismatched and stop
+    // counting the inlinees' profiles.
+    MismatchedFunctionSamples += FS.getTotalSamples();
+    return;
+  }
+
+  // Even the current-level function checksum is matched, it's possible that the
+  // nested inlinees' checksums are mismatched that affect the inlinee's sample
+  // loading, we need to go deeper to check the inlinees' function samples.
+  // Similarly, count all the samples as mismatched if the inlinee's checksum is
+  // mismatched using this recursive function.
+  for (const auto &I : FS.getCallsiteSamples())
+    for (const auto &CS : I.second)
+      countMismatchedFuncSamples(CS.second, false);
+}
+
+void SampleProfileMatcher::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &CallsiteMatchStates = It->second;
+
+  auto findMatchState = [&](const LineLocation &Loc) {
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      return MatchState::Unknown;
+    return It->second;
+  };
+
+  auto AttributeMismatchedSamples = [&](const enum MatchState &State,
+                                        uint64_t Samples) {
+    if (State == MatchState::Mismatched)
+      MismatchedCallsiteSamples += Samples;
+    else if (State == MatchState::Recovered)
+      RecoveredCallsiteSamples += Samples;
+  };
+
+  // The non-inlined callsites are saved in the body samples of function
+  // profile, go through it to count the non-inlined callsite samples.
+  for (const auto &I : FS.getBodySamples())
+    AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
+
+  // Count the inlined callsite samples.
+  for (const auto &I : FS.getCallsiteSamples()) {
+    auto State = findMatchState(I.first);
+    uint64_t CallsiteSamples = 0;
+    for (const auto &CS : I.second)
+      CallsiteSamples += CS.second.getTotalSamples();
+    AttributeMismatchedSamples(State, CallsiteSamples);
+
+    if (State == MatchState::Mismatched)
+      continue;
+
+    // When the current level of inlined call site matches the profiled call
+    // site, we need to go deeper along the inline tree to count mismatches from
+    // lower level inlinees.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second);
+  }
+}
+
+void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &MatchStates = It->second;
+  for (const auto &I : MatchStates) {
+    TotalProfiledCallsites++;
+    if (I.second == MatchState::Mismatched)
+      NumMismatchedCallsites++;
+    else if (I.second == MatchState::Recovered)
+      NumRecoveredCallsites++;
+  }
+}
+
+void SampleProfileMatcher::computeAndReportProfileStaleness() {
+  if (!ReportProfileStaleness && !PersistProfileStaleness)
+    return;
+
+  // Count profile mismatches for profile staleness report.
+  for (const auto &F : M) {
+    if (skipProfileForFunction(F))
+      continue;
+    // As the stats will be merged by linker, skip reporting the metrics for
+    // imported functions to avoid repeated counting.
+    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
+      continue;
+    const auto *FS = Reader.getSamplesFor(F);
+    if (!FS)
+      continue;
+    TotalProfiledFunc++;
+    TotalFunctionSamples += FS->getTotalSamples();
+
+    // Checksum mismatch is only used in pseudo-probe mode.
+    if (FunctionSamples::ProfileIsProbeBased)
+      countMismatchedFuncSamples(*FS, true);
+
+    // Count mismatches and samples for calliste.
+    countMismatchCallsites(*FS);
+    countMismatchedCallsiteSamples(*FS);
   }
-  if (SalvageStaleProfile)
-    distributeIRToProfileLocationMap();
 
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")"
+      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
-             << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples
+             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
              << ")"
              << " of samples are discarded due to function hash mismatch.\n";
----------------
WenleiHe wrote:

nit: merge the two lines?

https://github.com/llvm/llvm-project/pull/79090


More information about the llvm-commits mailing list