[llvm] [CSSPGO] Compute and report profile matching recovered callsites and samples (PR #79090)

Lei Wang via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 13 18:27:00 PST 2024


https://github.com/wlei-llvm updated https://github.com/llvm/llvm-project/pull/79090

>From b398feedef8a89cb2e58d09b7cefa48d8efc568c Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Mon, 22 Jan 2024 19:16:26 -0800
Subject: [PATCH 1/9] [CSSPGO] Support post-match profile staleness metrics

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp     | 440 +++++++++++-------
 .../Inputs/profile-mismatch.prof              |   7 +-
 .../SampleProfile/profile-mismatch.ll         |  12 +-
 .../pseudo-probe-profile-mismatch-thinlto.ll  |   6 +-
 .../pseudo-probe-profile-mismatch.ll          |  76 +--
 5 files changed, 324 insertions(+), 217 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 2fd8668d15e200..a7170faa65dc07 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -433,12 +433,19 @@ using CandidateQueue =
     PriorityQueue<InlineCandidate, std::vector<InlineCandidate>,
                   CandidateComparer>;
 
+using IRAnchorMap = std::map<LineLocation, StringRef>;
+using ProfileAnchorMap = std::map<LineLocation, std::unordered_set<FunctionId>>;
+
 // Sample profile matching - fuzzy match.
 class SampleProfileMatcher {
   Module &M;
   SampleProfileReader &Reader;
   const PseudoProbeManager *ProbeManager;
   SampleProfileMap FlattenedProfiles;
+
+  std::unordered_map<const Function *, IRAnchorMap> FuncIRAnchors;
+  std::unordered_map<const Function *, ProfileAnchorMap> FuncProfileAnchors;
+
   // For each function, the matcher generates a map, of which each entry is a
   // mapping from the source location of current build to the source location in
   // the profile.
@@ -448,6 +455,8 @@ class SampleProfileMatcher {
   uint64_t TotalProfiledCallsites = 0;
   uint64_t NumMismatchedCallsites = 0;
   uint64_t MismatchedCallsiteSamples = 0;
+  uint64_t PostMatchNumMismatchedCallsites = 0;
+  uint64_t PostMatchMismatchedCallsiteSamples = 0;
   uint64_t TotalCallsiteSamples = 0;
   uint64_t TotalProfiledFunc = 0;
   uint64_t NumMismatchedFuncHash = 0;
@@ -474,24 +483,22 @@ class SampleProfileMatcher {
     return nullptr;
   }
   void runOnFunction(const Function &F);
-  void findIRAnchors(const Function &F,
-                     std::map<LineLocation, StringRef> &IRAnchors);
-  void findProfileAnchors(
+  void findFuncAnchors();
+  void UpdateIRAnchors();
+  void findIRAnchors(const Function &F, IRAnchorMap &IRAnchors);
+  void findProfileAnchors(const FunctionSamples &FS,
+                          ProfileAnchorMap &ProfileAnchors);
+  void countMismatchedHashSamples(const FunctionSamples &FS);
+  void countProfileMismatches(bool IsPreMatch);
+  void countMismatchedHashes(const Function &F, const FunctionSamples &FS);
+  void countMismatchedCallsites(
+      const Function &F,
+      StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
+      uint64_t &FuncProfiledCallsites, uint64_t &FuncMismatchedCallsites) const;
+  void countMismatchedCallsiteSamples(
       const FunctionSamples &FS,
-      std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors);
-  void countMismatchedSamples(const FunctionSamples &FS);
-  void countProfileMismatches(
-      const Function &F, const FunctionSamples &FS,
-      const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors);
-  void countProfileCallsiteMismatches(
-      const FunctionSamples &FS,
-      const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors,
-      uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites);
+      StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
+      uint64_t &FuncMismatchedCallsiteSamples) const;
   LocToLocMap &getIRToProfileLocationMap(const Function &F) {
     auto Ret = FuncMappings.try_emplace(
         FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -499,11 +506,10 @@ class SampleProfileMatcher {
   }
   void distributeIRToProfileLocationMap();
   void distributeIRToProfileLocationMap(FunctionSamples &FS);
-  void runStaleProfileMatching(
-      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors,
-      LocToLocMap &IRToProfileLocationMap);
+  void runStaleProfileMatching();
+  void runStaleProfileMatching(const Function &F, const IRAnchorMap &IRAnchors,
+                               const ProfileAnchorMap &ProfileAnchors,
+                               LocToLocMap &IRToProfileLocationMap);
 };
 
 /// Sample profile pass.
@@ -1129,7 +1135,7 @@ void SampleProfileLoader::findExternalInlineCandidate(
         CalleeSample->getContext().hasAttribute(ContextShouldBeInlined);
     if (!PreInline && CalleeSample->getHeadSamplesEstimate() < Threshold)
       continue;
-    
+
     Function *Func = SymbolMap.lookup(CalleeSample->getFunction());
     // Add to the import list only when it's defined out of module.
     if (!Func || Func->isDeclaration())
@@ -2123,8 +2129,8 @@ bool SampleProfileLoader::doInitialization(Module &M,
   return true;
 }
 
-void SampleProfileMatcher::findIRAnchors(
-    const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
+void SampleProfileMatcher::findIRAnchors(const Function &F,
+                                         IRAnchorMap &IRAnchors) {
   // For inlined code, recover the original callsite and callee by finding the
   // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the
   // top-level frame is "main:1", the callsite is "1" and the callee is "foo".
@@ -2190,7 +2196,8 @@ void SampleProfileMatcher::findIRAnchors(
   }
 }
 
-void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) {
+void SampleProfileMatcher::countMismatchedHashSamples(
+    const FunctionSamples &FS) {
   const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
   // Skip the function that is external or renamed.
   if (!FuncDesc)
@@ -2202,96 +2209,11 @@ void SampleProfileMatcher::countMismatchedSamples(const FunctionSamples &FS) {
   }
   for (const auto &I : FS.getCallsiteSamples())
     for (const auto &CS : I.second)
-      countMismatchedSamples(CS.second);
-}
-
-void SampleProfileMatcher::countProfileMismatches(
-    const Function &F, const FunctionSamples &FS,
-    const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors) {
-  [[maybe_unused]] bool IsFuncHashMismatch = false;
-  if (FunctionSamples::ProfileIsProbeBased) {
-    TotalFuncHashSamples += FS.getTotalSamples();
-    TotalProfiledFunc++;
-    const auto *FuncDesc = ProbeManager->getDesc(F);
-    if (FuncDesc) {
-      if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-        NumMismatchedFuncHash++;
-        IsFuncHashMismatch = true;
-      }
-      countMismatchedSamples(FS);
-    }
-  }
-
-  uint64_t FuncMismatchedCallsites = 0;
-  uint64_t FuncProfiledCallsites = 0;
-  countProfileCallsiteMismatches(FS, IRAnchors, ProfileAnchors,
-                                 FuncMismatchedCallsites,
-                                 FuncProfiledCallsites);
-  TotalProfiledCallsites += FuncProfiledCallsites;
-  NumMismatchedCallsites += FuncMismatchedCallsites;
-  LLVM_DEBUG({
-    if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
-        FuncMismatchedCallsites)
-      dbgs() << "Function checksum is matched but there are "
-             << FuncMismatchedCallsites << "/" << FuncProfiledCallsites
-             << " mismatched callsites.\n";
-  });
+      countMismatchedHashSamples(CS.second);
 }
 
-void SampleProfileMatcher::countProfileCallsiteMismatches(
-    const FunctionSamples &FS,
-    const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
-    uint64_t &FuncMismatchedCallsites, uint64_t &FuncProfiledCallsites) {
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    const auto &Callees = I.second;
-    assert(!Callees.empty() && "Callees should not be empty");
-
-    StringRef IRCalleeName;
-    const auto &IR = IRAnchors.find(Loc);
-    if (IR != IRAnchors.end())
-      IRCalleeName = IR->second;
-
-    // Compute number of samples in the original profile.
-    uint64_t CallsiteSamples = 0;
-    if (auto CTM = FS.findCallTargetMapAt(Loc)) {
-      for (const auto &I : *CTM)
-        CallsiteSamples += I.second;
-    }
-    const auto *FSMap = FS.findFunctionSamplesMapAt(Loc);
-    if (FSMap) {
-      for (const auto &I : *FSMap)
-        CallsiteSamples += I.second.getTotalSamples();
-    }
-
-    bool CallsiteIsMatched = false;
-    // Since indirect call does not have CalleeName, check conservatively if
-    // callsite in the profile is a callsite location. This is to reduce num of
-    // false positive since otherwise all the indirect call samples will be
-    // reported as mismatching.
-    if (IRCalleeName == UnknownIndirectCallee)
-      CallsiteIsMatched = true;
-    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
-      CallsiteIsMatched = true;
-
-    FuncProfiledCallsites++;
-    TotalCallsiteSamples += CallsiteSamples;
-    if (!CallsiteIsMatched) {
-      FuncMismatchedCallsites++;
-      MismatchedCallsiteSamples += CallsiteSamples;
-    }
-  }
-}
-
-void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS,
-                                              std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
+void SampleProfileMatcher::findProfileAnchors(
+    const FunctionSamples &FS, ProfileAnchorMap &ProfileAnchors) {
   auto isInvalidLineOffset = [](uint32_t LineOffset) {
     return LineOffset & 0x8000;
   };
@@ -2338,10 +2260,8 @@ void SampleProfileMatcher::findProfileAnchors(const FunctionSamples &FS,
 //   [1, 2, 3(foo), 4,  7,  8(bar), 9]
 // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
 void SampleProfileMatcher::runStaleProfileMatching(
-    const Function &F,
-    const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
+    const Function &F, const IRAnchorMap &IRAnchors,
+    const ProfileAnchorMap &ProfileAnchors,
     LocToLocMap &IRToProfileLocationMap) {
   LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
                     << "\n");
@@ -2422,59 +2342,226 @@ void SampleProfileMatcher::runStaleProfileMatching(
   }
 }
 
-void SampleProfileMatcher::runOnFunction(const Function &F) {
-  // We need to use flattened function samples for matching.
-  // Unlike IR, which includes all callsites from the source code, the callsites
-  // in profile only show up when they are hit by samples, i,e. the profile
-  // callsites in one context may differ from those in another context. To get
-  // the maximum number of callsites, we merge the function profiles from all
-  // contexts, aka, the flattened profile to find profile anchors.
-  const auto *FSFlattened = getFlattenedSamplesFor(F);
-  if (!FSFlattened)
-    return;
+void SampleProfileMatcher::runStaleProfileMatching() {
+  for (const auto &F : M) {
+    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+      continue;
+    const auto *FSFlattened = getFlattenedSamplesFor(F);
+    if (!FSFlattened)
+      continue;
+    auto IR = FuncIRAnchors.find(&F);
+    auto P = FuncProfileAnchors.find(&F);
+    if (IR == FuncIRAnchors.end() || P == FuncProfileAnchors.end())
+      continue;
 
-  // Anchors for IR. It's a map from IR location to callee name, callee name is
-  // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
-  // for unknown indrect callee name.
-  std::map<LineLocation, StringRef> IRAnchors;
-  findIRAnchors(F, IRAnchors);
-  // Anchors for profile. It's a map from callsite location to a set of callee
-  // name.
-  std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
-  findProfileAnchors(*FSFlattened, ProfileAnchors);
-
-  // Detect profile mismatch for profile staleness metrics report.
-  // Skip reporting the metrics for imported functions.
-  if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
-      (ReportProfileStaleness || PersistProfileStaleness)) {
-    // Use top-level nested FS for counting profile mismatch metrics since
-    // currently once a callsite is mismatched, all its children profiles are
-    // dropped.
-    if (const auto *FS = Reader.getSamplesFor(F))
-      countProfileMismatches(F, *FS, IRAnchors, ProfileAnchors);
+    // Run profile matching for checksum mismatched profile, currently only
+    // support for pseudo-probe.
+    if (FunctionSamples::ProfileIsProbeBased &&
+        !ProbeManager->profileIsValid(F, *FSFlattened)) {
+      runStaleProfileMatching(F, IR->second, P->second,
+                              getIRToProfileLocationMap(F));
+    }
   }
 
-  // Run profile matching for checksum mismatched profile, currently only
-  // support for pseudo-probe.
-  if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
-      !ProbeManager->profileIsValid(F, *FSFlattened)) {
-    // The matching result will be saved to IRToProfileLocationMap, create a new
-    // map for each function.
-    runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
-                            getIRToProfileLocationMap(F));
-  }
+  distributeIRToProfileLocationMap();
 }
 
-void SampleProfileMatcher::runOnModule() {
+void SampleProfileMatcher::findFuncAnchors() {
   ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
                                    FunctionSamples::ProfileIsCS);
-  for (auto &F : M) {
+  for (const auto &F : M) {
     if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
       continue;
-    runOnFunction(F);
+    // We need to use flattened function samples for matching.
+    // Unlike IR, which includes all callsites from the source code, the
+    // callsites in profile only show up when they are hit by samples, i,e. the
+    // profile callsites in one context may differ from those in another
+    // context. To get the maximum number of callsites, we merge the function
+    // profiles from all contexts, aka, the flattened profile to find profile
+    // anchors.
+    const auto *FSFlattened = getFlattenedSamplesFor(F);
+    if (!FSFlattened)
+      continue;
+
+    // Anchors for IR. It's a map from IR location to callee name, callee name
+    // is empty for non-call instruction and use a dummy
+    // name(UnknownIndirectCallee) for unknown indrect callee name.
+    auto IR = FuncIRAnchors.emplace(&F, IRAnchorMap());
+    findIRAnchors(F, IR.first->second);
+
+    // Anchors for profile. It's a map from callsite location to a set of callee
+    // name.
+    auto P = FuncProfileAnchors.emplace(&F, ProfileAnchorMap());
+    findProfileAnchors(*FSFlattened, P.first->second);
+  }
+}
+
+void SampleProfileMatcher::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS,
+    StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
+    uint64_t &FuncMismatchedCallsiteSamples) const {
+  auto It = FuncToMismatchCallsites.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncToMismatchCallsites.end() || It->second.empty())
+    return;
+  const auto &MismatchCallsites = It->second;
+  for (const auto &I : FS.getBodySamples()) {
+    if (MismatchCallsites.count(I.first))
+      FuncMismatchedCallsiteSamples += I.second.getSamples();
+  }
+
+  for (const auto &I : FS.getCallsiteSamples()) {
+    const auto &Loc = I.first;
+    if (MismatchCallsites.count(Loc)) {
+      for (const auto &CS : I.second)
+        FuncMismatchedCallsiteSamples += CS.second.getTotalSamples();
+      continue;
+    }
+
+    // count mismatched samples for inlined samples.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second, FuncToMismatchCallsites,
+                                     FuncMismatchedCallsiteSamples);
+  }
+}
+
+void SampleProfileMatcher::countMismatchedCallsites(
+    const Function &F,
+    StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
+    uint64_t &FuncProfiledCallsites, uint64_t &FuncMismatchedCallsites) const {
+  auto IR = FuncIRAnchors.find(&F);
+  auto P = FuncProfileAnchors.find(&F);
+  if (IR == FuncIRAnchors.end() || P == FuncProfileAnchors.end())
+    return;
+  const auto &IRAnchors = IR->second;
+  const auto &ProfileAnchors = P->second;
+
+  auto &MismatchCallsites =
+      FuncToMismatchCallsites[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites, those callsite samples will be discarded.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+
+    StringRef IRCalleeName;
+    const auto &IR = IRAnchors.find(Loc);
+    if (IR != IRAnchors.end())
+      IRCalleeName = IR->second;
+    bool CallsiteIsMatched = false;
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == UnknownIndirectCallee)
+      CallsiteIsMatched = true;
+    else if (Callees.count(FunctionId(IRCalleeName)))
+      CallsiteIsMatched = true;
+
+    FuncProfiledCallsites++;
+    if (!CallsiteIsMatched) {
+      FuncMismatchedCallsites++;
+      MismatchCallsites.insert(Loc);
+    }
+  }
+}
+
+void SampleProfileMatcher::countMismatchedHashes(const Function &F,
+                                                 const FunctionSamples &FS) {
+  if (!FunctionSamples::ProfileIsProbeBased)
+    return;
+  const auto *FuncDesc = ProbeManager->getDesc(F);
+  if (FuncDesc) {
+    if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+      NumMismatchedFuncHash++;
+    }
+    countMismatchedHashSamples(FS);
+  }
+}
+
+void SampleProfileMatcher::UpdateIRAnchors() {
+  for (auto &I : FuncIRAnchors) {
+    const auto *F = I.first;
+    auto &IRAnchors = I.second;
+    const auto Mapping =
+        FuncMappings.find(FunctionSamples::getCanonicalFnName(F->getName()));
+    if (Mapping == FuncMappings.end())
+      continue;
+    IRAnchorMap UpdatedIRAnchors;
+    const auto &LocToLocMapping = Mapping->second;
+    for (const auto L : LocToLocMapping) {
+      UpdatedIRAnchors[L.second] = IRAnchors[L.first];
+      IRAnchors.erase(L.first);
+    }
+
+    for (const auto &IR : UpdatedIRAnchors) {
+      IRAnchors[IR.first] = IR.second;
+    }
+  }
+}
+
+void SampleProfileMatcher::countProfileMismatches(bool IsPreMatch) {
+  if (!ReportProfileStaleness && !PersistProfileStaleness)
+    return;
+
+  if (!IsPreMatch) {
+    // Use the profile matching results to update to the IR anchors.
+    UpdateIRAnchors();
+  }
+
+  uint64_t UnusedCounter = 0;
+  uint64_t *TotalProfiledCallsitesPtr =
+      IsPreMatch ? &TotalProfiledCallsites : &UnusedCounter;
+  uint64_t *NumMismatchedCallsitesPtr =
+      IsPreMatch ? &NumMismatchedCallsites : &PostMatchNumMismatchedCallsites;
+  uint64_t *MismatchedCallsiteSamplesPtr =
+      IsPreMatch ? &MismatchedCallsiteSamples
+                 : &PostMatchMismatchedCallsiteSamples;
+
+  auto SkipFunctionForReport = [](const Function &F) {
+    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+      return true;
+    // Skip reporting the metrics for imported functions.
+    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
+      return true;
+    return false;
+  };
+
+  StringMap<std::set<LineLocation>> FuncToMismatchCallsites;
+  for (const auto &F : M) {
+    if (SkipFunctionForReport(F))
+      continue;
+    const auto *FS = Reader.getSamplesFor(F);
+    if (FS && IsPreMatch) {
+      // Only count the total function metrics once in pre-match time.
+      TotalFuncHashSamples += FS->getTotalSamples();
+      TotalProfiledFunc++;
+      countMismatchedHashes(F, *FS);
+    }
+    countMismatchedCallsites(F, FuncToMismatchCallsites,
+                             *TotalProfiledCallsitesPtr,
+                             *NumMismatchedCallsitesPtr);
+  }
+
+  for (const auto &F : M) {
+    if (SkipFunctionForReport(F))
+      continue;
+    if (const auto *FS = Reader.getSamplesFor(F))
+      countMismatchedCallsiteSamples(*FS, FuncToMismatchCallsites,
+                                     *MismatchedCallsiteSamplesPtr);
+  }
+}
+
+void SampleProfileMatcher::runOnModule() {
+  findFuncAnchors();
+  countProfileMismatches(true);
+
+  if (SalvageStaleProfile) {
+    runStaleProfileMatching();
+    countProfileMismatches(false);
   }
-  if (SalvageStaleProfile)
-    distributeIRToProfileLocationMap();
 
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
@@ -2487,9 +2574,18 @@ void SampleProfileMatcher::runOnModule() {
     errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
            << ")"
            << " of callsites' profile are invalid and "
-           << "(" << MismatchedCallsiteSamples << "/" << TotalCallsiteSamples
+           << "(" << MismatchedCallsiteSamples << "/" << TotalFuncHashSamples
            << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
+    if (SalvageStaleProfile) {
+      errs() << "(" << PostMatchNumMismatchedCallsites << "/"
+             << TotalProfiledCallsites << ")"
+             << " of callsites' profile are invalid and "
+             << "(" << PostMatchMismatchedCallsiteSamples << "/"
+             << TotalFuncHashSamples << ")"
+             << " of samples are discarded due to callsite location mismatch "
+                "after stale profile matching.\n";
+    }
   }
 
   if (PersistProfileStaleness) {
@@ -2497,19 +2593,23 @@ void SampleProfileMatcher::runOnModule() {
     MDBuilder MDB(Ctx);
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
+    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
+    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
+                              MismatchedCallsiteSamples);
+    ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples);
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash);
       ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
+      ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash);
       ProfStatsVec.emplace_back("MismatchedFuncHashSamples",
                                 MismatchedFuncHashSamples);
-      ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples);
     }
-
-    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
-    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
-    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
-                              MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("TotalCallsiteSamples", TotalCallsiteSamples);
+    if (SalvageStaleProfile) {
+      ProfStatsVec.emplace_back("PostMatchNumMismatchedCallsites",
+                                PostMatchNumMismatchedCallsites);
+      ProfStatsVec.emplace_back("PostMatchMismatchedCallsiteSamples",
+                                PostMatchMismatchedCallsiteSamples);
+    }
 
     auto *MD = MDB.createLLVMStats(ProfStatsVec);
     auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
index 818a048b8cabb8..f2a00e789b8b66 100644
--- a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
+++ b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
@@ -2,14 +2,15 @@ main:30:0
  0: 0
  1.1: 0
  3: 10 matched:10
- 4: 10
- 5: 10 bar_mismatch:10
+ 7: 10
  8: 0
- 7: foo:15
+ 4: foo:15
   1: 5
   2: 5
   3: inlinee_mismatch:5
    1: 5
+ 5: bar_mismatch:10
+  1: 10
 bar:10:10
  1: 10
 matched:10:10
diff --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index d86175c02dbb42..e7c5dece1235b5 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -6,9 +6,9 @@
 ; RUN: llvm-objdump --section-headers %t.obj | FileCheck %s --check-prefix=CHECK-OBJ
 ; RUN: llc < %t.ll -filetype=asm -o - | FileCheck %s --check-prefix=CHECK-ASM
 
-; CHECK: (2/3) of callsites' profile are invalid and (25/35) of samples are discarded due to callsite location mismatch.
+; CHECK: (2/4) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 25, !"TotalCallsiteSamples", i64 35}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 15, !"TotalFuncHashSamples", i64 50}
 
 ; CHECK-OBJ: .llvm_stats
 
@@ -20,15 +20,15 @@
 ; CHECK-ASM: .byte 22
 ; CHECK-ASM: .ascii  "TotalProfiledCallsites"
 ; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
+; CHECK-ASM: .ascii  "NA=="
 ; CHECK-ASM: .byte 25
 ; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
 ; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MjU="
+; CHECK-ASM: .ascii  "MTU="
 ; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalCallsiteSamples"
+; CHECK-ASM: .ascii  "TotalFuncHashSamples"
 ; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MzU="
+; CHECK-ASM: .ascii  "NTA="
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
index 29c3a142cc68f8..7f848da74a53ce 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
@@ -4,7 +4,7 @@
 ; RUN: FileCheck %s --input-file %t.ll -check-prefix=CHECK-MD
 
 ; CHECK: (1/1) of functions' profile are invalid and  (6822/6822) of samples are discarded due to function hash mismatch.
-; CHECK: (4/4) of callsites' profile are invalid and (5026/5026) of samples are discarded due to callsite location mismatch.
+; CHECK: (4/4) of callsites' profile are invalid and (5026/6822) of samples are discarded due to callsite location mismatch.
+; CHECK: (0/4) of callsites' profile are invalid and (0/6822) of samples are discarded due to callsite location mismatch after stale profile matching.
 
-
-; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"TotalFuncHashSamples", i64 6822, !"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalCallsiteSamples", i64 5026}
+; CHECK-MD: !{!"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalFuncHashSamples", i64 6822, !"TotalProfiledFunc", i64 1, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"PostMatchNumMismatchedCallsites", i64 0, !"PostMatchMismatchedCallsiteSamples", i64 0}
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
index 4b6edf821376c0..5c5bb1f0fae647 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
@@ -10,45 +10,51 @@
 
 
 ; CHECK: (1/3) of functions' profile are invalid and (10/50) of samples are discarded due to function hash mismatch.
-; CHECK: (2/3) of callsites' profile are invalid and (20/30) of samples are discarded due to callsite location mismatch.
+; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch.
+; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch after stale profile matching.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 3, !"MismatchedFuncHashSamples", i64 10, !"TotalFuncHashSamples", i64 50, !"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalCallsiteSamples", i64 30}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalFuncHashSamples", i64 50, !"TotalProfiledFunc", i64 3, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 10, !"PostMatchNumMismatchedCallsites", i64 2, !"PostMatchMismatchedCallsiteSamples", i64 20}
 
 ; CHECK-OBJ: .llvm_stats
 
-; CHECK-ASM: .section  .llvm_stats,"", at progbits
-; CHECK-ASM: .byte 21
-; CHECK-ASM: .ascii  "NumMismatchedFuncHash"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MQ=="
-; CHECK-ASM: .byte 17
-; CHECK-ASM: .ascii  "TotalProfiledFunc"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedFuncHashSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MTA="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalFuncHashSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "NTA="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "NumMismatchedCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mg=="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "TotalProfiledCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mw=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MjA="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MzA="
+
+; CHECK-ASM: .section	.llvm_stats,"", at progbits
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"NumMismatchedCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mg=="
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"TotalProfiledCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MjA="
+; CHECK-ASM: .byte	20
+; CHECK-ASM: .ascii	"TotalFuncHashSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"NTA="
+; CHECK-ASM: .byte	17
+; CHECK-ASM: .ascii	"TotalProfiledFunc"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	21
+; CHECK-ASM: .ascii	"NumMismatchedFuncHash"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MQ=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedFuncHashSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTA="
+; CHECK-ASM: .byte	31
+; CHECK-ASM: .ascii	"PostMatchNumMismatchedCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mg=="
+; CHECK-ASM: .byte	34
+; CHECK-ASM: .ascii	"PostMatchMismatchedCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MjA="
 
 ; CHECK-NESTED: (1/2) of functions' profile are invalid and (211/311) of samples are discarded due to function hash mismatch.
 

>From c76af25acc4497fb502657f1903834f0af74e052 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Fri, 26 Jan 2024 17:52:12 -0800
Subject: [PATCH 2/9] [CSSPGO] Support post-match profile staleness metrics

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index a7170faa65dc07..c232b9339146a8 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -457,7 +457,6 @@ class SampleProfileMatcher {
   uint64_t MismatchedCallsiteSamples = 0;
   uint64_t PostMatchNumMismatchedCallsites = 0;
   uint64_t PostMatchMismatchedCallsiteSamples = 0;
-  uint64_t TotalCallsiteSamples = 0;
   uint64_t TotalProfiledFunc = 0;
   uint64_t NumMismatchedFuncHash = 0;
   uint64_t MismatchedFuncHashSamples = 0;

>From 4b46e1e73bd9f5e51f70a3acd139b8176f96d693 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Fri, 26 Jan 2024 10:14:35 -0800
Subject: [PATCH 3/9] Encapsulate mismatch counting into a new class
 ProfileMatchStats

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp     | 590 +++++++++---------
 .../Inputs/profile-mismatch.prof              |   1 -
 .../SampleProfile/profile-mismatch.ll         |   4 +-
 .../pseudo-probe-profile-mismatch-thinlto.ll  |   4 +-
 .../pseudo-probe-profile-mismatch.ll          |  19 +-
 5 files changed, 317 insertions(+), 301 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index c232b9339146a8..0743cb8f78204c 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -433,8 +433,43 @@ using CandidateQueue =
     PriorityQueue<InlineCandidate, std::vector<InlineCandidate>,
                   CandidateComparer>;
 
-using IRAnchorMap = std::map<LineLocation, StringRef>;
-using ProfileAnchorMap = std::map<LineLocation, std::unordered_set<FunctionId>>;
+// Profile matching statstics.
+class ProfileMatchStats {
+  const Module &M;
+  SampleProfileReader &Reader;
+  const PseudoProbeManager *ProbeManager;
+
+public:
+  ProfileMatchStats(const Module &M, SampleProfileReader &Reader,
+                    const PseudoProbeManager *ProbeManager)
+      : M(M), Reader(Reader), ProbeManager(ProbeManager) {}
+
+  uint64_t NumMismatchedCallsites = 0;
+  uint64_t TotalProfiledCallsites = 0;
+  uint64_t MismatchedCallsiteSamples = 0;
+  uint64_t NumMismatchedFuncHash = 0;
+  uint64_t TotalProfiledFunc = 0;
+  uint64_t MismatchedFuncHashSamples = 0;
+  uint64_t TotalFunctionSamples = 0;
+
+  // A map from function name to a set of mismatched callsite locations.
+  StringMap<std::set<LineLocation>> FuncMismatchedCallsites;
+
+  void countMismatchedSamples(const FunctionSamples &FS);
+  void countProfileMismatches(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors);
+  void countMismatchedCallsites(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors,
+      const LocToLocMap &IRToProfileLocationMap);
+  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
+  void countMismatchedCallsiteSamples();
+  void copyUnchangedCallsiteMismatches(
+      const StringMap<std::set<LineLocation>> &InputMismatchedCallsites);
+};
 
 // Sample profile matching - fuzzy match.
 class SampleProfileMatcher {
@@ -442,37 +477,27 @@ class SampleProfileMatcher {
   SampleProfileReader &Reader;
   const PseudoProbeManager *ProbeManager;
   SampleProfileMap FlattenedProfiles;
-
-  std::unordered_map<const Function *, IRAnchorMap> FuncIRAnchors;
-  std::unordered_map<const Function *, ProfileAnchorMap> FuncProfileAnchors;
-
   // For each function, the matcher generates a map, of which each entry is a
   // mapping from the source location of current build to the source location in
   // the profile.
   StringMap<LocToLocMap> FuncMappings;
 
-  // Profile mismatching statstics.
-  uint64_t TotalProfiledCallsites = 0;
-  uint64_t NumMismatchedCallsites = 0;
-  uint64_t MismatchedCallsiteSamples = 0;
-  uint64_t PostMatchNumMismatchedCallsites = 0;
-  uint64_t PostMatchMismatchedCallsiteSamples = 0;
-  uint64_t TotalProfiledFunc = 0;
-  uint64_t NumMismatchedFuncHash = 0;
-  uint64_t MismatchedFuncHashSamples = 0;
-  uint64_t TotalFuncHashSamples = 0;
-
-  // A dummy name for unknown indirect callee, used to differentiate from a
-  // non-call instruction that also has an empty callee name.
-  static constexpr const char *UnknownIndirectCallee =
-      "unknown.indirect.callee";
+  ProfileMatchStats PreMatchStats;
+  ProfileMatchStats PostMatchStats;
 
 public:
   SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
                        const PseudoProbeManager *ProbeManager)
-      : M(M), Reader(Reader), ProbeManager(ProbeManager){};
+      : M(M), Reader(Reader), ProbeManager(ProbeManager),
+        PreMatchStats(M, Reader, ProbeManager),
+        PostMatchStats(M, Reader, ProbeManager){};
   void runOnModule();
 
+  // A dummy name for unknown indirect callee, used to differentiate from a
+  // non-call instruction that also has an empty callee name.
+  static constexpr const char *UnknownIndirectCallee =
+      "unknown.indirect.callee";
+
 private:
   FunctionSamples *getFlattenedSamplesFor(const Function &F) {
     StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
@@ -482,22 +507,11 @@ class SampleProfileMatcher {
     return nullptr;
   }
   void runOnFunction(const Function &F);
-  void findFuncAnchors();
-  void UpdateIRAnchors();
-  void findIRAnchors(const Function &F, IRAnchorMap &IRAnchors);
-  void findProfileAnchors(const FunctionSamples &FS,
-                          ProfileAnchorMap &ProfileAnchors);
-  void countMismatchedHashSamples(const FunctionSamples &FS);
-  void countProfileMismatches(bool IsPreMatch);
-  void countMismatchedHashes(const Function &F, const FunctionSamples &FS);
-  void countMismatchedCallsites(
-      const Function &F,
-      StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
-      uint64_t &FuncProfiledCallsites, uint64_t &FuncMismatchedCallsites) const;
-  void countMismatchedCallsiteSamples(
+  void findIRAnchors(const Function &F,
+                     std::map<LineLocation, StringRef> &IRAnchors);
+  void findProfileAnchors(
       const FunctionSamples &FS,
-      StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
-      uint64_t &FuncMismatchedCallsiteSamples) const;
+      std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
   LocToLocMap &getIRToProfileLocationMap(const Function &F) {
     auto Ret = FuncMappings.try_emplace(
         FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -505,10 +519,12 @@ class SampleProfileMatcher {
   }
   void distributeIRToProfileLocationMap();
   void distributeIRToProfileLocationMap(FunctionSamples &FS);
-  void runStaleProfileMatching();
-  void runStaleProfileMatching(const Function &F, const IRAnchorMap &IRAnchors,
-                               const ProfileAnchorMap &ProfileAnchors,
-                               LocToLocMap &IRToProfileLocationMap);
+  void runStaleProfileMatching(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors,
+      LocToLocMap &IRToProfileLocationMap);
+  void reportOrPersistProfileStats();
 };
 
 /// Sample profile pass.
@@ -695,6 +711,10 @@ void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo(
 }
 } // namespace llvm
 
+bool ShouldSkipProfileLoading(const Function &F) {
+  return F.isDeclaration() || !F.hasFnAttribute("use-sample-profile");
+}
+
 ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
   if (FunctionSamples::ProfileIsProbeBased)
     return getProbeWeight(Inst);
@@ -2128,8 +2148,8 @@ bool SampleProfileLoader::doInitialization(Module &M,
   return true;
 }
 
-void SampleProfileMatcher::findIRAnchors(const Function &F,
-                                         IRAnchorMap &IRAnchors) {
+void SampleProfileMatcher::findIRAnchors(
+    const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
   // For inlined code, recover the original callsite and callee by finding the
   // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the
   // top-level frame is "main:1", the callsite is "1" and the callee is "foo".
@@ -2195,8 +2215,7 @@ void SampleProfileMatcher::findIRAnchors(const Function &F,
   }
 }
 
-void SampleProfileMatcher::countMismatchedHashSamples(
-    const FunctionSamples &FS) {
+void ProfileMatchStats::countMismatchedSamples(const FunctionSamples &FS) {
   const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
   // Skip the function that is external or renamed.
   if (!FuncDesc)
@@ -2208,11 +2227,144 @@ void SampleProfileMatcher::countMismatchedHashSamples(
   }
   for (const auto &I : FS.getCallsiteSamples())
     for (const auto &CS : I.second)
-      countMismatchedHashSamples(CS.second);
+      countMismatchedSamples(CS.second);
+}
+
+void ProfileMatchStats::countMismatchedCallsites(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    const LocToLocMap &IRToProfileLocationMap) {
+  auto &MismatchedCallsites =
+      FuncMismatchedCallsites[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    const auto &ProfileLoc = IRToProfileLocationMap.find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap.end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  };
+
+  std::set<LineLocation> MatchedCallsites;
+  for (const auto &I : IRAnchors) {
+    // In post-match, use the matching result to remap the current IR callsite.
+    const auto &Loc = MapIRLocToProfileLoc(I.first);
+    const auto &IRCalleeName = I.second;
+    const auto &It = ProfileAnchors.find(Loc);
+    if (It == ProfileAnchors.end())
+      continue;
+    const auto &Callees = It->second;
+
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
+      MatchedCallsites.insert(Loc);
+    else if (Callees.count(getRepInFormat(IRCalleeName)))
+      MatchedCallsites.insert(Loc);
+  }
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites, those callsite samples will be discarded.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    [[maybe_unused]] const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+    TotalProfiledCallsites++;
+    if (!MatchedCallsites.count(Loc)) {
+      NumMismatchedCallsites++;
+      MismatchedCallsites.insert(Loc);
+    }
+  }
+}
+
+void ProfileMatchStats::countProfileMismatches(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors) {
+  [[maybe_unused]] bool IsFuncHashMismatch = false;
+  // Use top-level nested FS for counting profile mismatch metrics since
+  // currently once a callsite is mismatched, all its children profiles are
+  // dropped.
+  if (const auto *FS = Reader.getSamplesFor(F)) {
+    TotalProfiledFunc++;
+    TotalFunctionSamples += FS->getTotalSamples();
+    if (FunctionSamples::ProfileIsProbeBased) {
+      const auto *FuncDesc = ProbeManager->getDesc(F);
+      if (FuncDesc) {
+        if (ProbeManager->profileIsHashMismatched(*FuncDesc, *FS)) {
+          NumMismatchedFuncHash++;
+          IsFuncHashMismatch = true;
+        }
+        countMismatchedSamples(*FS);
+      }
+    }
+  }
+
+  countMismatchedCallsites(F, IRAnchors, ProfileAnchors, LocToLocMap());
+  LLVM_DEBUG({
+    auto It = FuncMismatchedCallsites.find(
+        FunctionSamples::getCanonicalFnName(F.getName()));
+    if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
+        It != FuncMismatchedCallsites.end() && !It->second.empty())
+      dbgs() << "Function checksum is matched but there are "
+             << It->second.size() << " mismatched callsites.\n";
+  });
+}
+
+void ProfileMatchStats::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS) {
+  auto It = FuncMismatchedCallsites.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncMismatchedCallsites.end() || It->second.empty())
+    return;
+  const auto &MismatchCallsites = It->second;
+
+  for (const auto &I : FS.getBodySamples()) {
+    if (MismatchCallsites.count(I.first))
+      MismatchedCallsiteSamples += I.second.getSamples();
+  }
+
+  for (const auto &I : FS.getCallsiteSamples()) {
+    const auto &Loc = I.first;
+    if (MismatchCallsites.count(Loc)) {
+      for (const auto &CS : I.second)
+        MismatchedCallsiteSamples += CS.second.getTotalSamples();
+      continue;
+    }
+
+    // Count mismatched samples for inlined functions.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second);
+  }
+}
+
+void ProfileMatchStats::countMismatchedCallsiteSamples() {
+  if (FuncMismatchedCallsites.empty())
+    return;
+  for (const auto &F : M) {
+    if (ShouldSkipProfileLoading(F))
+      continue;
+    if (const auto *FS = Reader.getSamplesFor(F))
+      countMismatchedCallsiteSamples(*FS);
+  }
+}
+
+void ProfileMatchStats::copyUnchangedCallsiteMismatches(
+    const StringMap<std::set<LineLocation>> &InputMismatchedCallsites) {
+  for (const auto &I : InputMismatchedCallsites) {
+    auto It = FuncMismatchedCallsites.find(I.first());
+    if (It != FuncMismatchedCallsites.end())
+      continue;
+    FuncMismatchedCallsites.try_emplace(I.first(), I.second);
+  }
 }
 
 void SampleProfileMatcher::findProfileAnchors(
-    const FunctionSamples &FS, ProfileAnchorMap &ProfileAnchors) {
+    const FunctionSamples &FS,
+    std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
   auto isInvalidLineOffset = [](uint32_t LineOffset) {
     return LineOffset & 0x8000;
   };
@@ -2259,8 +2411,9 @@ void SampleProfileMatcher::findProfileAnchors(
 //   [1, 2, 3(foo), 4,  7,  8(bar), 9]
 // The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
 void SampleProfileMatcher::runStaleProfileMatching(
-    const Function &F, const IRAnchorMap &IRAnchors,
-    const ProfileAnchorMap &ProfileAnchors,
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
     LocToLocMap &IRToProfileLocationMap) {
   LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
                     << "\n");
@@ -2341,249 +2494,79 @@ void SampleProfileMatcher::runStaleProfileMatching(
   }
 }
 
-void SampleProfileMatcher::runStaleProfileMatching() {
-  for (const auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
-      continue;
-    const auto *FSFlattened = getFlattenedSamplesFor(F);
-    if (!FSFlattened)
-      continue;
-    auto IR = FuncIRAnchors.find(&F);
-    auto P = FuncProfileAnchors.find(&F);
-    if (IR == FuncIRAnchors.end() || P == FuncProfileAnchors.end())
-      continue;
-
-    // Run profile matching for checksum mismatched profile, currently only
-    // support for pseudo-probe.
-    if (FunctionSamples::ProfileIsProbeBased &&
-        !ProbeManager->profileIsValid(F, *FSFlattened)) {
-      runStaleProfileMatching(F, IR->second, P->second,
-                              getIRToProfileLocationMap(F));
-    }
-  }
-
-  distributeIRToProfileLocationMap();
-}
-
-void SampleProfileMatcher::findFuncAnchors() {
-  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
-                                   FunctionSamples::ProfileIsCS);
-  for (const auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
-      continue;
-    // We need to use flattened function samples for matching.
-    // Unlike IR, which includes all callsites from the source code, the
-    // callsites in profile only show up when they are hit by samples, i,e. the
-    // profile callsites in one context may differ from those in another
-    // context. To get the maximum number of callsites, we merge the function
-    // profiles from all contexts, aka, the flattened profile to find profile
-    // anchors.
-    const auto *FSFlattened = getFlattenedSamplesFor(F);
-    if (!FSFlattened)
-      continue;
-
-    // Anchors for IR. It's a map from IR location to callee name, callee name
-    // is empty for non-call instruction and use a dummy
-    // name(UnknownIndirectCallee) for unknown indrect callee name.
-    auto IR = FuncIRAnchors.emplace(&F, IRAnchorMap());
-    findIRAnchors(F, IR.first->second);
-
-    // Anchors for profile. It's a map from callsite location to a set of callee
-    // name.
-    auto P = FuncProfileAnchors.emplace(&F, ProfileAnchorMap());
-    findProfileAnchors(*FSFlattened, P.first->second);
-  }
-}
-
-void SampleProfileMatcher::countMismatchedCallsiteSamples(
-    const FunctionSamples &FS,
-    StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
-    uint64_t &FuncMismatchedCallsiteSamples) const {
-  auto It = FuncToMismatchCallsites.find(FS.getFuncName());
-  // Skip it if no mismatched callsite or this is an external function.
-  if (It == FuncToMismatchCallsites.end() || It->second.empty())
-    return;
-  const auto &MismatchCallsites = It->second;
-  for (const auto &I : FS.getBodySamples()) {
-    if (MismatchCallsites.count(I.first))
-      FuncMismatchedCallsiteSamples += I.second.getSamples();
-  }
-
-  for (const auto &I : FS.getCallsiteSamples()) {
-    const auto &Loc = I.first;
-    if (MismatchCallsites.count(Loc)) {
-      for (const auto &CS : I.second)
-        FuncMismatchedCallsiteSamples += CS.second.getTotalSamples();
-      continue;
-    }
-
-    // count mismatched samples for inlined samples.
-    for (const auto &CS : I.second)
-      countMismatchedCallsiteSamples(CS.second, FuncToMismatchCallsites,
-                                     FuncMismatchedCallsiteSamples);
-  }
-}
-
-void SampleProfileMatcher::countMismatchedCallsites(
-    const Function &F,
-    StringMap<std::set<LineLocation>> &FuncToMismatchCallsites,
-    uint64_t &FuncProfiledCallsites, uint64_t &FuncMismatchedCallsites) const {
-  auto IR = FuncIRAnchors.find(&F);
-  auto P = FuncProfileAnchors.find(&F);
-  if (IR == FuncIRAnchors.end() || P == FuncProfileAnchors.end())
-    return;
-  const auto &IRAnchors = IR->second;
-  const auto &ProfileAnchors = P->second;
-
-  auto &MismatchCallsites =
-      FuncToMismatchCallsites[FunctionSamples::getCanonicalFnName(F.getName())];
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    const auto &Callees = I.second;
-    assert(!Callees.empty() && "Callees should not be empty");
-
-    StringRef IRCalleeName;
-    const auto &IR = IRAnchors.find(Loc);
-    if (IR != IRAnchors.end())
-      IRCalleeName = IR->second;
-    bool CallsiteIsMatched = false;
-    // Since indirect call does not have CalleeName, check conservatively if
-    // callsite in the profile is a callsite location. This is to reduce num of
-    // false positive since otherwise all the indirect call samples will be
-    // reported as mismatching.
-    if (IRCalleeName == UnknownIndirectCallee)
-      CallsiteIsMatched = true;
-    else if (Callees.count(FunctionId(IRCalleeName)))
-      CallsiteIsMatched = true;
-
-    FuncProfiledCallsites++;
-    if (!CallsiteIsMatched) {
-      FuncMismatchedCallsites++;
-      MismatchCallsites.insert(Loc);
-    }
-  }
-}
-
-void SampleProfileMatcher::countMismatchedHashes(const Function &F,
-                                                 const FunctionSamples &FS) {
-  if (!FunctionSamples::ProfileIsProbeBased)
+void SampleProfileMatcher::runOnFunction(const Function &F) {
+  // We need to use flattened function samples for matching.
+  // Unlike IR, which includes all callsites from the source code, the callsites
+  // in profile only show up when they are hit by samples, i,e. the profile
+  // callsites in one context may differ from those in another context. To get
+  // the maximum number of callsites, we merge the function profiles from all
+  // contexts, aka, the flattened profile to find profile anchors.
+  const auto *FSFlattened = getFlattenedSamplesFor(F);
+  if (!FSFlattened)
     return;
-  const auto *FuncDesc = ProbeManager->getDesc(F);
-  if (FuncDesc) {
-    if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-      NumMismatchedFuncHash++;
-    }
-    countMismatchedHashSamples(FS);
-  }
-}
-
-void SampleProfileMatcher::UpdateIRAnchors() {
-  for (auto &I : FuncIRAnchors) {
-    const auto *F = I.first;
-    auto &IRAnchors = I.second;
-    const auto Mapping =
-        FuncMappings.find(FunctionSamples::getCanonicalFnName(F->getName()));
-    if (Mapping == FuncMappings.end())
-      continue;
-    IRAnchorMap UpdatedIRAnchors;
-    const auto &LocToLocMapping = Mapping->second;
-    for (const auto L : LocToLocMapping) {
-      UpdatedIRAnchors[L.second] = IRAnchors[L.first];
-      IRAnchors.erase(L.first);
-    }
-
-    for (const auto &IR : UpdatedIRAnchors) {
-      IRAnchors[IR.first] = IR.second;
-    }
-  }
-}
-
-void SampleProfileMatcher::countProfileMismatches(bool IsPreMatch) {
-  if (!ReportProfileStaleness && !PersistProfileStaleness)
-    return;
-
-  if (!IsPreMatch) {
-    // Use the profile matching results to update to the IR anchors.
-    UpdateIRAnchors();
-  }
-
-  uint64_t UnusedCounter = 0;
-  uint64_t *TotalProfiledCallsitesPtr =
-      IsPreMatch ? &TotalProfiledCallsites : &UnusedCounter;
-  uint64_t *NumMismatchedCallsitesPtr =
-      IsPreMatch ? &NumMismatchedCallsites : &PostMatchNumMismatchedCallsites;
-  uint64_t *MismatchedCallsiteSamplesPtr =
-      IsPreMatch ? &MismatchedCallsiteSamples
-                 : &PostMatchMismatchedCallsiteSamples;
-
-  auto SkipFunctionForReport = [](const Function &F) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
-      return true;
-    // Skip reporting the metrics for imported functions.
-    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
-      return true;
-    return false;
-  };
 
-  StringMap<std::set<LineLocation>> FuncToMismatchCallsites;
-  for (const auto &F : M) {
-    if (SkipFunctionForReport(F))
-      continue;
-    const auto *FS = Reader.getSamplesFor(F);
-    if (FS && IsPreMatch) {
-      // Only count the total function metrics once in pre-match time.
-      TotalFuncHashSamples += FS->getTotalSamples();
-      TotalProfiledFunc++;
-      countMismatchedHashes(F, *FS);
-    }
-    countMismatchedCallsites(F, FuncToMismatchCallsites,
-                             *TotalProfiledCallsitesPtr,
-                             *NumMismatchedCallsitesPtr);
-  }
-
-  for (const auto &F : M) {
-    if (SkipFunctionForReport(F))
-      continue;
-    if (const auto *FS = Reader.getSamplesFor(F))
-      countMismatchedCallsiteSamples(*FS, FuncToMismatchCallsites,
-                                     *MismatchedCallsiteSamplesPtr);
+  // Anchors for IR. It's a map from IR location to callee name, callee name is
+  // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
+  // for unknown indrect callee name.
+  std::map<LineLocation, StringRef> IRAnchors;
+  findIRAnchors(F, IRAnchors);
+  // Anchors for profile. It's a map from callsite location to a set of callee
+  // name.
+  std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
+  findProfileAnchors(*FSFlattened, ProfileAnchors);
+
+  // Detect profile mismatch for profile staleness metrics report.
+  // Skip reporting the metrics for imported functions.
+  if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
+      (ReportProfileStaleness || PersistProfileStaleness)) {
+    PreMatchStats.countProfileMismatches(F, IRAnchors, ProfileAnchors);
+  }
+
+  // Run profile matching for checksum mismatched profile, currently only
+  // support for pseudo-probe.
+  if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
+      !ProbeManager->profileIsValid(F, *FSFlattened)) {
+    // The matching result will be saved to IRToProfileLocationMap, create a new
+    // map for each function.
+    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
+    runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
+                            IRToProfileLocationMap);
+    PostMatchStats.countMismatchedCallsites(F, IRAnchors, ProfileAnchors,
+                                            IRToProfileLocationMap);
   }
 }
 
-void SampleProfileMatcher::runOnModule() {
-  findFuncAnchors();
-  countProfileMismatches(true);
-
-  if (SalvageStaleProfile) {
-    runStaleProfileMatching();
-    countProfileMismatches(false);
-  }
-
+void SampleProfileMatcher::reportOrPersistProfileStats() {
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")"
+      errs() << "(" << PreMatchStats.NumMismatchedFuncHash << "/"
+             << PreMatchStats.TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
-             << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples
-             << ")"
+             << " (" << PreMatchStats.MismatchedFuncHashSamples << "/"
+             << PreMatchStats.TotalFunctionSamples << ")"
              << " of samples are discarded due to function hash mismatch.\n";
     }
-    errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
-           << ")"
+    errs() << "(" << PreMatchStats.NumMismatchedCallsites << "/"
+           << PreMatchStats.TotalProfiledCallsites << ")"
            << " of callsites' profile are invalid and "
-           << "(" << MismatchedCallsiteSamples << "/" << TotalFuncHashSamples
-           << ")"
+           << "(" << PreMatchStats.MismatchedCallsiteSamples << "/"
+           << PreMatchStats.TotalFunctionSamples << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
     if (SalvageStaleProfile) {
-      errs() << "(" << PostMatchNumMismatchedCallsites << "/"
-             << TotalProfiledCallsites << ")"
-             << " of callsites' profile are invalid and "
-             << "(" << PostMatchMismatchedCallsiteSamples << "/"
-             << TotalFuncHashSamples << ")"
-             << " of samples are discarded due to callsite location mismatch "
-                "after stale profile matching.\n";
+      uint64_t NumRecoveredCallsites = PostMatchStats.TotalProfiledCallsites -
+                                       PostMatchStats.NumMismatchedCallsites;
+      uint64_t NumMismatchedCallsites =
+          PreMatchStats.NumMismatchedCallsites - NumRecoveredCallsites;
+      errs() << "Out of " << PostMatchStats.TotalProfiledCallsites
+             << " callsites used for profile matching, "
+             << NumRecoveredCallsites
+             << " callsites have been recovered. After the matching, ("
+             << NumMismatchedCallsites << "/"
+             << PreMatchStats.TotalProfiledCallsites
+             << ") of callsites are still invalid ("
+             << PostMatchStats.MismatchedCallsiteSamples << "/"
+             << PreMatchStats.TotalFunctionSamples << ")"
+             << " of samples are still discarded.\n";
     }
   }
 
@@ -2592,22 +2575,29 @@ void SampleProfileMatcher::runOnModule() {
     MDBuilder MDB(Ctx);
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
-    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
-    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
+    ProfStatsVec.emplace_back("NumMismatchedCallsites",
+                              PreMatchStats.NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("TotalProfiledCallsites",
+                              PreMatchStats.TotalProfiledCallsites);
     ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
-                              MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples);
+                              PreMatchStats.MismatchedCallsiteSamples);
+    ProfStatsVec.emplace_back("TotalProfiledFunc",
+                              PreMatchStats.TotalProfiledFunc);
+    ProfStatsVec.emplace_back("TotalFunctionSamples",
+                              PreMatchStats.TotalFunctionSamples);
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
-      ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash);
+      ProfStatsVec.emplace_back("NumMismatchedFuncHash",
+                                PreMatchStats.NumMismatchedFuncHash);
       ProfStatsVec.emplace_back("MismatchedFuncHashSamples",
-                                MismatchedFuncHashSamples);
+                                PreMatchStats.MismatchedFuncHashSamples);
     }
     if (SalvageStaleProfile) {
       ProfStatsVec.emplace_back("PostMatchNumMismatchedCallsites",
-                                PostMatchNumMismatchedCallsites);
+                                PostMatchStats.NumMismatchedCallsites);
+      ProfStatsVec.emplace_back("NumCallsitesForMatching",
+                                PostMatchStats.TotalProfiledCallsites);
       ProfStatsVec.emplace_back("PostMatchMismatchedCallsiteSamples",
-                                PostMatchMismatchedCallsiteSamples);
+                                PostMatchStats.MismatchedCallsiteSamples);
     }
 
     auto *MD = MDB.createLLVMStats(ProfStatsVec);
@@ -2616,6 +2606,30 @@ void SampleProfileMatcher::runOnModule() {
   }
 }
 
+void SampleProfileMatcher::runOnModule() {
+  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+                                   FunctionSamples::ProfileIsCS);
+  for (auto &F : M) {
+    if (ShouldSkipProfileLoading(F))
+      continue;
+    runOnFunction(F);
+  }
+
+  if (SalvageStaleProfile)
+    distributeIRToProfileLocationMap();
+
+  PreMatchStats.countMismatchedCallsiteSamples();
+  if (SalvageStaleProfile) {
+    // If a function doesn't run the matching but has mismatched callsites, this
+    // won't be any data for that function in post-match stats, so just reuse
+    // the pre-match stats.
+    PostMatchStats.copyUnchangedCallsiteMismatches(
+        PreMatchStats.FuncMismatchedCallsites);
+    PostMatchStats.countMismatchedCallsiteSamples();
+  }
+  reportOrPersistProfileStats();
+}
+
 void SampleProfileMatcher::distributeIRToProfileLocationMap(
     FunctionSamples &FS) {
   const auto ProfileMappings = FuncMappings.find(FS.getFuncName());
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
index f2a00e789b8b66..241d0914a37641 100644
--- a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
+++ b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
@@ -14,4 +14,3 @@ main:30:0
 bar:10:10
  1: 10
 matched:10:10
- 1: 10
diff --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index e7c5dece1235b5..14e384d7964ab0 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -8,7 +8,7 @@
 
 ; CHECK: (2/4) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 15, !"TotalFuncHashSamples", i64 50}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 15, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50}
 
 ; CHECK-OBJ: .llvm_stats
 
@@ -26,7 +26,7 @@
 ; CHECK-ASM: .byte 4
 ; CHECK-ASM: .ascii  "MTU="
 ; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalFuncHashSamples"
+; CHECK-ASM: .ascii  "TotalFunctionSamples"
 ; CHECK-ASM: .byte 4
 ; CHECK-ASM: .ascii  "NTA="
 
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
index 7f848da74a53ce..768fe5509f33a9 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
@@ -5,6 +5,6 @@
 
 ; CHECK: (1/1) of functions' profile are invalid and  (6822/6822) of samples are discarded due to function hash mismatch.
 ; CHECK: (4/4) of callsites' profile are invalid and (5026/6822) of samples are discarded due to callsite location mismatch.
-; CHECK: (0/4) of callsites' profile are invalid and (0/6822) of samples are discarded due to callsite location mismatch after stale profile matching.
+; CHECK: Out of 4 callsites used for profile matching, 4 callsites have been recovered. After the matching, (0/4) of callsites are still invalid (0/6822) of samples are still discarded.
 
-; CHECK-MD: !{!"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalFuncHashSamples", i64 6822, !"TotalProfiledFunc", i64 1, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"PostMatchNumMismatchedCallsites", i64 0, !"PostMatchMismatchedCallsiteSamples", i64 0}
+; CHECK-MD: !{!"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalProfiledFunc", i64 1, !"TotalFunctionSamples", i64 6822, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"PostMatchNumMismatchedCallsites", i64 0, !"NumCallsitesForMatching", i64 4, !"PostMatchMismatchedCallsiteSamples", i64 0}
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
index 5c5bb1f0fae647..9949b5fd41f407 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
@@ -11,13 +11,12 @@
 
 ; CHECK: (1/3) of functions' profile are invalid and (10/50) of samples are discarded due to function hash mismatch.
 ; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch.
-; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch after stale profile matching.
+; CHECK: Out of 0 callsites used for profile matching, 0 callsites have been recovered. After the matching, (2/3) of callsites are still invalid (20/50) of samples are still discarded.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalFuncHashSamples", i64 50, !"TotalProfiledFunc", i64 3, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 10, !"PostMatchNumMismatchedCallsites", i64 2, !"PostMatchMismatchedCallsiteSamples", i64 20}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 10, !"PostMatchNumMismatchedCallsites", i64 0, !"NumCallsitesForMatching", i64 0, !"PostMatchMismatchedCallsiteSamples", i64 20}
 
 ; CHECK-OBJ: .llvm_stats
 
-
 ; CHECK-ASM: .section	.llvm_stats,"", at progbits
 ; CHECK-ASM: .byte	22
 ; CHECK-ASM: .ascii	"NumMismatchedCallsites"
@@ -31,14 +30,14 @@
 ; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"MjA="
-; CHECK-ASM: .byte	20
-; CHECK-ASM: .ascii	"TotalFuncHashSamples"
-; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"NTA="
 ; CHECK-ASM: .byte	17
 ; CHECK-ASM: .ascii	"TotalProfiledFunc"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	20
+; CHECK-ASM: .ascii	"TotalFunctionSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"NTA="
 ; CHECK-ASM: .byte	21
 ; CHECK-ASM: .ascii	"NumMismatchedFuncHash"
 ; CHECK-ASM: .byte	4
@@ -50,7 +49,11 @@
 ; CHECK-ASM: .byte	31
 ; CHECK-ASM: .ascii	"PostMatchNumMismatchedCallsites"
 ; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"Mg=="
+; CHECK-ASM: .ascii	"MA=="
+; CHECK-ASM: .byte	23
+; CHECK-ASM: .ascii	"NumCallsitesForMatching"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
 ; CHECK-ASM: .byte	34
 ; CHECK-ASM: .ascii	"PostMatchMismatchedCallsiteSamples"
 ; CHECK-ASM: .byte	4

>From f32a98a2826dd052050dea19ac2689026354b6e0 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Mon, 29 Jan 2024 11:18:11 -0800
Subject: [PATCH 4/9] tmp

---
 llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof | 1 +
 1 file changed, 1 insertion(+)

diff --git a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
index 241d0914a37641..f2a00e789b8b66 100644
--- a/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
+++ b/llvm/test/Transforms/SampleProfile/Inputs/profile-mismatch.prof
@@ -14,3 +14,4 @@ main:30:0
 bar:10:10
  1: 10
 matched:10:10
+ 1: 10

>From 6c9f4fc4574f8db88d2b7f636a65e758aadb5ecd Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Wed, 31 Jan 2024 15:50:44 -0800
Subject: [PATCH 5/9] [PseudoProbe] Compute and report profile matching
 recovered callsites and samples

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp     | 523 +++++++++---------
 .../SampleProfile/profile-mismatch.ll         |   2 +-
 .../pseudo-probe-profile-mismatch-thinlto.ll  |   4 +-
 .../pseudo-probe-profile-mismatch.ll          |   2 +-
 4 files changed, 266 insertions(+), 265 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0743cb8f78204c..9ba23d8f490803 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -433,44 +433,6 @@ using CandidateQueue =
     PriorityQueue<InlineCandidate, std::vector<InlineCandidate>,
                   CandidateComparer>;
 
-// Profile matching statstics.
-class ProfileMatchStats {
-  const Module &M;
-  SampleProfileReader &Reader;
-  const PseudoProbeManager *ProbeManager;
-
-public:
-  ProfileMatchStats(const Module &M, SampleProfileReader &Reader,
-                    const PseudoProbeManager *ProbeManager)
-      : M(M), Reader(Reader), ProbeManager(ProbeManager) {}
-
-  uint64_t NumMismatchedCallsites = 0;
-  uint64_t TotalProfiledCallsites = 0;
-  uint64_t MismatchedCallsiteSamples = 0;
-  uint64_t NumMismatchedFuncHash = 0;
-  uint64_t TotalProfiledFunc = 0;
-  uint64_t MismatchedFuncHashSamples = 0;
-  uint64_t TotalFunctionSamples = 0;
-
-  // A map from function name to a set of mismatched callsite locations.
-  StringMap<std::set<LineLocation>> FuncMismatchedCallsites;
-
-  void countMismatchedSamples(const FunctionSamples &FS);
-  void countProfileMismatches(
-      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors);
-  void countMismatchedCallsites(
-      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors,
-      const LocToLocMap &IRToProfileLocationMap);
-  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
-  void countMismatchedCallsiteSamples();
-  void copyUnchangedCallsiteMismatches(
-      const StringMap<std::set<LineLocation>> &InputMismatchedCallsites);
-};
-
 // Sample profile matching - fuzzy match.
 class SampleProfileMatcher {
   Module &M;
@@ -482,22 +444,46 @@ class SampleProfileMatcher {
   // the profile.
   StringMap<LocToLocMap> FuncMappings;
 
-  ProfileMatchStats PreMatchStats;
-  ProfileMatchStats PostMatchStats;
+  // Match state for an anchor/callsite.
+  enum class MatchState {
+    Matched = 0,
+    Mismatched = 0x1,
+    Recovered = 0x1,
+  };
 
-public:
-  SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
-                       const PseudoProbeManager *ProbeManager)
-      : M(M), Reader(Reader), ProbeManager(ProbeManager),
-        PreMatchStats(M, Reader, ProbeManager),
-        PostMatchStats(M, Reader, ProbeManager){};
-  void runOnModule();
+  // For each function, store every callsite state into a map, of which each
+  // entry is a pair of callsite location and MatchState. This is used for
+  // profile stalness computation and report.
+  StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
+      FuncCallsiteMatchStates;
+
+  /// Profile mismatch statstics:
+  uint64_t TotalProfiledFunc = 0;
+  // Num of function whose checksum is mismatched.
+  uint64_t NumMismatchedFunc = 0;
+  uint64_t TotalProfiledCallsites = 0;
+  uint64_t NumMismatchedCallsites = 0;
+  uint64_t NumRecoveredCallsites = 0;
+
+  /// Weigted profile samples mismatch statstics:
+  uint64_t TotalFunctionSamples = 0;
+  // Samples for the mismatched checksum functions;
+  uint64_t MismatchedFunctionSamples = 0;
+
+  uint64_t MismatchedCallsiteSamples = 0;
+  uint64_t RecoveredCallsiteSamples = 0;
 
   // A dummy name for unknown indirect callee, used to differentiate from a
   // non-call instruction that also has an empty callee name.
   static constexpr const char *UnknownIndirectCallee =
       "unknown.indirect.callee";
 
+public:
+  SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
+                       const PseudoProbeManager *ProbeManager)
+      : M(M), Reader(Reader), ProbeManager(ProbeManager){};
+  void runOnModule();
+
 private:
   FunctionSamples *getFlattenedSamplesFor(const Function &F) {
     StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
@@ -512,6 +498,24 @@ class SampleProfileMatcher {
   void findProfileAnchors(
       const FunctionSamples &FS,
       std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
+  // Compute the callsite match states for profile staleness report, the result
+  // is save in FuncCallsiteMatchStates.
+  void computeCallsiteMatchStates(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors,
+      const LocToLocMap &IRToProfileLocationMap);
+
+  // Count the samples of checksum mismatched function for the top-level
+  // function and all inlinees.
+  void countMismatchedFuncSamples(const FunctionSamples &FS);
+  // Count the number of mismatched or recovered callsites.
+  void countMismatchCallsites(const FunctionSamples &FS);
+  // Count the samples of mismatched or recovered callsites for top-level
+  // function and all inlinees.
+  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
+  void computeAndReportProfileStaleness();
+
   LocToLocMap &getIRToProfileLocationMap(const Function &F) {
     auto Ret = FuncMappings.try_emplace(
         FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
@@ -2215,153 +2219,6 @@ void SampleProfileMatcher::findIRAnchors(
   }
 }
 
-void ProfileMatchStats::countMismatchedSamples(const FunctionSamples &FS) {
-  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
-  // Skip the function that is external or renamed.
-  if (!FuncDesc)
-    return;
-
-  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-    MismatchedFuncHashSamples += FS.getTotalSamples();
-    return;
-  }
-  for (const auto &I : FS.getCallsiteSamples())
-    for (const auto &CS : I.second)
-      countMismatchedSamples(CS.second);
-}
-
-void ProfileMatchStats::countMismatchedCallsites(
-    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
-    const LocToLocMap &IRToProfileLocationMap) {
-  auto &MismatchedCallsites =
-      FuncMismatchedCallsites[FunctionSamples::getCanonicalFnName(F.getName())];
-
-  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
-    const auto &ProfileLoc = IRToProfileLocationMap.find(IRLoc);
-    if (ProfileLoc != IRToProfileLocationMap.end())
-      return ProfileLoc->second;
-    else
-      return IRLoc;
-  };
-
-  std::set<LineLocation> MatchedCallsites;
-  for (const auto &I : IRAnchors) {
-    // In post-match, use the matching result to remap the current IR callsite.
-    const auto &Loc = MapIRLocToProfileLoc(I.first);
-    const auto &IRCalleeName = I.second;
-    const auto &It = ProfileAnchors.find(Loc);
-    if (It == ProfileAnchors.end())
-      continue;
-    const auto &Callees = It->second;
-
-    // Since indirect call does not have CalleeName, check conservatively if
-    // callsite in the profile is a callsite location. This is to reduce num of
-    // false positive since otherwise all the indirect call samples will be
-    // reported as mismatching.
-    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
-      MatchedCallsites.insert(Loc);
-    else if (Callees.count(getRepInFormat(IRCalleeName)))
-      MatchedCallsites.insert(Loc);
-  }
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    [[maybe_unused]] const auto &Callees = I.second;
-    assert(!Callees.empty() && "Callees should not be empty");
-    TotalProfiledCallsites++;
-    if (!MatchedCallsites.count(Loc)) {
-      NumMismatchedCallsites++;
-      MismatchedCallsites.insert(Loc);
-    }
-  }
-}
-
-void ProfileMatchStats::countProfileMismatches(
-    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors) {
-  [[maybe_unused]] bool IsFuncHashMismatch = false;
-  // Use top-level nested FS for counting profile mismatch metrics since
-  // currently once a callsite is mismatched, all its children profiles are
-  // dropped.
-  if (const auto *FS = Reader.getSamplesFor(F)) {
-    TotalProfiledFunc++;
-    TotalFunctionSamples += FS->getTotalSamples();
-    if (FunctionSamples::ProfileIsProbeBased) {
-      const auto *FuncDesc = ProbeManager->getDesc(F);
-      if (FuncDesc) {
-        if (ProbeManager->profileIsHashMismatched(*FuncDesc, *FS)) {
-          NumMismatchedFuncHash++;
-          IsFuncHashMismatch = true;
-        }
-        countMismatchedSamples(*FS);
-      }
-    }
-  }
-
-  countMismatchedCallsites(F, IRAnchors, ProfileAnchors, LocToLocMap());
-  LLVM_DEBUG({
-    auto It = FuncMismatchedCallsites.find(
-        FunctionSamples::getCanonicalFnName(F.getName()));
-    if (FunctionSamples::ProfileIsProbeBased && !IsFuncHashMismatch &&
-        It != FuncMismatchedCallsites.end() && !It->second.empty())
-      dbgs() << "Function checksum is matched but there are "
-             << It->second.size() << " mismatched callsites.\n";
-  });
-}
-
-void ProfileMatchStats::countMismatchedCallsiteSamples(
-    const FunctionSamples &FS) {
-  auto It = FuncMismatchedCallsites.find(FS.getFuncName());
-  // Skip it if no mismatched callsite or this is an external function.
-  if (It == FuncMismatchedCallsites.end() || It->second.empty())
-    return;
-  const auto &MismatchCallsites = It->second;
-
-  for (const auto &I : FS.getBodySamples()) {
-    if (MismatchCallsites.count(I.first))
-      MismatchedCallsiteSamples += I.second.getSamples();
-  }
-
-  for (const auto &I : FS.getCallsiteSamples()) {
-    const auto &Loc = I.first;
-    if (MismatchCallsites.count(Loc)) {
-      for (const auto &CS : I.second)
-        MismatchedCallsiteSamples += CS.second.getTotalSamples();
-      continue;
-    }
-
-    // Count mismatched samples for inlined functions.
-    for (const auto &CS : I.second)
-      countMismatchedCallsiteSamples(CS.second);
-  }
-}
-
-void ProfileMatchStats::countMismatchedCallsiteSamples() {
-  if (FuncMismatchedCallsites.empty())
-    return;
-  for (const auto &F : M) {
-    if (ShouldSkipProfileLoading(F))
-      continue;
-    if (const auto *FS = Reader.getSamplesFor(F))
-      countMismatchedCallsiteSamples(*FS);
-  }
-}
-
-void ProfileMatchStats::copyUnchangedCallsiteMismatches(
-    const StringMap<std::set<LineLocation>> &InputMismatchedCallsites) {
-  for (const auto &I : InputMismatchedCallsites) {
-    auto It = FuncMismatchedCallsites.find(I.first());
-    if (It != FuncMismatchedCallsites.end())
-      continue;
-    FuncMismatchedCallsites.try_emplace(I.first(), I.second);
-  }
-}
-
 void SampleProfileMatcher::findProfileAnchors(
     const FunctionSamples &FS,
     std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
@@ -2515,12 +2372,9 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
   std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
   findProfileAnchors(*FSFlattened, ProfileAnchors);
 
-  // Detect profile mismatch for profile staleness metrics report.
-  // Skip reporting the metrics for imported functions.
-  if (!GlobalValue::isAvailableExternallyLinkage(F.getLinkage()) &&
-      (ReportProfileStaleness || PersistProfileStaleness)) {
-    PreMatchStats.countProfileMismatches(F, IRAnchors, ProfileAnchors);
-  }
+  // Compute the callsite match states for profile staleness report.
+  if (ReportProfileStaleness || PersistProfileStaleness)
+    computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors, LocToLocMap());
 
   // Run profile matching for checksum mismatched profile, currently only
   // support for pseudo-probe.
@@ -2531,43 +2385,209 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
     auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
     runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
                             IRToProfileLocationMap);
-    PostMatchStats.countMismatchedCallsites(F, IRAnchors, ProfileAnchors,
-                                            IRToProfileLocationMap);
+    // Find and update callsite match states after matching.
+    if ((ReportProfileStaleness || PersistProfileStaleness) &&
+        !IRToProfileLocationMap.empty())
+      computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
+                                 IRToProfileLocationMap);
+  }
+}
+
+void SampleProfileMatcher::computeCallsiteMatchStates(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    const LocToLocMap &IRToProfileLocationMap) {
+  // Use the matching result to determine if it's in post-match phrase.
+  bool IsPostMatch = !IRToProfileLocationMap.empty();
+  auto &MismatchedCallsites =
+      FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    const auto &ProfileLoc = IRToProfileLocationMap.find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap.end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  };
+
+  std::set<LineLocation> MatchedCallsites;
+  for (const auto &I : IRAnchors) {
+    // In post-match, use the matching result to remap the current IR callsite.
+    const auto &Loc = MapIRLocToProfileLoc(I.first);
+    const auto &IRCalleeName = I.second;
+    const auto &It = ProfileAnchors.find(Loc);
+    if (It == ProfileAnchors.end())
+      continue;
+    const auto &Callees = It->second;
+
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
+      MatchedCallsites.insert(Loc);
+    // TODO : Ideally, we should ensure it's a direct callsite location(Callees
+    // size is 1). However, there may be a bug for profile merge(like ODR
+    // violation) that causes the callees size to be more than 1. After we fix
+    // the bug, we can remove this check.
+    else if (Callees.count(getRepInFormat(IRCalleeName)))
+      MatchedCallsites.insert(Loc);
+  }
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites, those callsite samples will be discarded.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    [[maybe_unused]] const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+    if (IsPostMatch) {
+      if (MatchedCallsites.count(Loc)) {
+        auto It = MismatchedCallsites.find(Loc);
+        if (It != MismatchedCallsites.end() &&
+            It->second == MatchState::Mismatched)
+          MismatchedCallsites.emplace(Loc, MatchState::Recovered);
+      } else
+        MismatchedCallsites.emplace(Loc, MatchState::Mismatched);
+    } else {
+      if (MatchedCallsites.count(Loc))
+        MismatchedCallsites.emplace(Loc, MatchState::Matched);
+      else
+        MismatchedCallsites.emplace(Loc, MatchState::Mismatched);
+    }
+  }
+}
+
+void SampleProfileMatcher::countMismatchedFuncSamples(
+    const FunctionSamples &FS) {
+  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
+  // Skip the function that is external or renamed.
+  if (!FuncDesc)
+    return;
+
+  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+    MismatchedFunctionSamples += FS.getTotalSamples();
+    return;
+  }
+  for (const auto &I : FS.getCallsiteSamples())
+    for (const auto &CS : I.second)
+      countMismatchedFuncSamples(CS.second);
+}
+
+void SampleProfileMatcher::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &MismatchCallsites = It->second;
+
+  auto IsCallsiteMismatched = [&](const LineLocation &Loc) {
+    auto It = MismatchCallsites.find(Loc);
+    if (It == MismatchCallsites.end())
+      return false;
+    return It->second == MatchState::Mismatched;
+  };
+
+  auto CountSamples = [&](const LineLocation &Loc, uint64_t Samples) {
+    auto It = MismatchCallsites.find(Loc);
+    if (It == MismatchCallsites.end())
+      return;
+    if (It->second == MatchState::Mismatched)
+      MismatchedCallsiteSamples += Samples;
+    else if (It->second == MatchState::Recovered)
+      RecoveredCallsiteSamples += Samples;
+  };
+
+  for (const auto &I : FS.getBodySamples())
+    CountSamples(I.first, I.second.getSamples());
+
+  for (const auto &I : FS.getCallsiteSamples()) {
+    uint64_t Samples = 0;
+    for (const auto &CS : I.second)
+      Samples += CS.second.getTotalSamples();
+
+    CountSamples(I.first, Samples);
+
+    if (IsCallsiteMismatched(I.first))
+      continue;
+
+    // Count mismatched samples for matched inlines.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second);
+  }
+}
+
+void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &MatchStates = It->second;
+  for (const auto &I : MatchStates) {
+    TotalProfiledCallsites++;
+    if (I.second == MatchState::Mismatched)
+      NumMismatchedCallsites++;
+    else if (I.second == MatchState::Recovered)
+      NumRecoveredCallsites++;
   }
 }
 
-void SampleProfileMatcher::reportOrPersistProfileStats() {
+void SampleProfileMatcher::computeAndReportProfileStaleness() {
+  if (!ReportProfileStaleness && !PersistProfileStaleness)
+    return;
+
+  // Count profile mismatches for profile staleness report.
+  for (const auto &F : M) {
+    if (ShouldSkipProfileLoading(F))
+      continue;
+    // As the stats will be merged by linker, skip reporting the metrics for
+    // imported functions to avoid repeated counting.
+    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
+      continue;
+    // Use top-level nested FS for counting profile mismatch metrics since
+    // currently once a callsite is mismatched, all its children profiles are
+    // dropped.
+    const auto *FS = Reader.getSamplesFor(F);
+    if (!FS)
+      continue;
+
+    TotalProfiledFunc++;
+    TotalFunctionSamples += FS->getTotalSamples();
+
+    if (FunctionSamples::ProfileIsProbeBased) {
+      const auto *FuncDesc = ProbeManager->getDesc(F);
+      if (FuncDesc && ProbeManager->profileIsHashMismatched(*FuncDesc, *FS))
+        NumMismatchedFunc++;
+
+      countMismatchedFuncSamples(*FS);
+    }
+
+    // Count mismatches and samples for calliste.
+    countMismatchCallsites(*FS);
+    countMismatchedCallsiteSamples(*FS);
+  }
+
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << PreMatchStats.NumMismatchedFuncHash << "/"
-             << PreMatchStats.TotalProfiledFunc << ")"
+      errs() << "(" << NumMismatchedFunc << "/" << TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
-             << " (" << PreMatchStats.MismatchedFuncHashSamples << "/"
-             << PreMatchStats.TotalFunctionSamples << ")"
+             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
+             << ")"
              << " of samples are discarded due to function hash mismatch.\n";
     }
-    errs() << "(" << PreMatchStats.NumMismatchedCallsites << "/"
-           << PreMatchStats.TotalProfiledCallsites << ")"
+    errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
+           << ")"
            << " of callsites' profile are invalid and "
-           << "(" << PreMatchStats.MismatchedCallsiteSamples << "/"
-           << PreMatchStats.TotalFunctionSamples << ")"
+           << "(" << MismatchedCallsiteSamples << "/" << TotalFunctionSamples
+           << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
-    if (SalvageStaleProfile) {
-      uint64_t NumRecoveredCallsites = PostMatchStats.TotalProfiledCallsites -
-                                       PostMatchStats.NumMismatchedCallsites;
-      uint64_t NumMismatchedCallsites =
-          PreMatchStats.NumMismatchedCallsites - NumRecoveredCallsites;
-      errs() << "Out of " << PostMatchStats.TotalProfiledCallsites
-             << " callsites used for profile matching, "
-             << NumRecoveredCallsites
-             << " callsites have been recovered. After the matching, ("
-             << NumMismatchedCallsites << "/"
-             << PreMatchStats.TotalProfiledCallsites
-             << ") of callsites are still invalid ("
-             << PostMatchStats.MismatchedCallsiteSamples << "/"
-             << PreMatchStats.TotalFunctionSamples << ")"
-             << " of samples are still discarded.\n";
-    }
+    errs() << "(" << NumRecoveredCallsites << "/" << TotalProfiledCallsites
+           << ")"
+           << " of callsites and "
+           << "(" << RecoveredCallsiteSamples << "/" << TotalFunctionSamples
+           << ")"
+           << " of samples are recovered by stale profile matching.\n";
   }
 
   if (PersistProfileStaleness) {
@@ -2575,31 +2595,22 @@ void SampleProfileMatcher::reportOrPersistProfileStats() {
     MDBuilder MDB(Ctx);
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
-    ProfStatsVec.emplace_back("NumMismatchedCallsites",
-                              PreMatchStats.NumMismatchedCallsites);
-    ProfStatsVec.emplace_back("TotalProfiledCallsites",
-                              PreMatchStats.TotalProfiledCallsites);
-    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
-                              PreMatchStats.MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("TotalProfiledFunc",
-                              PreMatchStats.TotalProfiledFunc);
-    ProfStatsVec.emplace_back("TotalFunctionSamples",
-                              PreMatchStats.TotalFunctionSamples);
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumMismatchedFuncHash",
-                                PreMatchStats.NumMismatchedFuncHash);
-      ProfStatsVec.emplace_back("MismatchedFuncHashSamples",
-                                PreMatchStats.MismatchedFuncHashSamples);
-    }
-    if (SalvageStaleProfile) {
-      ProfStatsVec.emplace_back("PostMatchNumMismatchedCallsites",
-                                PostMatchStats.NumMismatchedCallsites);
-      ProfStatsVec.emplace_back("NumCallsitesForMatching",
-                                PostMatchStats.TotalProfiledCallsites);
-      ProfStatsVec.emplace_back("PostMatchMismatchedCallsiteSamples",
-                                PostMatchStats.MismatchedCallsiteSamples);
+      ProfStatsVec.emplace_back("NumMismatchedFunc", NumMismatchedFunc);
+      ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
+      ProfStatsVec.emplace_back("MismatchedFunctionSamples",
+                                MismatchedFunctionSamples);
+      ProfStatsVec.emplace_back("TotalFunctionSamples", TotalFunctionSamples);
     }
 
+    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("NumRecoveredCallsites", NumRecoveredCallsites);
+    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
+    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
+                              MismatchedCallsiteSamples);
+    ProfStatsVec.emplace_back("RecoveredCallsiteSamples",
+                              RecoveredCallsiteSamples);
+
     auto *MD = MDB.createLLVMStats(ProfStatsVec);
     auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
     NMD->addOperand(MD);
@@ -2610,24 +2621,14 @@ void SampleProfileMatcher::runOnModule() {
   ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
                                    FunctionSamples::ProfileIsCS);
   for (auto &F : M) {
-    if (ShouldSkipProfileLoading(F))
+    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
       continue;
     runOnFunction(F);
   }
-
   if (SalvageStaleProfile)
     distributeIRToProfileLocationMap();
 
-  PreMatchStats.countMismatchedCallsiteSamples();
-  if (SalvageStaleProfile) {
-    // If a function doesn't run the matching but has mismatched callsites, this
-    // won't be any data for that function in post-match stats, so just reuse
-    // the pre-match stats.
-    PostMatchStats.copyUnchangedCallsiteMismatches(
-        PreMatchStats.FuncMismatchedCallsites);
-    PostMatchStats.countMismatchedCallsiteSamples();
-  }
-  reportOrPersistProfileStats();
+  computeAndReportProfileStaleness();
 }
 
 void SampleProfileMatcher::distributeIRToProfileLocationMap(
diff --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index 14e384d7964ab0..35f20dcb85ed66 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -6,7 +6,7 @@
 ; RUN: llvm-objdump --section-headers %t.obj | FileCheck %s --check-prefix=CHECK-OBJ
 ; RUN: llc < %t.ll -filetype=asm -o - | FileCheck %s --check-prefix=CHECK-ASM
 
-; CHECK: (2/4) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
+; CHECK: (1/3) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
 
 ; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 15, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50}
 
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
index 768fe5509f33a9..4f8b13982c6107 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
@@ -5,6 +5,6 @@
 
 ; CHECK: (1/1) of functions' profile are invalid and  (6822/6822) of samples are discarded due to function hash mismatch.
 ; CHECK: (4/4) of callsites' profile are invalid and (5026/6822) of samples are discarded due to callsite location mismatch.
-; CHECK: Out of 4 callsites used for profile matching, 4 callsites have been recovered. After the matching, (0/4) of callsites are still invalid (0/6822) of samples are still discarded.
+; CHECK: (0/4) of callsites and (0/6822) of samples are recovered by stale profile matching.
 
-; CHECK-MD: !{!"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalProfiledFunc", i64 1, !"TotalFunctionSamples", i64 6822, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"PostMatchNumMismatchedCallsites", i64 0, !"NumCallsitesForMatching", i64 4, !"PostMatchMismatchedCallsiteSamples", i64 0}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"TotalFuncHashSamples", i64 6822, !"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalCallsiteSamples", i64 5026}
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
index 9949b5fd41f407..1158287835799e 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
@@ -11,7 +11,7 @@
 
 ; CHECK: (1/3) of functions' profile are invalid and (10/50) of samples are discarded due to function hash mismatch.
 ; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch.
-; CHECK: Out of 0 callsites used for profile matching, 0 callsites have been recovered. After the matching, (2/3) of callsites are still invalid (20/50) of samples are still discarded.
+; CHECK: (0/3) of callsites and (0/50) of samples are recovered by stale profile matching.
 
 ; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 10, !"PostMatchNumMismatchedCallsites", i64 0, !"NumCallsitesForMatching", i64 0, !"PostMatchMismatchedCallsiteSamples", i64 20}
 

>From 854eaf4c02ba4b99cbcea37c776229caaa523d03 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Thu, 1 Feb 2024 16:03:47 -0800
Subject: [PATCH 6/9] Addressing feedback

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp     | 194 +++++++++---------
 .../SampleProfile/profile-mismatch.ll         |  39 ++--
 .../pseudo-probe-profile-mismatch-thinlto.ll  |   4 +-
 .../pseudo-probe-profile-mismatch.ll          |  52 +++--
 4 files changed, 149 insertions(+), 140 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 9ba23d8f490803..3d7e36de0e0922 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -447,8 +447,12 @@ class SampleProfileMatcher {
   // Match state for an anchor/callsite.
   enum class MatchState {
     Matched = 0,
-    Mismatched = 0x1,
-    Recovered = 0x1,
+    Mismatched = 1,
+    // Stay Matched after profile matching.
+    StayMatched = 2,
+    // Recovered from Mismatched after profile matching.
+    Recovered = 3,
+    Unknown = 32,
   };
 
   // For each function, store every callsite state into a map, of which each
@@ -457,19 +461,17 @@ class SampleProfileMatcher {
   StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
       FuncCallsiteMatchStates;
 
-  /// Profile mismatch statstics:
+  // Profile mismatch statstics:
   uint64_t TotalProfiledFunc = 0;
-  // Num of function whose checksum is mismatched.
-  uint64_t NumMismatchedFunc = 0;
+  // Num of checksum-mismatched function.
+  uint64_t NumStaleProfileFunc = 0;
   uint64_t TotalProfiledCallsites = 0;
   uint64_t NumMismatchedCallsites = 0;
   uint64_t NumRecoveredCallsites = 0;
-
-  /// Weigted profile samples mismatch statstics:
+  // Total samples for all profiled functions.
   uint64_t TotalFunctionSamples = 0;
-  // Samples for the mismatched checksum functions;
+  // Total samples for all checksum-mismatched functions.
   uint64_t MismatchedFunctionSamples = 0;
-
   uint64_t MismatchedCallsiteSamples = 0;
   uint64_t RecoveredCallsiteSamples = 0;
 
@@ -504,11 +506,11 @@ class SampleProfileMatcher {
       const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
       const std::map<LineLocation, std::unordered_set<FunctionId>>
           &ProfileAnchors,
-      const LocToLocMap &IRToProfileLocationMap);
+      const LocToLocMap *IRToProfileLocationMap);
 
   // Count the samples of checksum mismatched function for the top-level
   // function and all inlinees.
-  void countMismatchedFuncSamples(const FunctionSamples &FS);
+  void countMismatchedFuncSamples(const FunctionSamples &FS, bool IsTopLevel);
   // Count the number of mismatched or recovered callsites.
   void countMismatchCallsites(const FunctionSamples &FS);
   // Count the samples of mismatched or recovered callsites for top-level
@@ -715,7 +717,7 @@ void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo(
 }
 } // namespace llvm
 
-bool ShouldSkipProfileLoading(const Function &F) {
+static bool skipProfileForFunction(const Function &F) {
   return F.isDeclaration() || !F.hasFnAttribute("use-sample-profile");
 }
 
@@ -1903,7 +1905,7 @@ SampleProfileLoader::buildProfiledCallGraph(Module &M) {
   // the profile. This makes sure functions missing from the profile still
   // gets a chance to be processed.
   for (Function &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+    if (skipProfileForFunction(F))
       continue;
     ProfiledCG->addProfiledFunction(
           getRepInFormat(FunctionSamples::getCanonicalFnName(F)));
@@ -1932,7 +1934,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
     }
 
     for (Function &F : M)
-      if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile"))
+      if (!skipProfileForFunction(F))
         FunctionOrderList.push_back(&F);
     return FunctionOrderList;
   }
@@ -1998,7 +2000,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
       }
       for (auto *Node : Range) {
         Function *F = SymbolMap.lookup(Node->Name);
-        if (F && !F->isDeclaration() && F->hasFnAttribute("use-sample-profile"))
+        if (F && !skipProfileForFunction(*F))
           FunctionOrderList.push_back(F);
       }
       ++CGI;
@@ -2009,7 +2011,7 @@ SampleProfileLoader::buildFunctionOrder(Module &M, LazyCallGraph &CG) {
       for (LazyCallGraph::SCC &C : RC) {
         for (LazyCallGraph::Node &N : C) {
           Function &F = N.getFunction();
-          if (!F.isDeclaration() && F.hasFnAttribute("use-sample-profile"))
+          if (!skipProfileForFunction(F))
             FunctionOrderList.push_back(&F);
         }
       }
@@ -2374,7 +2376,7 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
 
   // Compute the callsite match states for profile staleness report.
   if (ReportProfileStaleness || PersistProfileStaleness)
-    computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors, LocToLocMap());
+    computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
 
   // Run profile matching for checksum mismatched profile, currently only
   // support for pseudo-probe.
@@ -2386,10 +2388,9 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
     runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
                             IRToProfileLocationMap);
     // Find and update callsite match states after matching.
-    if ((ReportProfileStaleness || PersistProfileStaleness) &&
-        !IRToProfileLocationMap.empty())
+    if (ReportProfileStaleness || PersistProfileStaleness)
       computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
-                                 IRToProfileLocationMap);
+                                 &IRToProfileLocationMap);
   }
 }
 
@@ -2397,21 +2398,22 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
     const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
     const std::map<LineLocation, std::unordered_set<FunctionId>>
         &ProfileAnchors,
-    const LocToLocMap &IRToProfileLocationMap) {
-  // Use the matching result to determine if it's in post-match phrase.
-  bool IsPostMatch = !IRToProfileLocationMap.empty();
-  auto &MismatchedCallsites =
+    const LocToLocMap *IRToProfileLocationMap) {
+  bool IsPostMatch = IRToProfileLocationMap != nullptr;
+  auto &CallsiteMatchStates =
       FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
 
+  // IRToProfileLocationMap is null before the matching.
   auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
-    const auto &ProfileLoc = IRToProfileLocationMap.find(IRLoc);
-    if (ProfileLoc != IRToProfileLocationMap.end())
+    if (!IRToProfileLocationMap)
+      return IRLoc;
+    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap->end())
       return ProfileLoc->second;
     else
       return IRLoc;
   };
 
-  std::set<LineLocation> MatchedCallsites;
   for (const auto &I : IRAnchors) {
     // In post-match, use the matching result to remap the current IR callsite.
     const auto &Loc = MapIRLocToProfileLoc(I.first);
@@ -2421,18 +2423,27 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
       continue;
     const auto &Callees = It->second;
 
+    bool IsCallsiteMatched = false;
     // Since indirect call does not have CalleeName, check conservatively if
     // callsite in the profile is a callsite location. This is to reduce num of
     // false positive since otherwise all the indirect call samples will be
     // reported as mismatching.
     if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
-      MatchedCallsites.insert(Loc);
-    // TODO : Ideally, we should ensure it's a direct callsite location(Callees
-    // size is 1). However, there may be a bug for profile merge(like ODR
-    // violation) that causes the callees size to be more than 1. After we fix
-    // the bug, we can remove this check.
-    else if (Callees.count(getRepInFormat(IRCalleeName)))
-      MatchedCallsites.insert(Loc);
+      IsCallsiteMatched = true;
+    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
+      IsCallsiteMatched = true;
+
+    if (IsCallsiteMatched) {
+      auto R = CallsiteMatchStates.emplace(Loc, MatchState::Matched);
+      // Update the post-match state when there is a existing state indicateing
+      // it's in post-match phrase.
+      if (!R.second) {
+        if (R.first->second == MatchState::Mismatched)
+          R.first->second = MatchState::Recovered;
+        if (R.first->second == MatchState::Matched)
+          R.first->second = MatchState::StayMatched;
+      }
+    }
   }
 
   // Check if there are any callsites in the profile that does not match to any
@@ -2441,37 +2452,40 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
     const auto &Loc = I.first;
     [[maybe_unused]] const auto &Callees = I.second;
     assert(!Callees.empty() && "Callees should not be empty");
-    if (IsPostMatch) {
-      if (MatchedCallsites.count(Loc)) {
-        auto It = MismatchedCallsites.find(Loc);
-        if (It != MismatchedCallsites.end() &&
-            It->second == MatchState::Mismatched)
-          MismatchedCallsites.emplace(Loc, MatchState::Recovered);
-      } else
-        MismatchedCallsites.emplace(Loc, MatchState::Mismatched);
-    } else {
-      if (MatchedCallsites.count(Loc))
-        MismatchedCallsites.emplace(Loc, MatchState::Matched);
-      else
-        MismatchedCallsites.emplace(Loc, MatchState::Mismatched);
-    }
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
+    // If in post-match, the state is not updated to Recovered or StayMatched,
+    // update it to Mismatched.
+    else if (IsPostMatch && It->second == MatchState::Matched)
+      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
   }
 }
 
-void SampleProfileMatcher::countMismatchedFuncSamples(
-    const FunctionSamples &FS) {
+void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
+                                                      bool IsTopLevel) {
   const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
   // Skip the function that is external or renamed.
   if (!FuncDesc)
     return;
 
   if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+    if (IsTopLevel)
+      NumStaleProfileFunc++;
+    // Once the checksum is mismatched, it's likely all the callites are
+    // mismatched and dropped, we conservatively count all the samples as
+    // mismatched samples and stop counting the inlinee profile.
     MismatchedFunctionSamples += FS.getTotalSamples();
     return;
   }
+
+  // Even the current function checksum is matched, it's possible that the
+  // inlinees' checksums are mismatched, we need to go deeper to check the
+  // inlinee's function samples. Similarly, if the inlinee's checksum is
+  // mismatched, we stop and count all the samples as mismatched samples.
   for (const auto &I : FS.getCallsiteSamples())
     for (const auto &CS : I.second)
-      countMismatchedFuncSamples(CS.second);
+      countMismatchedFuncSamples(CS.second, false);
 }
 
 void SampleProfileMatcher::countMismatchedCallsiteSamples(
@@ -2480,39 +2494,41 @@ void SampleProfileMatcher::countMismatchedCallsiteSamples(
   // Skip it if no mismatched callsite or this is an external function.
   if (It == FuncCallsiteMatchStates.end() || It->second.empty())
     return;
-  const auto &MismatchCallsites = It->second;
+  const auto &CallsiteMatchStates = It->second;
 
-  auto IsCallsiteMismatched = [&](const LineLocation &Loc) {
-    auto It = MismatchCallsites.find(Loc);
-    if (It == MismatchCallsites.end())
-      return false;
-    return It->second == MatchState::Mismatched;
+  auto findMatchState = [&](const LineLocation &Loc) {
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      return MatchState::Unknown;
+    return It->second;
   };
 
-  auto CountSamples = [&](const LineLocation &Loc, uint64_t Samples) {
-    auto It = MismatchCallsites.find(Loc);
-    if (It == MismatchCallsites.end())
-      return;
-    if (It->second == MatchState::Mismatched)
+  auto AttributeMismatchedSamples = [&](const enum MatchState &State,
+                                        uint64_t Samples) {
+    if (State == MatchState::Mismatched)
       MismatchedCallsiteSamples += Samples;
-    else if (It->second == MatchState::Recovered)
+    else if (State == MatchState::Recovered)
       RecoveredCallsiteSamples += Samples;
   };
 
+  // The non-inlined callsites are saved in the body samples of function
+  // profile.
   for (const auto &I : FS.getBodySamples())
-    CountSamples(I.first, I.second.getSamples());
+    AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
 
   for (const auto &I : FS.getCallsiteSamples()) {
-    uint64_t Samples = 0;
+    auto State = findMatchState(I.first);
+    uint64_t CallsiteSamples = 0;
     for (const auto &CS : I.second)
-      Samples += CS.second.getTotalSamples();
-
-    CountSamples(I.first, Samples);
+      CallsiteSamples += CS.second.getTotalSamples();
+    AttributeMismatchedSamples(State, CallsiteSamples);
 
-    if (IsCallsiteMismatched(I.first))
+    if (State == MatchState::Mismatched)
       continue;
 
-    // Count mismatched samples for matched inlines.
+    // When the current level of inlined call site matches the profiled call
+    // site, we need to go deeper along the inline tree to count mismatches from
+    // lower level inlinees.
     for (const auto &CS : I.second)
       countMismatchedCallsiteSamples(CS.second);
   }
@@ -2539,29 +2555,21 @@ void SampleProfileMatcher::computeAndReportProfileStaleness() {
 
   // Count profile mismatches for profile staleness report.
   for (const auto &F : M) {
-    if (ShouldSkipProfileLoading(F))
+    if (skipProfileForFunction(F))
       continue;
     // As the stats will be merged by linker, skip reporting the metrics for
     // imported functions to avoid repeated counting.
     if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
       continue;
-    // Use top-level nested FS for counting profile mismatch metrics since
-    // currently once a callsite is mismatched, all its children profiles are
-    // dropped.
     const auto *FS = Reader.getSamplesFor(F);
     if (!FS)
       continue;
-
     TotalProfiledFunc++;
     TotalFunctionSamples += FS->getTotalSamples();
 
-    if (FunctionSamples::ProfileIsProbeBased) {
-      const auto *FuncDesc = ProbeManager->getDesc(F);
-      if (FuncDesc && ProbeManager->profileIsHashMismatched(*FuncDesc, *FS))
-        NumMismatchedFunc++;
-
-      countMismatchedFuncSamples(*FS);
-    }
+    // Checksum mismatch is only used in pseudo-probe mode.
+    if (FunctionSamples::ProfileIsProbeBased)
+      countMismatchedFuncSamples(*FS, true);
 
     // Count mismatches and samples for calliste.
     countMismatchCallsites(*FS);
@@ -2570,23 +2578,23 @@ void SampleProfileMatcher::computeAndReportProfileStaleness() {
 
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumMismatchedFunc << "/" << TotalProfiledFunc << ")"
+      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
              << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
              << ")"
              << " of samples are discarded due to function hash mismatch.\n";
     }
-    errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
-           << ")"
+    errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
+           << TotalProfiledCallsites << ")"
            << " of callsites' profile are invalid and "
-           << "(" << MismatchedCallsiteSamples << "/" << TotalFunctionSamples
-           << ")"
+           << "(" << (MismatchedCallsiteSamples + RecoveredCallsiteSamples)
+           << "/" << TotalFunctionSamples << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
-    errs() << "(" << NumRecoveredCallsites << "/" << TotalProfiledCallsites
-           << ")"
+    errs() << "(" << NumRecoveredCallsites << "/"
+           << (NumRecoveredCallsites + NumMismatchedCallsites) << ")"
            << " of callsites and "
-           << "(" << RecoveredCallsiteSamples << "/" << TotalFunctionSamples
-           << ")"
+           << "(" << RecoveredCallsiteSamples << "/"
+           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) << ")"
            << " of samples are recovered by stale profile matching.\n";
   }
 
@@ -2596,7 +2604,7 @@ void SampleProfileMatcher::computeAndReportProfileStaleness() {
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumMismatchedFunc", NumMismatchedFunc);
+      ProfStatsVec.emplace_back("NumStaleProfileFunc", NumStaleProfileFunc);
       ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
       ProfStatsVec.emplace_back("MismatchedFunctionSamples",
                                 MismatchedFunctionSamples);
@@ -2621,7 +2629,7 @@ void SampleProfileMatcher::runOnModule() {
   ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
                                    FunctionSamples::ProfileIsCS);
   for (auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
+    if (skipProfileForFunction(F))
       continue;
     runOnFunction(F);
   }
diff --git a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
index 35f20dcb85ed66..42bc1b81f67059 100644
--- a/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/profile-mismatch.ll
@@ -8,27 +8,30 @@
 
 ; CHECK: (1/3) of callsites' profile are invalid and (15/50) of samples are discarded due to callsite location mismatch.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 15, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50}
+; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 1, !"NumRecoveredCallsites", i64 0, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 15, !"RecoveredCallsiteSamples", i64 0}
 
 ; CHECK-OBJ: .llvm_stats
 
-; CHECK-ASM: .section  .llvm_stats,"", at progbits
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "NumMismatchedCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "Mg=="
-; CHECK-ASM: .byte 22
-; CHECK-ASM: .ascii  "TotalProfiledCallsites"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "NA=="
-; CHECK-ASM: .byte 25
-; CHECK-ASM: .ascii  "MismatchedCallsiteSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "MTU="
-; CHECK-ASM: .byte 20
-; CHECK-ASM: .ascii  "TotalFunctionSamples"
-; CHECK-ASM: .byte 4
-; CHECK-ASM: .ascii  "NTA="
+; CHECK-ASM: .ascii	"NumMismatchedCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MQ=="
+; CHECK-ASM: .byte	21
+; CHECK-ASM: .ascii	"NumRecoveredCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"TotalProfiledCallsites"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTU="
+; CHECK-ASM: .byte	24
+; CHECK-ASM: .ascii	"RecoveredCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
+
 
 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
index 4f8b13982c6107..bf33df2481044b 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch-thinlto.ll
@@ -5,6 +5,6 @@
 
 ; CHECK: (1/1) of functions' profile are invalid and  (6822/6822) of samples are discarded due to function hash mismatch.
 ; CHECK: (4/4) of callsites' profile are invalid and (5026/6822) of samples are discarded due to callsite location mismatch.
-; CHECK: (0/4) of callsites and (0/6822) of samples are recovered by stale profile matching.
+; CHECK: (4/4) of callsites and (5026/5026) of samples are recovered by stale profile matching.
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedFuncHash", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFuncHashSamples", i64 6822, !"TotalFuncHashSamples", i64 6822, !"NumMismatchedCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 5026, !"TotalCallsiteSamples", i64 5026}
+; CHECK-MD: ![[#]] = !{!"NumStaleProfileFunc", i64 1, !"TotalProfiledFunc", i64 1, !"MismatchedFunctionSamples", i64 6822, !"TotalFunctionSamples", i64 6822, !"NumMismatchedCallsites", i64 0, !"NumRecoveredCallsites", i64 4, !"TotalProfiledCallsites", i64 4, !"MismatchedCallsiteSamples", i64 0, !"RecoveredCallsiteSamples", i64 5026}
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
index 1158287835799e..de1c4154a67e89 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-profile-mismatch.ll
@@ -11,53 +11,51 @@
 
 ; CHECK: (1/3) of functions' profile are invalid and (10/50) of samples are discarded due to function hash mismatch.
 ; CHECK: (2/3) of callsites' profile are invalid and (20/50) of samples are discarded due to callsite location mismatch.
-; CHECK: (0/3) of callsites and (0/50) of samples are recovered by stale profile matching.
+; CHECK: (0/2) of callsites and (0/20) of samples are recovered by stale profile matching.
+
+; CHECK-MD: ![[#]] = !{!"NumStaleProfileFunc", i64 1, !"TotalProfiledFunc", i64 3, !"MismatchedFunctionSamples", i64 10, !"TotalFunctionSamples", i64 50, !"NumMismatchedCallsites", i64 2, !"NumRecoveredCallsites", i64 0, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"RecoveredCallsiteSamples", i64 0}
 
-; CHECK-MD: ![[#]] = !{!"NumMismatchedCallsites", i64 2, !"TotalProfiledCallsites", i64 3, !"MismatchedCallsiteSamples", i64 20, !"TotalProfiledFunc", i64 3, !"TotalFunctionSamples", i64 50, !"NumMismatchedFuncHash", i64 1, !"MismatchedFuncHashSamples", i64 10, !"PostMatchNumMismatchedCallsites", i64 0, !"NumCallsitesForMatching", i64 0, !"PostMatchMismatchedCallsiteSamples", i64 20}
 
 ; CHECK-OBJ: .llvm_stats
 
 ; CHECK-ASM: .section	.llvm_stats,"", at progbits
-; CHECK-ASM: .byte	22
-; CHECK-ASM: .ascii	"NumMismatchedCallsites"
+; CHECK-ASM: .byte	19
+; CHECK-ASM: .ascii	"NumStaleProfileFunc"
 ; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"Mg=="
-; CHECK-ASM: .byte	22
-; CHECK-ASM: .ascii	"TotalProfiledCallsites"
-; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"Mw=="
-; CHECK-ASM: .byte	25
-; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
-; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"MjA="
+; CHECK-ASM: .ascii	"MQ=="
 ; CHECK-ASM: .byte	17
 ; CHECK-ASM: .ascii	"TotalProfiledFunc"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedFunctionSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MTA="
 ; CHECK-ASM: .byte	20
 ; CHECK-ASM: .ascii	"TotalFunctionSamples"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"NTA="
-; CHECK-ASM: .byte	21
-; CHECK-ASM: .ascii	"NumMismatchedFuncHash"
-; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"MQ=="
-; CHECK-ASM: .byte	25
-; CHECK-ASM: .ascii	"MismatchedFuncHashSamples"
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"NumMismatchedCallsites"
 ; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"MTA="
-; CHECK-ASM: .byte	31
-; CHECK-ASM: .ascii	"PostMatchNumMismatchedCallsites"
+; CHECK-ASM: .ascii	"Mg=="
+; CHECK-ASM: .byte	21
+; CHECK-ASM: .ascii	"NumRecoveredCallsites"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"MA=="
-; CHECK-ASM: .byte	23
-; CHECK-ASM: .ascii	"NumCallsitesForMatching"
+; CHECK-ASM: .byte	22
+; CHECK-ASM: .ascii	"TotalProfiledCallsites"
 ; CHECK-ASM: .byte	4
-; CHECK-ASM: .ascii	"MA=="
-; CHECK-ASM: .byte	34
-; CHECK-ASM: .ascii	"PostMatchMismatchedCallsiteSamples"
+; CHECK-ASM: .ascii	"Mw=="
+; CHECK-ASM: .byte	25
+; CHECK-ASM: .ascii	"MismatchedCallsiteSamples"
 ; CHECK-ASM: .byte	4
 ; CHECK-ASM: .ascii	"MjA="
+; CHECK-ASM: .byte	24
+; CHECK-ASM: .ascii	"RecoveredCallsiteSamples"
+; CHECK-ASM: .byte	4
+; CHECK-ASM: .ascii	"MA=="
+
 
 ; CHECK-NESTED: (1/2) of functions' profile are invalid and (211/311) of samples are discarded due to function hash mismatch.
 

>From 15ca43a0e6bbd9677ba56460305e88f467942a0e Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Tue, 6 Feb 2024 15:46:15 -0800
Subject: [PATCH 7/9] Update comments

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 37 ++++++++++++-----------
 1 file changed, 20 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 3d7e36de0e0922..c55496c165158d 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -455,9 +455,9 @@ class SampleProfileMatcher {
     Unknown = 32,
   };
 
-  // For each function, store every callsite state into a map, of which each
-  // entry is a pair of callsite location and MatchState. This is used for
-  // profile stalness computation and report.
+  // For each function, store every callsite a matching state into this map, of
+  // which each entry is a pair of callsite location and MatchState. This is
+  // used for profile stalness computation and report.
   StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
       FuncCallsiteMatchStates;
 
@@ -501,7 +501,7 @@ class SampleProfileMatcher {
       const FunctionSamples &FS,
       std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
   // Compute the callsite match states for profile staleness report, the result
-  // is save in FuncCallsiteMatchStates.
+  // is saved in FuncCallsiteMatchStates.
   void computeCallsiteMatchStates(
       const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
       const std::map<LineLocation, std::unordered_set<FunctionId>>
@@ -2403,8 +2403,8 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
   auto &CallsiteMatchStates =
       FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
 
-  // IRToProfileLocationMap is null before the matching.
   auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    // IRToProfileLocationMap is null in pre-match phrase.
     if (!IRToProfileLocationMap)
       return IRLoc;
     const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
@@ -2435,8 +2435,8 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
 
     if (IsCallsiteMatched) {
       auto R = CallsiteMatchStates.emplace(Loc, MatchState::Matched);
-      // Update the post-match state when there is a existing state indicateing
-      // it's in post-match phrase.
+      // When there is an existing state, we know it's in post-match phrase.
+      // Update the matching state accordingly.
       if (!R.second) {
         if (R.first->second == MatchState::Mismatched)
           R.first->second = MatchState::Recovered;
@@ -2447,7 +2447,7 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
   }
 
   // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
+  // IR callsites.
   for (const auto &I : ProfileAnchors) {
     const auto &Loc = I.first;
     [[maybe_unused]] const auto &Callees = I.second;
@@ -2455,7 +2455,7 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
     auto It = CallsiteMatchStates.find(Loc);
     if (It == CallsiteMatchStates.end())
       CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
-    // If in post-match, the state is not updated to Recovered or StayMatched,
+    // In post-match, if the state is not updated to Recovered or StayMatched,
     // update it to Mismatched.
     else if (IsPostMatch && It->second == MatchState::Matched)
       CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
@@ -2472,17 +2472,19 @@ void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
   if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
     if (IsTopLevel)
       NumStaleProfileFunc++;
-    // Once the checksum is mismatched, it's likely all the callites are
-    // mismatched and dropped, we conservatively count all the samples as
-    // mismatched samples and stop counting the inlinee profile.
+    // Given currently all probe ids are after block probe ids, once the
+    // checksum is mismatched, it's likely all the callites are mismatched and
+    // dropped. We conservatively count all the samples as mismatched and stop
+    // counting the inlinees' profiles.
     MismatchedFunctionSamples += FS.getTotalSamples();
     return;
   }
 
-  // Even the current function checksum is matched, it's possible that the
-  // inlinees' checksums are mismatched, we need to go deeper to check the
-  // inlinee's function samples. Similarly, if the inlinee's checksum is
-  // mismatched, we stop and count all the samples as mismatched samples.
+  // Even the current-level function checksum is matched, it's possible that the
+  // nested inlinees' checksums are mismatched that affect the inlinee's sample
+  // loading, we need to go deeper to check the inlinees' function samples.
+  // Similarly, count all the samples as mismatched if the inlinee's checksum is
+  // mismatched using this recursive function.
   for (const auto &I : FS.getCallsiteSamples())
     for (const auto &CS : I.second)
       countMismatchedFuncSamples(CS.second, false);
@@ -2512,10 +2514,11 @@ void SampleProfileMatcher::countMismatchedCallsiteSamples(
   };
 
   // The non-inlined callsites are saved in the body samples of function
-  // profile.
+  // profile, go through it to count the non-inlined callsite samples.
   for (const auto &I : FS.getBodySamples())
     AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
 
+  // Count the inlined callsite samples.
   for (const auto &I : FS.getCallsiteSamples()) {
     auto State = findMatchState(I.first);
     uint64_t CallsiteSamples = 0;

>From 0dc2c79aa20c793669110750116a6b421b5338a2 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Mon, 12 Feb 2024 23:43:52 -0800
Subject: [PATCH 8/9] addressing comments

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 15 +++++++--------
 1 file changed, 7 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index c55496c165158d..f8267bec4bcf16 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -455,9 +455,9 @@ class SampleProfileMatcher {
     Unknown = 32,
   };
 
-  // For each function, store every callsite a matching state into this map, of
-  // which each entry is a pair of callsite location and MatchState. This is
-  // used for profile stalness computation and report.
+  // For each function, store every callsite and its matching state into this
+  // map, of which each entry is a pair of callsite location and MatchState.
+  // This is used for profile staleness computation and report.
   StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
       FuncCallsiteMatchStates;
 
@@ -2416,9 +2416,9 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
 
   for (const auto &I : IRAnchors) {
     // In post-match, use the matching result to remap the current IR callsite.
-    const auto &Loc = MapIRLocToProfileLoc(I.first);
+    const auto &ProfileLoc = MapIRLocToProfileLoc(I.first);
     const auto &IRCalleeName = I.second;
-    const auto &It = ProfileAnchors.find(Loc);
+    const auto &It = ProfileAnchors.find(ProfileLoc);
     if (It == ProfileAnchors.end())
       continue;
     const auto &Callees = It->second;
@@ -2434,7 +2434,7 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
       IsCallsiteMatched = true;
 
     if (IsCallsiteMatched) {
-      auto R = CallsiteMatchStates.emplace(Loc, MatchState::Matched);
+      auto R = CallsiteMatchStates.emplace(ProfileLoc, MatchState::Matched);
       // When there is an existing state, we know it's in post-match phrase.
       // Update the matching state accordingly.
       if (!R.second) {
@@ -2584,8 +2584,7 @@ void SampleProfileMatcher::computeAndReportProfileStaleness() {
       errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
              << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
-             << ")"
-             << " of samples are discarded due to function hash mismatch.\n";
+             << ") of samples are discarded due to function hash mismatch.\n";
     }
     errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
            << TotalProfiledCallsites << ")"

>From 353f0ac7347ab7b9455c612539e62b794ec1bd28 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Tue, 13 Feb 2024 18:06:17 -0800
Subject: [PATCH 9/9] Rename the MatchState value

---
 llvm/lib/Transforms/IPO/SampleProfile.cpp | 60 ++++++++++++-----------
 1 file changed, 31 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index f8267bec4bcf16..53beb7260a84bd 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -446,12 +446,12 @@ class SampleProfileMatcher {
 
   // Match state for an anchor/callsite.
   enum class MatchState {
-    Matched = 0,
-    Mismatched = 1,
-    // Stay Matched after profile matching.
-    StayMatched = 2,
-    // Recovered from Mismatched after profile matching.
-    Recovered = 3,
+    InitialMatch = 0,
+    InitialMismatch = 1,
+    // From initial mismatch to final match.
+    RecoveredMismatch = 2,
+    // From initial match to final mismatch.
+    RemovedMatch = 3,
     Unknown = 32,
   };
 
@@ -500,9 +500,9 @@ class SampleProfileMatcher {
   void findProfileAnchors(
       const FunctionSamples &FS,
       std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
-  // Compute the callsite match states for profile staleness report, the result
+  // Record the callsite match states for profile staleness report, the result
   // is saved in FuncCallsiteMatchStates.
-  void computeCallsiteMatchStates(
+  void recordCallsiteMatchStates(
       const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
       const std::map<LineLocation, std::unordered_set<FunctionId>>
           &ProfileAnchors,
@@ -2376,7 +2376,7 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
 
   // Compute the callsite match states for profile staleness report.
   if (ReportProfileStaleness || PersistProfileStaleness)
-    computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
+    recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
 
   // Run profile matching for checksum mismatched profile, currently only
   // support for pseudo-probe.
@@ -2389,12 +2389,12 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
                             IRToProfileLocationMap);
     // Find and update callsite match states after matching.
     if (ReportProfileStaleness || PersistProfileStaleness)
-      computeCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
-                                 &IRToProfileLocationMap);
+      recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
+                                &IRToProfileLocationMap);
   }
 }
 
-void SampleProfileMatcher::computeCallsiteMatchStates(
+void SampleProfileMatcher::recordCallsiteMatchStates(
     const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
     const std::map<LineLocation, std::unordered_set<FunctionId>>
         &ProfileAnchors,
@@ -2434,15 +2434,12 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
       IsCallsiteMatched = true;
 
     if (IsCallsiteMatched) {
-      auto R = CallsiteMatchStates.emplace(ProfileLoc, MatchState::Matched);
+      auto R =
+          CallsiteMatchStates.emplace(ProfileLoc, MatchState::InitialMatch);
       // When there is an existing state, we know it's in post-match phrase.
       // Update the matching state accordingly.
-      if (!R.second) {
-        if (R.first->second == MatchState::Mismatched)
-          R.first->second = MatchState::Recovered;
-        if (R.first->second == MatchState::Matched)
-          R.first->second = MatchState::StayMatched;
-      }
+      if (!R.second && R.first->second == MatchState::InitialMismatch)
+        R.first->second = MatchState::RecoveredMismatch;
     }
   }
 
@@ -2454,11 +2451,10 @@ void SampleProfileMatcher::computeCallsiteMatchStates(
     assert(!Callees.empty() && "Callees should not be empty");
     auto It = CallsiteMatchStates.find(Loc);
     if (It == CallsiteMatchStates.end())
-      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
-    // In post-match, if the state is not updated to Recovered or StayMatched,
-    // update it to Mismatched.
-    else if (IsPostMatch && It->second == MatchState::Matched)
-      CallsiteMatchStates.emplace(Loc, MatchState::Mismatched);
+      CallsiteMatchStates.emplace(Loc, MatchState::InitialMismatch);
+    // The inital match is removed in post-match.
+    else if (IsPostMatch && It->second == MatchState::InitialMatch)
+      CallsiteMatchStates.emplace(Loc, MatchState::RemovedMatch);
   }
 }
 
@@ -2505,11 +2501,16 @@ void SampleProfileMatcher::countMismatchedCallsiteSamples(
     return It->second;
   };
 
+  auto IsMismatchState = [&](const enum MatchState &State) {
+    return State == MatchState::InitialMismatch ||
+           State == MatchState::RemovedMatch;
+  };
+
   auto AttributeMismatchedSamples = [&](const enum MatchState &State,
                                         uint64_t Samples) {
-    if (State == MatchState::Mismatched)
+    if (IsMismatchState(State))
       MismatchedCallsiteSamples += Samples;
-    else if (State == MatchState::Recovered)
+    else if (State == MatchState::RecoveredMismatch)
       RecoveredCallsiteSamples += Samples;
   };
 
@@ -2526,7 +2527,7 @@ void SampleProfileMatcher::countMismatchedCallsiteSamples(
       CallsiteSamples += CS.second.getTotalSamples();
     AttributeMismatchedSamples(State, CallsiteSamples);
 
-    if (State == MatchState::Mismatched)
+    if (IsMismatchState(State))
       continue;
 
     // When the current level of inlined call site matches the profiled call
@@ -2545,9 +2546,10 @@ void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
   const auto &MatchStates = It->second;
   for (const auto &I : MatchStates) {
     TotalProfiledCallsites++;
-    if (I.second == MatchState::Mismatched)
+    if (I.second == MatchState::InitialMismatch ||
+        I.second == MatchState::RemovedMatch)
       NumMismatchedCallsites++;
-    else if (I.second == MatchState::Recovered)
+    else if (I.second == MatchState::RecoveredMismatch)
       NumRecoveredCallsites++;
   }
 }



More information about the llvm-commits mailing list