[llvm] c98da37 - [CSSPGO] Compute and report profile matching recovered callsites and samples (#79090)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 19 11:36:24 PST 2024


Author: Lei Wang
Date: 2024-02-19T11:36:20-08:00
New Revision: c98da372cb08cd3b3c513a6a86592b0f2892fb49

URL: https://github.com/llvm/llvm-project/commit/c98da372cb08cd3b3c513a6a86592b0f2892fb49
DIFF: https://github.com/llvm/llvm-project/commit/c98da372cb08cd3b3c513a6a86592b0f2892fb49.diff

LOG: [CSSPGO] Compute and report profile matching recovered callsites and samples (#79090)

This change adds the support to compute and report the staleness metrics
after stale profile matching so that we can know how effective the fuzzy
matching is, i. e. how many callsites and samples are recovered by the
matching.
Some implementation notes:
- The function checksum mismatch metrics are not applicable here as it's
function-level metrics, checksum mismatch remains the same before and
after matching, so we need to compute based on the callsite samples.
- Added two new counters `NumRecoveredCallsites`,
`RecoveredCallsiteSamples` for this and removed `TotalCallsiteSamples`
as now the we can use the `TotalFuncHashSamples` as base, and renamed
some counters.
- In profile matching, we changed to use a state machine to represent
the callsite's matching state changes. See the `MatchState` for the
state, and used a new function `recordCallsiteMatchStates` to compute
and record the callsite's match states changes before and after the
matching, , the result is compressed and saved into a
`FuncCallsiteMatchStates` map for later counting use.
- Changed the counting function to run on module-level and moved it to
the end of the whole process(`computeAndReportProfileStaleness`). The
reason is before the callsite is only counted on top-level function,
this change extends it to count(recursively) on the inlined functions
and samples, which is more accurate.

Added: 
    

Modified: 
    llvm/lib/Transforms/IPO/SampleProfile.cpp
    llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
    llvm/test/Transforms/SampleProfile/Inputs/pseudo-probe-profile-mismatch.prof
    llvm/test/Transforms/SampleProfile/profile-mismatch.ll
    llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
    llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 2fd8668d15e200..ffc26f1a0a972d 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -444,15 +444,43 @@ class SampleProfileMatcher {
   // the profile.
   StringMap<LocToLocMap> FuncMappings;
 
-  // Profile mismatching statstics.
+  // Match state for an anchor/callsite.
+  enum class MatchState {
+    Unknown = 0,
+    // Initial match between input profile and current IR.
+    InitialMatch = 1,
+    // Initial mismatch between input profile and current IR.
+    InitialMismatch = 2,
+    // InitialMatch stays matched after fuzzy profile matching.
+    UnchangedMatch = 3,
+    // InitialMismatch stays mismatched after fuzzy profile matching.
+    UnchangedMismatch = 4,
+    // InitialMismatch is recovered after fuzzy profile matching.
+    RecoveredMismatch = 5,
+    // InitialMatch is removed and becomes mismatched after fuzzy profile
+    // matching.
+    RemovedMatch = 6,
+  };
+
+  // For each function, store every callsite and its matching state into this
+  // map, of which each entry is a pair of callsite location and MatchState.
+  // This is used for profile staleness computation and report.
+  StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
+      FuncCallsiteMatchStates;
+
+  // Profile mismatch statstics:
+  uint64_t TotalProfiledFunc = 0;
+  // Num of checksum-mismatched function.
+  uint64_t NumStaleProfileFunc = 0;
   uint64_t TotalProfiledCallsites = 0;
   uint64_t NumMismatchedCallsites = 0;
+  uint64_t NumRecoveredCallsites = 0;
+  // Total samples for all profiled functions.
+  uint64_t TotalFunctionSamples = 0;
+  // Total samples for all checksum-mismatched functions.
+  uint64_t MismatchedFunctionSamples = 0;
   uint64_t MismatchedCallsiteSamples = 0;
-  uint64_t TotalCallsiteSamples = 0;
-  uint64_t TotalProfiledFunc = 0;
-  uint64_t NumMismatchedFuncHash = 0;
-  uint64_t MismatchedFuncHashSamples = 0;
-  uint64_t TotalFuncHashSamples = 0;
+  uint64_t RecoveredCallsiteSamples = 0;
 
   // A dummy name for unknown indirect callee, used to 
diff erentiate from a
   // non-call instruction that also has an empty callee name.
@@ -464,6 +492,11 @@ class SampleProfileMatcher {
                        const PseudoProbeManager *ProbeManager)
       : M(M), Reader(Reader), ProbeManager(ProbeManager){};
   void runOnModule();
+  void clearMatchingData() {
+    // Do not clear FuncMappings, it stores IRLoc to ProfLoc remappings which
+    // will be used for sample loader.
+    FuncCallsiteMatchStates.clear();
+  }
 
 private:
   FunctionSamples *getFlattenedSamplesFor(const Function &F) {
@@ -478,20 +511,43 @@ class SampleProfileMatcher {
                      std::map<LineLocation, StringRef> &IRAnchors);
   void findProfileAnchors(
       const FunctionSamples &FS,
-      std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors);
-  void countMismatchedSamples(const FunctionSamples &FS);
-  void countProfileMismatches(
-      const Function &F, const FunctionSamples &FS,
-      const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors);
-  void countProfileCallsiteMismatches(
-      const FunctionSamples &FS,
-      const std::map<LineLocation, StringRef> &IRAnchors,
+      std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
+  // Record the callsite match states for profile staleness report, the result
+  // is saved in FuncCallsiteMatchStates.
+  void recordCallsiteMatchStates(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
       const std::map<LineLocation, std::unordered_set<FunctionId>>
           &ProfileAnchors,
-      uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
+      const LocToLocMap *IRToProfileLocationMap);
+
+  bool isMismatchState(const enum MatchState &State) {
+    return State == MatchState::InitialMismatch ||
+           State == MatchState::UnchangedMismatch ||
+           State == MatchState::RemovedMatch;
+  };
+
+  bool isInitialState(const enum MatchState &State) {
+    return State == MatchState::InitialMatch ||
+           State == MatchState::InitialMismatch;
+  };
+
+  bool isFinalState(const enum MatchState &State) {
+    return State == MatchState::UnchangedMatch ||
+           State == MatchState::UnchangedMismatch ||
+           State == MatchState::RecoveredMismatch ||
+           State == MatchState::RemovedMatch;
+  };
+
+  // Count the samples of checksum mismatched function for the top-level
+  // function and all inlinees.
+  void countMismatchedFuncSamples(const FunctionSamples &FS, bool IsTopLevel);
+  // Count the number of mismatched or recovered callsites.
+  void countMismatchCallsites(const FunctionSamples &FS);
+  // Count the samples of mismatched or recovered callsites for top-level
+  // function and all inlinees.
+  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
+  void computeAndReportProfileStaleness();
+
   LocToLocMap &getIRToProfileLocationMap(const Function &F) {
     auto Ret = FuncMappings.try_emplace(
         FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -504,6 +560,7 @@ class SampleProfileMatcher {
       const std::map<LineLocation, std::unordered_set<FunctionId>>
           &ProfileAnchors,
       LocToLocMap &IRToProfileLocationMap);
+  void reportOrPersistProfileStats();
 };
 
 /// Sample profile pass.
@@ -690,6 +747,10 @@ void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo(
 }
 } // namespace llvm
 
+static bool skipProfileForFunction(const Function &F) {
+  return F.isDeclaration() || !F.hasFnAttribute("use-sample-profile");
+}
+
 ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
   if (FunctionSamples::ProfileIsProbeBased)
     return getProbeWeight(Inst);
@@ -1129,7 +1190,7 @@ void SampleProfileLoader::findExternalInlineCandidate(
         CalleeSample->getContext().hasAttribute(ContextShouldBeInlined);
     if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold)
       continue;
-    
+
     Function *Func = SymbolMap.lookup(CalleeSample->getFunction());
     // Add to the import list only when it's defined out of module.
     if (!Func || Func->isDeclaration())
@@ -1874,7 +1935,7 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) {
   // the profile. This makes sure functions missing from the profile still
   // gets a chance to be processed.
   for (Function &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+    if (skipProfileForFunction(F))
       continue;
     ProfiledCG->addProfiledFunction(
           getRepInFormat(FunctionSamples::getCanonicalFnName(F)));
@@ -1903,7 +1964,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
     }
 
     for (Function &F : M)
-      if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile"))
+      if (!skipProfileForFunction(F))
         FunctionOrderList.push_back(&F);
     return FunctionOrderList;
   }
@@ -1969,7 +2030,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
       }
       for (auto *Node : Range) {
         Function *F = SymbolMap.lookup(Node->Name);
-        if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile"))
+        if (F && !skipProfileForFunction(*F))
           FunctionOrderList.push_back(F);
       }
       ++CGI;
@@ -1980,7 +2041,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
       for (LazyCallGraph::SCC &C : RC) {
         for (LazyCallGraph::Node &N : C) {
           Function &F = N.getFunction();
-          if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile"))
+          if (!skipProfileForFunction(F))
             FunctionOrderList.push_back(&F);
         }
       }
@@ -2190,108 +2251,9 @@ void SampleProfileMatcher::findIRAnchors(
   }
 }
 
-void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) {
-  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
-  // Skip the function that is external or renamed.
-  if (!FuncDesc)
-    return;
-
-  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-    MismatchedFuncHashSamples += FS.getTotalSamples();
-    return;
-  }
-  for (const auto &I : FS.getCallsiteSamples())
-    for (const auto &CS : I.second)
-      countMismatchedSamples(CS.second);
-}
-
-void SampleProfileMatcher::countProfileMismatches(
-    const Function &F, const FunctionSamples &FS,
-    const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors) {
-  [[maybe_unused]] bool IsFuncHashMismatch = false;
-  if (FunctionSamples::ProfileIsProbeBased) {
-    TotalFuncHashSamples += FS.getTotalSamples();
-    TotalProfiledFunc++;
-    const auto *FuncDesc = ProbeManager->getDesc(F);
-    if (FuncDesc) {
-      if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-        NumMismatchedFuncHash++;
-        IsFuncHashMismatch = true;
-      }
-      countMismatchedSamples(FS);
-    }
-  }
-
-  uint64_t FuncMismatchedCallsites = 0;
-  uint64_t FuncProfiledCallsites = 0;
-  countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors,
-                                 FuncMismatchedCallsites,
-                                 FuncProfiledCallsites);
-  TotalProfiledCallsites += FuncProfiledCallsites;
-  NumMismatchedCallsites += FuncMismatchedCallsites;
-  LLVM_DEBUG({
-    if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
-        FuncMismatchedCallsites)
-      dbgs() << "Function checksum is matched but there are "
-             << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
-             << " mismatched callsites.\n";
-  });
-}
-
-void SampleProfileMatcher::countProfileCallsiteMismatches(
+void SampleProfileMatcher::findProfileAnchors(
     const FunctionSamples &FS,
-    const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
-    uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    const auto &Callees = I.second;
-    assert(!Callees.empty() && "Callees should not be empty");
-
-    StringRef IRCalleeName;
-    const auto &IR = IRAnchors.find(Loc);
-    if (IR != IRAnchors.end())
-      IRCalleeName = IR->second;
-
-    // Compute number of samples in the original profile.
-    uint64_t CallsiteSamples = 0;
-    if (auto CTM = FS.findCallTargetMapAt(Loc)) {
-      for (const auto &I : *CTM)
-        CallsiteSamples += I.second;
-    }
-    const auto *FSMap = FS.findFunctionSamplesMapAt(Loc);
-    if (FSMap) {
-      for (const auto &I : *FSMap)
-        CallsiteSamples += I.second.getTotalSamples();
-    }
-
-    bool CallsiteIsMatched = false;
-    // Since indirect call does not have CalleeName, check conservatively if
-    // callsite in the profile is a callsite location. This is to reduce num of
-    // false positive since otherwise all the indirect call samples will be
-    // reported as mismatching.
-    if (IRCalleeName == UnknownIndirectCallee)
-      CallsiteIsMatched = true;
-    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
-      CallsiteIsMatched = true;
-
-    FuncProfiledCallsites++;
-    TotalCallsiteSamples += CallsiteSamples;
-    if (!CallsiteIsMatched) {
-      FuncMismatchedCallsites++;
-      MismatchedCallsiteSamples += CallsiteSamples;
-    }
-  }
-}
-
-void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS,
-                                              std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
+    std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
   auto isInvalidLineOffset = [](uint32_t LineOffset) {
     return LineOffset & 0x8000;
   };
@@ -2338,8 +2300,7 @@ void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS,
 //   [1, 2, 3(foo), 4,  7,  8(bar), 9]
 // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
 void SampleProfileMatcher::runStaleProfileMatching(
-    const Function &F,
-    const std::map<LineLocation, StringRef> &IRAnchors,
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
     const std::map<LineLocation, std::unordered_set<FunctionId>>
         &ProfileAnchors,
     LocToLocMap &IRToProfileLocationMap) {
@@ -2443,16 +2404,9 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
   std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
   findProfileAnchors(*FSFlattened, ProfileAnchors);
 
-  // Detect profile mismatch for profile staleness metrics report.
-  // Skip reporting the metrics for imported functions.
-  if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
-      (ReportProfileStaleness || PersistProfileStaleness)) {
-    // Use top-level nested FS for counting profile mismatch metrics since
-    // currently once a callsite is mismatched, all its children profiles are
-    // dropped.
-    if (const auto *FS = Reader.getSamplesFor(F))
-      countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors);
-  }
+  // Compute the callsite match states for profile staleness report.
+  if (ReportProfileStaleness || PersistProfileStaleness)
+    recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
 
   // Run profile matching for checksum mismatched profile, currently only
   // support for pseudo-probe.
@@ -2460,36 +2414,231 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
       !ProbeManager->profileIsValid(F, *FSFlattened)) {
     // The matching result will be saved to IRToProfileLocationMap, create a new
     // map for each function.
+    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
     runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
-                            getIRToProfileLocationMap(F));
+                            IRToProfileLocationMap);
+    // Find and update callsite match states after matching.
+    if (ReportProfileStaleness || PersistProfileStaleness)
+      recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
+                                &IRToProfileLocationMap);
   }
 }
 
-void SampleProfileMatcher::runOnModule() {
-  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
-                                   FunctionSamples::ProfileIsCS);
-  for (auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+void SampleProfileMatcher::recordCallsiteMatchStates(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    const LocToLocMap *IRToProfileLocationMap) {
+  bool IsPostMatch = IRToProfileLocationMap != nullptr;
+  auto &CallsiteMatchStates =
+      FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    // IRToProfileLocationMap is null in pre-match phrase.
+    if (!IRToProfileLocationMap)
+      return IRLoc;
+    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap->end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  };
+
+  for (const auto &I : IRAnchors) {
+    // After fuzzy profile matching, use the matching result to remap the
+    // current IR callsite.
+    const auto &ProfileLoc = MapIRLocToProfileLoc(I.first);
+    const auto &IRCalleeName = I.second;
+    const auto &It = ProfileAnchors.find(ProfileLoc);
+    if (It == ProfileAnchors.end())
       continue;
-    runOnFunction(F);
+    const auto &Callees = It->second;
+
+    bool IsCallsiteMatched = false;
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
+      IsCallsiteMatched = true;
+    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
+      IsCallsiteMatched = true;
+
+    if (IsCallsiteMatched) {
+      auto It = CallsiteMatchStates.find(ProfileLoc);
+      if (It == CallsiteMatchStates.end())
+        CallsiteMatchStates.emplace(ProfileLoc, MatchState::InitialMatch);
+      else if (IsPostMatch) {
+        if (It->second == MatchState::InitialMatch)
+          It->second = MatchState::UnchangedMatch;
+        else if (It->second == MatchState::InitialMismatch)
+          It->second = MatchState::RecoveredMismatch;
+      }
+    }
+  }
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    [[maybe_unused]] const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      CallsiteMatchStates.emplace(Loc, MatchState::InitialMismatch);
+    else if (IsPostMatch) {
+      // Update the state if it's not matched(UnchangedMatch or
+      // RecoveredMismatch).
+      if (It->second == MatchState::InitialMismatch)
+        It->second = MatchState::UnchangedMismatch;
+      else if (It->second == MatchState::InitialMatch)
+        It->second = MatchState::RemovedMatch;
+    }
+  }
+}
+
+void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
+                                                      bool IsTopLevel) {
+  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
+  // Skip the function that is external or renamed.
+  if (!FuncDesc)
+    return;
+
+  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+    if (IsTopLevel)
+      NumStaleProfileFunc++;
+    // Given currently all probe ids are after block probe ids, once the
+    // checksum is mismatched, it's likely all the callites are mismatched and
+    // dropped. We conservatively count all the samples as mismatched and stop
+    // counting the inlinees' profiles.
+    MismatchedFunctionSamples += FS.getTotalSamples();
+    return;
+  }
+
+  // Even the current-level function checksum is matched, it's possible that the
+  // nested inlinees' checksums are mismatched that affect the inlinee's sample
+  // loading, we need to go deeper to check the inlinees' function samples.
+  // Similarly, count all the samples as mismatched if the inlinee's checksum is
+  // mismatched using this recursive function.
+  for (const auto &I : FS.getCallsiteSamples())
+    for (const auto &CS : I.second)
+      countMismatchedFuncSamples(CS.second, false);
+}
+
+void SampleProfileMatcher::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &CallsiteMatchStates = It->second;
+
+  auto findMatchState = [&](const LineLocation &Loc) {
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      return MatchState::Unknown;
+    return It->second;
+  };
+
+  auto AttributeMismatchedSamples = [&](const enum MatchState &State,
+                                        uint64_t Samples) {
+    if (isMismatchState(State))
+      MismatchedCallsiteSamples += Samples;
+    else if (State == MatchState::RecoveredMismatch)
+      RecoveredCallsiteSamples += Samples;
+  };
+
+  // The non-inlined callsites are saved in the body samples of function
+  // profile, go through it to count the non-inlined callsite samples.
+  for (const auto &I : FS.getBodySamples())
+    AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
+
+  // Count the inlined callsite samples.
+  for (const auto &I : FS.getCallsiteSamples()) {
+    auto State = findMatchState(I.first);
+    uint64_t CallsiteSamples = 0;
+    for (const auto &CS : I.second)
+      CallsiteSamples += CS.second.getTotalSamples();
+    AttributeMismatchedSamples(State, CallsiteSamples);
+
+    if (isMismatchState(State))
+      continue;
+
+    // When the current level of inlined call site matches the profiled call
+    // site, we need to go deeper along the inline tree to count mismatches from
+    // lower level inlinees.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second);
+  }
+}
+
+void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &MatchStates = It->second;
+  [[maybe_unused]] bool OnInitialState =
+      isInitialState(MatchStates.begin()->second);
+  for (const auto &I : MatchStates) {
+    TotalProfiledCallsites++;
+    assert(
+        (OnInitialState ? isInitialState(I.second) : isFinalState(I.second)) &&
+        "Profile matching state is inconsistent");
+
+    if (isMismatchState(I.second))
+      NumMismatchedCallsites++;
+    else if (I.second == MatchState::RecoveredMismatch)
+      NumRecoveredCallsites++;
+  }
+}
+
+void SampleProfileMatcher::computeAndReportProfileStaleness() {
+  if (!ReportProfileStaleness && !PersistProfileStaleness)
+    return;
+
+  // Count profile mismatches for profile staleness report.
+  for (const auto &F : M) {
+    if (skipProfileForFunction(F))
+      continue;
+    // As the stats will be merged by linker, skip reporting the metrics for
+    // imported functions to avoid repeated counting.
+    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
+      continue;
+    const auto *FS = Reader.getSamplesFor(F);
+    if (!FS)
+      continue;
+    TotalProfiledFunc++;
+    TotalFunctionSamples += FS->getTotalSamples();
+
+    // Checksum mismatch is only used in pseudo-probe mode.
+    if (FunctionSamples::ProfileIsProbeBased)
+      countMismatchedFuncSamples(*FS, true);
+
+    // Count mismatches and samples for calliste.
+    countMismatchCallsites(*FS);
+    countMismatchedCallsiteSamples(*FS);
   }
-  if (SalvageStaleProfile)
-    distributeIRToProfileLocationMap();
 
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")"
+      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
-             << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples
-             << ")"
-             << " of samples are discarded due to function hash mismatch.\n";
+             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
+             << ") of samples are discarded due to function hash mismatch.\n";
     }
-    errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
-           << ")"
+    errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
+           << TotalProfiledCallsites << ")"
            << " of callsites' profile are invalid and "
-           << "(" << MismatchedCallsiteSamples << "/" << TotalCallsiteSamples
-           << ")"
+           << "(" << (MismatchedCallsiteSamples + RecoveredCallsiteSamples)
+           << "/" << TotalFunctionSamples << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
+    errs() << "(" << NumRecoveredCallsites << "/"
+           << (NumRecoveredCallsites + NumMismatchedCallsites) << ")"
+           << " of callsites and "
+           << "(" << RecoveredCallsiteSamples << "/"
+           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) << ")"
+           << " of samples are recovered by stale profile matching.\n";
   }
 
   if (PersistProfileStaleness) {
@@ -2498,18 +2647,20 @@ void SampleProfileMatcher::runOnModule() {
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash);
+      ProfStatsVec.emplace_back("NumStaleProfileFunc", NumStaleProfileFunc);
       ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
-      ProfStatsVec.emplace_back("MismatchedFuncHashSamples",
-                                MismatchedFuncHashSamples);
-      ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples);
+      ProfStatsVec.emplace_back("MismatchedFunctionSamples",
+                                MismatchedFunctionSamples);
+      ProfStatsVec.emplace_back("TotalFunctionSamples", TotalFunctionSamples);
     }
 
     ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("NumRecoveredCallsites", NumRecoveredCallsites);
     ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
     ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
                               MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("TotalCallsiteSamples", TotalCallsiteSamples);
+    ProfStatsVec.emplace_back("RecoveredCallsiteSamples",
+                              RecoveredCallsiteSamples);
 
     auto *MD = MDB.createLLVMStats(ProfStatsVec);
     auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
@@ -2517,6 +2668,20 @@ void SampleProfileMatcher::runOnModule() {
   }
 }
 
+void SampleProfileMatcher::runOnModule() {
+  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+                                   FunctionSamples::ProfileIsCS);
+  for (auto &F : M) {
+    if (skipProfileForFunction(F))
+      continue;
+    runOnFunction(F);
+  }
+  if (SalvageStaleProfile)
+    distributeIRToProfileLocationMap();
+
+  computeAndReportProfileStaleness();
+}
+
 void SampleProfileMatcher::distributeIRToProfileLocationMap(
     FunctionSamples &FS) {
   const auto ProfileMappings = FuncMappings.find(FS.getFuncName());
@@ -2587,6 +2752,7 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
   if (ReportProfileStaleness || PersistProfileStaleness ||
       SalvageStaleProfile) {
     MatchingManager->runOnModule();
+    MatchingManager->clearMatchingData();
   }
 
   bool retval = false;

diff  --git a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
index 818a048b8cabb8..f2a00e789b8b66 100644
--- a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
+++ b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
@@ -2,14 +2,15 @@ main:30:0
  0: 0
  1.1: 0
  3: 10 matched:10
- 4: 10
- 5: 10 bar_mismatch:10
+ 7: 10
  8: 0
- 7: foo:15
+ 4: foo:15
   1: 5
   2: 5
   3: inlinee_mismatch:5
    1: 5
+ 5: bar_mismatch:10
+  1: 10
 bar:10:10
  1: 10
 matched:10:10

diff  --git a/llvm/test/Transforms/SampleProfile/Inputs/pseudo-probe-profile-mismatch.prof b/llvm/test/Transforms/SampleProfile/Inputs/pseudo-probe-profile-mismatch.prof
index 5dc266941b2f22..903b30f249ddfe 100644
--- a/llvm/test/Transforms/SampleProfile/Inputs/pseudo-probe-profile-mismatch.prof
+++ b/llvm/test/Transforms/SampleProfile/Inputs/pseudo-probe-profile-mismatch.prof
@@ -5,7 +5,7 @@ main:30:0
  13: foo_mismatch:10
   1: 10
   !CFGChecksum: 4294967295
- !CFGChecksum: 844635331715433
+ !CFGChecksum: 84463533171543
 bar:10:10
  1: 10
  !CFGChecksum: 42949671295

diff  --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index d86175c02dbb42..42bc1b81f67059 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -6,29 +6,32 @@
 ; RUN: llvm-objdump --section-headers %t.obj | FileCheck %s --check-prefix=CHECK-OBJ
 ; RUN: llc < %t.ll -filetype=asm -o - | FileCheck %s --check-prefix=CHECK-ASM
 
-; CHECK: (2/3) of callsites' profile are invalid and (25/35) of samples are discarded due to callsite location mismatch.
+; CHECK: (1/3) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 25, !"TotalCallsiteSamples", i64 35}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 1, !"NumRecoveredCallsites", i64 0, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 15, !"RecoveredCallsiteSamples", i64 0}
 
 ; CHECK-OBJ: .llvm_stats
 
-; CHECK-ASM: .section  .llvm_stats,"", at progbits
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "NumMismatchedCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mg=="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "TotalProfiledCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MjU="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MzU="
+; CHECK-ASM: .ascii	"NumMismatchedCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MQ=="
+; CHECK-ASM: .byte	21
+; CHECK-ASM: .ascii	"NumRecoveredCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"TotalProfiledCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTU="
+; CHECK-ASM: .byte	24
+; CHECK-ASM: .ascii	"RecoveredCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
+
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"

diff  --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
index 29c3a142cc68f8..bf33df2481044b 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
@@ -4,7 +4,7 @@
 ; RUN: FileCheck %s --input-file %t.ll -check-prefix=CHECK-MD
 
 ; CHECK: (1/1) of functions' profile are invalid and  (6822/6822) of samples are discarded due to function hash mismatch.
-; CHECK: (4/4) of callsites' profile are invalid and (5026/5026) of samples are discarded due to callsite location mismatch.
+; CHECK: (4/4) of callsites' profile are invalid and (5026/6822) of samples are discarded due to callsite location mismatch.
+; CHECK: (4/4) of callsites and (5026/5026) of samples are recovered by stale profile matching.
 
-
-; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"TotalFuncHashSamples", i64 6822, !"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalCallsiteSamples", i64 5026}
+; CHECK-MD: ![[#]] = !{!"NumStaleProfileFunc", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFunctionSamples", i64 6822, !"TotalFunctionSamples", i64 6822, !"NumMismatchedCallsites", i64 0, !"NumRecoveredCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 0, !"RecoveredCallsiteSamples", i64 5026}

diff  --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
index 4b6edf821376c0..6efd9f511e1372 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
@@ -9,46 +9,53 @@
 ; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-profile-mismatch-nested.prof -report-profile-staleness -persist-profile-staleness -S 2>&1 | FileCheck %s --check-prefix=CHECK-NESTED
 
 
-; CHECK: (1/3) of functions' profile are invalid and (10/50) of samples are discarded due to function hash mismatch.
-; CHECK: (2/3) of callsites' profile are invalid and (20/30) of samples are discarded due to callsite location mismatch.
+; CHECK: (2/3) of functions' profile are invalid and (40/50) of samples are discarded due to function hash mismatch.
+; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch.
+; CHECK: (1/2) of callsites and (10/20) of samples are recovered by stale profile matching.
+
+; CHECK-MD: ![[#]] = !{!"NumStaleProfileFunc", i64 2, !"TotalProfiledFunc", i64 3, !"MismatchedFunctionSamples", i64 40, !"TotalFunctionSamples", i64 50, !"NumMismatchedCallsites", i64 1, !"NumRecoveredCallsites", i64 1, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 10, !"RecoveredCallsiteSamples", i64 10}
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 3, !"MismatchedFuncHashSamples", i64 10, !"TotalFuncHashSamples", i64 50, !"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalCallsiteSamples", i64 30}
 
 ; CHECK-OBJ: .llvm_stats
 
-; CHECK-ASM: .section  .llvm_stats,"", at progbits
-; CHECK-ASM: .byte 21
-; CHECK-ASM: .ascii  "NumMismatchedFuncHash"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MQ=="
-; CHECK-ASM: .byte 17
-; CHECK-ASM: .ascii  "TotalProfiledFunc"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedFuncHashSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MTA="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalFuncHashSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "NTA="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "NumMismatchedCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mg=="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "TotalProfiledCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MjA="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MzA="
+; CHECK-ASM: .section	.llvm_stats,"", at progbits
+; CHECK-ASM: .byte	19
+; CHECK-ASM: .ascii	"NumStaleProfileFunc"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mg=="
+; CHECK-ASM: .byte	17
+; CHECK-ASM: .ascii	"TotalProfiledFunc"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedFunctionSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"NDA="
+; CHECK-ASM: .byte	20
+; CHECK-ASM: .ascii	"TotalFunctionSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"NTA="
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"NumMismatchedCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MQ=="
+; CHECK-ASM: .byte	21
+; CHECK-ASM: .ascii	"NumRecoveredCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MQ=="
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"TotalProfiledCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTA="
+; CHECK-ASM: .byte	24
+; CHECK-ASM: .ascii	"RecoveredCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTA="
+
 
 ; CHECK-NESTED: (1/2) of functions' profile are invalid and (211/311) of samples are discarded due to function hash mismatch.
 


        


More information about the llvm-commits mailing list