[llvm] [SampleFDO][NFC] Refactoring SampleProfileMatcher (PR #86988)

Lei Wang via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 28 15:24:30 PDT 2024


https://github.com/wlei-llvm updated https://github.com/llvm/llvm-project/pull/86988

>From 0521286684aed19456d866eb794e5ff017514031 Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Thu, 28 Mar 2024 11:29:11 -0700
Subject: [PATCH 1/2] [SampleFDO] Refactoring SampleProfileMatcher

---
 .../Transforms/IPO/SampleProfileMatcher.h     | 154 ++++
 .../Utils/SampleProfileLoaderBaseImpl.h       |   4 +
 llvm/lib/Transforms/IPO/CMakeLists.txt        |   1 +
 llvm/lib/Transforms/IPO/SampleProfile.cpp     | 672 +-----------------
 .../Transforms/IPO/SampleProfileMatcher.cpp   | 553 ++++++++++++++
 .../pseudo-probe-callee-profile-mismatch.ll   |   2 +-
 ...pseudo-probe-stale-profile-matching-lto.ll |   2 +-
 .../pseudo-probe-stale-profile-matching.ll    |   2 +-
 8 files changed, 719 insertions(+), 671 deletions(-)
 create mode 100644 llvm/include/llvm/Transforms/IPO/SampleProfileMatcher.h
 create mode 100644 llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp

diff --git a/llvm/include/llvm/Transforms/IPO/SampleProfileMatcher.h b/llvm/include/llvm/Transforms/IPO/SampleProfileMatcher.h
new file mode 100644
index 00000000000000..7ae6194da7c9cc
--- /dev/null
+++ b/llvm/include/llvm/Transforms/IPO/SampleProfileMatcher.h
@@ -0,0 +1,154 @@
+//===- Transforms/IPO/SampleProfileMatcher.h ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This file provides the interface for SampleProfileMatcher.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_IPO_SAMPLEPROFILEMATCHER_H
+#define LLVM_TRANSFORMS_IPO_SAMPLEPROFILEMATCHER_H
+
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
+
+namespace llvm {
+
+// Sample profile matching - fuzzy match.
+class SampleProfileMatcher {
+  Module &M;
+  SampleProfileReader &Reader;
+  const PseudoProbeManager *ProbeManager;
+  const ThinOrFullLTOPhase LTOPhase;
+  SampleProfileMap FlattenedProfiles;
+  // For each function, the matcher generates a map, of which each entry is a
+  // mapping from the source location of current build to the source location in
+  // the profile.
+  StringMap<LocToLocMap> FuncMappings;
+
+  // Match state for an anchor/callsite.
+  enum class MatchState {
+    Unknown = 0,
+    // Initial match between input profile and current IR.
+    InitialMatch = 1,
+    // Initial mismatch between input profile and current IR.
+    InitialMismatch = 2,
+    // InitialMatch stays matched after fuzzy profile matching.
+    UnchangedMatch = 3,
+    // InitialMismatch stays mismatched after fuzzy profile matching.
+    UnchangedMismatch = 4,
+    // InitialMismatch is recovered after fuzzy profile matching.
+    RecoveredMismatch = 5,
+    // InitialMatch is removed and becomes mismatched after fuzzy profile
+    // matching.
+    RemovedMatch = 6,
+  };
+
+  // For each function, store every callsite and its matching state into this
+  // map, of which each entry is a pair of callsite location and MatchState.
+  // This is used for profile staleness computation and report.
+  StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
+      FuncCallsiteMatchStates;
+
+  // Profile mismatch statstics:
+  uint64_t TotalProfiledFunc = 0;
+  // Num of checksum-mismatched function.
+  uint64_t NumStaleProfileFunc = 0;
+  uint64_t TotalProfiledCallsites = 0;
+  uint64_t NumMismatchedCallsites = 0;
+  uint64_t NumRecoveredCallsites = 0;
+  // Total samples for all profiled functions.
+  uint64_t TotalFunctionSamples = 0;
+  // Total samples for all checksum-mismatched functions.
+  uint64_t MismatchedFunctionSamples = 0;
+  uint64_t MismatchedCallsiteSamples = 0;
+  uint64_t RecoveredCallsiteSamples = 0;
+
+  // A dummy name for unknown indirect callee, used to differentiate from a
+  // non-call instruction that also has an empty callee name.
+  static constexpr const char *UnknownIndirectCallee =
+      "unknown.indirect.callee";
+
+public:
+  SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
+                       const PseudoProbeManager *ProbeManager,
+                       ThinOrFullLTOPhase LTOPhase)
+      : M(M), Reader(Reader), ProbeManager(ProbeManager), LTOPhase(LTOPhase){};
+  void runOnModule();
+  void clearMatchingData() {
+    // Do not clear FuncMappings, it stores IRLoc to ProfLoc remappings which
+    // will be used for sample loader.
+    FuncCallsiteMatchStates.clear();
+  }
+
+private:
+  FunctionSamples *getFlattenedSamplesFor(const Function &F) {
+    StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
+    auto It = FlattenedProfiles.find(FunctionId(CanonFName));
+    if (It != FlattenedProfiles.end())
+      return &It->second;
+    return nullptr;
+  }
+  void runOnFunction(Function &F);
+  void findIRAnchors(const Function &F,
+                     std::map<LineLocation, StringRef> &IRAnchors);
+  void findProfileAnchors(
+      const FunctionSamples &FS,
+      std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
+  // Record the callsite match states for profile staleness report, the result
+  // is saved in FuncCallsiteMatchStates.
+  void recordCallsiteMatchStates(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors,
+      const LocToLocMap *IRToProfileLocationMap);
+
+  bool isMismatchState(const enum MatchState &State) {
+    return State == MatchState::InitialMismatch ||
+           State == MatchState::UnchangedMismatch ||
+           State == MatchState::RemovedMatch;
+  };
+
+  bool isInitialState(const enum MatchState &State) {
+    return State == MatchState::InitialMatch ||
+           State == MatchState::InitialMismatch;
+  };
+
+  bool isFinalState(const enum MatchState &State) {
+    return State == MatchState::UnchangedMatch ||
+           State == MatchState::UnchangedMismatch ||
+           State == MatchState::RecoveredMismatch ||
+           State == MatchState::RemovedMatch;
+  };
+
+  // Count the samples of checksum mismatched function for the top-level
+  // function and all inlinees.
+  void countMismatchedFuncSamples(const FunctionSamples &FS, bool IsTopLevel);
+  // Count the number of mismatched or recovered callsites.
+  void countMismatchCallsites(const FunctionSamples &FS);
+  // Count the samples of mismatched or recovered callsites for top-level
+  // function and all inlinees.
+  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
+  void computeAndReportProfileStaleness();
+
+  LocToLocMap &getIRToProfileLocationMap(const Function &F) {
+    auto Ret = FuncMappings.try_emplace(
+        FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
+    return Ret.first->second;
+  }
+  void distributeIRToProfileLocationMap();
+  void distributeIRToProfileLocationMap(FunctionSamples &FS);
+  void runStaleProfileMatching(
+      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+      const std::map<LineLocation, std::unordered_set<FunctionId>>
+          &ProfileAnchors,
+      LocToLocMap &IRToProfileLocationMap);
+  void reportOrPersistProfileStats();
+};
+} // end namespace llvm
+#endif // LLVM_TRANSFORMS_IPO_SAMPLEPROFILEMATCHER_H
diff --git a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
index 048b97c34ee2ae..d898ee58307ead 100644
--- a/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
+++ b/llvm/include/llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h
@@ -146,6 +146,10 @@ class PseudoProbeManager {
 
 extern cl::opt<bool> SampleProfileUseProfi;
 
+static inline bool skipProfileForFunction(const Function &F) {
+  return F.isDeclaration() || !F.hasFnAttribute("use-sample-profile");
+}
+
 template <typename FT> class SampleProfileLoaderBaseImpl {
 public:
   SampleProfileLoaderBaseImpl(std::string Name, std::string RemapName,
diff --git a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt
index 034f1587ae8df4..5fbdbc3a014f9a 100644
--- a/llvm/lib/Transforms/IPO/CMakeLists.txt
+++ b/llvm/lib/Transforms/IPO/CMakeLists.txt
@@ -35,6 +35,7 @@ add_llvm_component_library(LLVMipo
   PartialInlining.cpp
   SampleContextTracker.cpp
   SampleProfile.cpp
+  SampleProfileMatcher.cpp
   SampleProfileProbe.cpp
   SCCP.cpp
   StripDeadPrototypes.cpp
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 7545a92c114ef2..b5f45a252c7b46 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -71,6 +71,7 @@
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Transforms/IPO/ProfiledCallGraph.h"
 #include "llvm/Transforms/IPO/SampleContextTracker.h"
+#include "llvm/Transforms/IPO/SampleProfileMatcher.h"
 #include "llvm/Transforms/IPO/SampleProfileProbe.h"
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Utils/CallPromotionUtils.h"
@@ -129,16 +130,16 @@ static cl::opt<std::string> SampleProfileRemappingFile(
     "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"),
     cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden);
 
-static cl::opt<bool> SalvageStaleProfile(
+cl::opt<bool> SalvageStaleProfile(
     "salvage-stale-profile", cl::Hidden, cl::init(false),
     cl::desc("Salvage stale profile by fuzzy matching and use the remapped "
              "location for sample profile query."));
 
-static cl::opt<bool> ReportProfileStaleness(
+cl::opt<bool> ReportProfileStaleness(
     "report-profile-staleness", cl::Hidden, cl::init(false),
     cl::desc("Compute and report stale profile statistical metrics."));
 
-static cl::opt<bool> PersistProfileStaleness(
+cl::opt<bool> PersistProfileStaleness(
     "persist-profile-staleness", cl::Hidden, cl::init(false),
     cl::desc("Compute stale profile statistical metrics and write it into the "
              "native object file(.llvm_stats section)."));
@@ -448,138 +449,6 @@ using CandidateQueue =
     PriorityQueue<InlineCandidate, std::vector<InlineCandidate>,
                   CandidateComparer>;
 
-// Sample profile matching - fuzzy match.
-class SampleProfileMatcher {
-  Module &M;
-  SampleProfileReader &Reader;
-  const PseudoProbeManager *ProbeManager;
-  const ThinOrFullLTOPhase LTOPhase;
-  SampleProfileMap FlattenedProfiles;
-  // For each function, the matcher generates a map, of which each entry is a
-  // mapping from the source location of current build to the source location in
-  // the profile.
-  StringMap<LocToLocMap> FuncMappings;
-
-  // Match state for an anchor/callsite.
-  enum class MatchState {
-    Unknown = 0,
-    // Initial match between input profile and current IR.
-    InitialMatch = 1,
-    // Initial mismatch between input profile and current IR.
-    InitialMismatch = 2,
-    // InitialMatch stays matched after fuzzy profile matching.
-    UnchangedMatch = 3,
-    // InitialMismatch stays mismatched after fuzzy profile matching.
-    UnchangedMismatch = 4,
-    // InitialMismatch is recovered after fuzzy profile matching.
-    RecoveredMismatch = 5,
-    // InitialMatch is removed and becomes mismatched after fuzzy profile
-    // matching.
-    RemovedMatch = 6,
-  };
-
-  // For each function, store every callsite and its matching state into this
-  // map, of which each entry is a pair of callsite location and MatchState.
-  // This is used for profile staleness computation and report.
-  StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
-      FuncCallsiteMatchStates;
-
-  // Profile mismatch statstics:
-  uint64_t TotalProfiledFunc = 0;
-  // Num of checksum-mismatched function.
-  uint64_t NumStaleProfileFunc = 0;
-  uint64_t TotalProfiledCallsites = 0;
-  uint64_t NumMismatchedCallsites = 0;
-  uint64_t NumRecoveredCallsites = 0;
-  // Total samples for all profiled functions.
-  uint64_t TotalFunctionSamples = 0;
-  // Total samples for all checksum-mismatched functions.
-  uint64_t MismatchedFunctionSamples = 0;
-  uint64_t MismatchedCallsiteSamples = 0;
-  uint64_t RecoveredCallsiteSamples = 0;
-
-  // A dummy name for unknown indirect callee, used to differentiate from a
-  // non-call instruction that also has an empty callee name.
-  static constexpr const char *UnknownIndirectCallee =
-      "unknown.indirect.callee";
-
-public:
-  SampleProfileMatcher(Module &M, SampleProfileReader &Reader,
-                       const PseudoProbeManager *ProbeManager,
-                       ThinOrFullLTOPhase LTOPhase)
-      : M(M), Reader(Reader), ProbeManager(ProbeManager), LTOPhase(LTOPhase){};
-  void runOnModule();
-  void clearMatchingData() {
-    // Do not clear FuncMappings, it stores IRLoc to ProfLoc remappings which
-    // will be used for sample loader.
-    FuncCallsiteMatchStates.clear();
-  }
-
-private:
-  FunctionSamples *getFlattenedSamplesFor(const Function &F) {
-    StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
-    auto It = FlattenedProfiles.find(FunctionId(CanonFName));
-    if (It != FlattenedProfiles.end())
-      return &It->second;
-    return nullptr;
-  }
-  void runOnFunction(Function &F);
-  void findIRAnchors(const Function &F,
-                     std::map<LineLocation, StringRef> &IRAnchors);
-  void findProfileAnchors(
-      const FunctionSamples &FS,
-      std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors);
-  // Record the callsite match states for profile staleness report, the result
-  // is saved in FuncCallsiteMatchStates.
-  void recordCallsiteMatchStates(
-      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors,
-      const LocToLocMap *IRToProfileLocationMap);
-
-  bool isMismatchState(const enum MatchState &State) {
-    return State == MatchState::InitialMismatch ||
-           State == MatchState::UnchangedMismatch ||
-           State == MatchState::RemovedMatch;
-  };
-
-  bool isInitialState(const enum MatchState &State) {
-    return State == MatchState::InitialMatch ||
-           State == MatchState::InitialMismatch;
-  };
-
-  bool isFinalState(const enum MatchState &State) {
-    return State == MatchState::UnchangedMatch ||
-           State == MatchState::UnchangedMismatch ||
-           State == MatchState::RecoveredMismatch ||
-           State == MatchState::RemovedMatch;
-  };
-
-  // Count the samples of checksum mismatched function for the top-level
-  // function and all inlinees.
-  void countMismatchedFuncSamples(const FunctionSamples &FS, bool IsTopLevel);
-  // Count the number of mismatched or recovered callsites.
-  void countMismatchCallsites(const FunctionSamples &FS);
-  // Count the samples of mismatched or recovered callsites for top-level
-  // function and all inlinees.
-  void countMismatchedCallsiteSamples(const FunctionSamples &FS);
-  void computeAndReportProfileStaleness();
-
-  LocToLocMap &getIRToProfileLocationMap(const Function &F) {
-    auto Ret = FuncMappings.try_emplace(
-        FunctionSamples::getCanonicalFnName(F.getName()), LocToLocMap());
-    return Ret.first->second;
-  }
-  void distributeIRToProfileLocationMap();
-  void distributeIRToProfileLocationMap(FunctionSamples &FS);
-  void runStaleProfileMatching(
-      const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-      const std::map<LineLocation, std::unordered_set<FunctionId>>
-          &ProfileAnchors,
-      LocToLocMap &IRToProfileLocationMap);
-  void reportOrPersistProfileStats();
-};
-
 /// Sample profile pass.
 ///
 /// This pass reads profile data from the file specified by
@@ -766,10 +635,6 @@ void SampleProfileLoaderBaseImpl<Function>::computeDominanceAndLoopInfo(
 }
 } // namespace llvm
 
-static bool skipProfileForFunction(const Function &F) {
-  return F.isDeclaration() || !F.hasFnAttribute("use-sample-profile");
-}
-
 ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) {
   if (FunctionSamples::ProfileIsProbeBased)
     return getProbeWeight(Inst);
@@ -2262,535 +2127,6 @@ bool SampleProfileLoader::rejectHighStalenessProfile(
   return false;
 }
 
-void SampleProfileMatcher::findIRAnchors(
-    const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
-  // For inlined code, recover the original callsite and callee by finding the
-  // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the
-  // top-level frame is "main:1", the callsite is "1" and the callee is "foo".
-  auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) {
-    assert((DIL && DIL->getInlinedAt()) && "No inlined callsite");
-    const DILocation *PrevDIL = nullptr;
-    do {
-      PrevDIL = DIL;
-      DIL = DIL->getInlinedAt();
-    } while (DIL->getInlinedAt());
-
-    LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
-    StringRef CalleeName = PrevDIL->getSubprogramLinkageName();
-    return std::make_pair(Callsite, CalleeName);
-  };
-
-  auto GetCanonicalCalleeName = [](const CallBase *CB) {
-    StringRef CalleeName = UnknownIndirectCallee;
-    if (Function *Callee = CB->getCalledFunction())
-      CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
-    return CalleeName;
-  };
-
-  // Extract profile matching anchors in the IR.
-  for (auto &BB : F) {
-    for (auto &I : BB) {
-      DILocation *DIL = I.getDebugLoc();
-      if (!DIL)
-        continue;
-
-      if (FunctionSamples::ProfileIsProbeBased) {
-        if (auto Probe = extractProbe(I)) {
-          // Flatten inlined IR for the matching.
-          if (DIL->getInlinedAt()) {
-            IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
-          } else {
-            // Use empty StringRef for basic block probe.
-            StringRef CalleeName;
-            if (const auto *CB = dyn_cast<CallBase>(&I)) {
-              // Skip the probe inst whose callee name is "llvm.pseudoprobe".
-              if (!isa<IntrinsicInst>(&I))
-                CalleeName = GetCanonicalCalleeName(CB);
-            }
-            IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName);
-          }
-        }
-      } else {
-        // TODO: For line-number based profile(AutoFDO), currently only support
-        // find callsite anchors. In future, we need to parse all the non-call
-        // instructions to extract the line locations for profile matching.
-        if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
-          continue;
-
-        if (DIL->getInlinedAt()) {
-          IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
-        } else {
-          LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
-          StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I));
-          IRAnchors.emplace(Callsite, CalleeName);
-        }
-      }
-    }
-  }
-}
-
-void SampleProfileMatcher::findProfileAnchors(
-    const FunctionSamples &FS,
-    std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
-  auto isInvalidLineOffset = [](uint32_t LineOffset) {
-    return LineOffset & 0x8000;
-  };
-
-  for (const auto &I : FS.getBodySamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
-    for (const auto &I : I.second.getCallTargets()) {
-      auto Ret = ProfileAnchors.try_emplace(Loc,
-                                            std::unordered_set<FunctionId>());
-      Ret.first->second.insert(I.first);
-    }
-  }
-
-  for (const auto &I : FS.getCallsiteSamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
-    const auto &CalleeMap = I.second;
-    for (const auto &I : CalleeMap) {
-      auto Ret = ProfileAnchors.try_emplace(Loc,
-                                            std::unordered_set<FunctionId>());
-      Ret.first->second.insert(I.first);
-    }
-  }
-}
-
-// Call target name anchor based profile fuzzy matching.
-// Input:
-// For IR locations, the anchor is the callee name of direct callsite; For
-// profile locations, it's the call target name for BodySamples or inlinee's
-// profile name for CallsiteSamples.
-// Matching heuristic:
-// First match all the anchors in lexical order, then split the non-anchor
-// locations between the two anchors evenly, first half are matched based on the
-// start anchor, second half are matched based on the end anchor.
-// For example, given:
-// IR locations:      [1, 2(foo), 3, 5, 6(bar), 7]
-// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9]
-// The matching gives:
-//   [1,    2(foo), 3,  5,  6(bar), 7]
-//    |     |       |   |     |     |
-//   [1, 2, 3(foo), 4,  7,  8(bar), 9]
-// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
-void SampleProfileMatcher::runStaleProfileMatching(
-    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
-    LocToLocMap &IRToProfileLocationMap) {
-  LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
-                    << "\n");
-  assert(IRToProfileLocationMap.empty() &&
-         "Run stale profile matching only once per function");
-
-  std::unordered_map<FunctionId, std::set<LineLocation>>
-      CalleeToCallsitesMap;
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    const auto &Callees = I.second;
-    // Filter out possible indirect calls, use direct callee name as anchor.
-    if (Callees.size() == 1) {
-      FunctionId CalleeName = *Callees.begin();
-      const auto &Candidates = CalleeToCallsitesMap.try_emplace(
-          CalleeName, std::set<LineLocation>());
-      Candidates.first->second.insert(Loc);
-    }
-  }
-
-  auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) {
-    // Skip the unchanged location mapping to save memory.
-    if (From != To)
-      IRToProfileLocationMap.insert({From, To});
-  };
-
-  // Use function's beginning location as the initial anchor.
-  int32_t LocationDelta = 0;
-  SmallVector<LineLocation> LastMatchedNonAnchors;
-
-  for (const auto &IR : IRAnchors) {
-    const auto &Loc = IR.first;
-    auto CalleeName = IR.second;
-    bool IsMatchedAnchor = false;
-    // Match the anchor location in lexical order.
-    if (!CalleeName.empty()) {
-      auto CandidateAnchors = CalleeToCallsitesMap.find(
-          getRepInFormat(CalleeName));
-      if (CandidateAnchors != CalleeToCallsitesMap.end() &&
-          !CandidateAnchors->second.empty()) {
-        auto CI = CandidateAnchors->second.begin();
-        const auto Candidate = *CI;
-        CandidateAnchors->second.erase(CI);
-        InsertMatching(Loc, Candidate);
-        LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName
-                          << " is matched from " << Loc << " to " << Candidate
-                          << "\n");
-        LocationDelta = Candidate.LineOffset - Loc.LineOffset;
-
-        // Match backwards for non-anchor locations.
-        // The locations in LastMatchedNonAnchors have been matched forwards
-        // based on the previous anchor, spilt it evenly and overwrite the
-        // second half based on the current anchor.
-        for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2;
-             I < LastMatchedNonAnchors.size(); I++) {
-          const auto &L = LastMatchedNonAnchors[I];
-          uint32_t CandidateLineOffset = L.LineOffset + LocationDelta;
-          LineLocation Candidate(CandidateLineOffset, L.Discriminator);
-          InsertMatching(L, Candidate);
-          LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L
-                            << " to " << Candidate << "\n");
-        }
-
-        IsMatchedAnchor = true;
-        LastMatchedNonAnchors.clear();
-      }
-    }
-
-    // Match forwards for non-anchor locations.
-    if (!IsMatchedAnchor) {
-      uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta;
-      LineLocation Candidate(CandidateLineOffset, Loc.Discriminator);
-      InsertMatching(Loc, Candidate);
-      LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to "
-                        << Candidate << "\n");
-      LastMatchedNonAnchors.emplace_back(Loc);
-    }
-  }
-}
-
-void SampleProfileMatcher::runOnFunction(Function &F) {
-  // We need to use flattened function samples for matching.
-  // Unlike IR, which includes all callsites from the source code, the callsites
-  // in profile only show up when they are hit by samples, i,e. the profile
-  // callsites in one context may differ from those in another context. To get
-  // the maximum number of callsites, we merge the function profiles from all
-  // contexts, aka, the flattened profile to find profile anchors.
-  const auto *FSFlattened = getFlattenedSamplesFor(F);
-  if (!FSFlattened)
-    return;
-
-  // Anchors for IR. It's a map from IR location to callee name, callee name is
-  // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
-  // for unknown indrect callee name.
-  std::map<LineLocation, StringRef> IRAnchors;
-  findIRAnchors(F, IRAnchors);
-  // Anchors for profile. It's a map from callsite location to a set of callee
-  // name.
-  std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
-  findProfileAnchors(*FSFlattened, ProfileAnchors);
-
-  // Compute the callsite match states for profile staleness report.
-  if (ReportProfileStaleness || PersistProfileStaleness)
-    recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
-
-  // Run profile matching for checksum mismatched profile, currently only
-  // support for pseudo-probe.
-  if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
-      !ProbeManager->profileIsValid(F, *FSFlattened)) {
-    // For imported functions, the checksum metadata(pseudo_probe_desc) are
-    // dropped, so we leverage function attribute(profile-checksum-mismatch) to
-    // transfer the info: add the attribute during pre-link phase and check it
-    // during post-link phase(see "profileIsValid").
-    if (FunctionSamples::ProfileIsProbeBased &&
-        LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink)
-      F.addFnAttr("profile-checksum-mismatch");
-
-    // The matching result will be saved to IRToProfileLocationMap, create a
-    // new map for each function.
-    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
-    runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
-                            IRToProfileLocationMap);
-    // Find and update callsite match states after matching.
-    if (ReportProfileStaleness || PersistProfileStaleness)
-      recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
-                                &IRToProfileLocationMap);
-  }
-}
-
-void SampleProfileMatcher::recordCallsiteMatchStates(
-    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
-    const std::map<LineLocation, std::unordered_set<FunctionId>>
-        &ProfileAnchors,
-    const LocToLocMap *IRToProfileLocationMap) {
-  bool IsPostMatch = IRToProfileLocationMap != nullptr;
-  auto &CallsiteMatchStates =
-      FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
-
-  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
-    // IRToProfileLocationMap is null in pre-match phrase.
-    if (!IRToProfileLocationMap)
-      return IRLoc;
-    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
-    if (ProfileLoc != IRToProfileLocationMap->end())
-      return ProfileLoc->second;
-    else
-      return IRLoc;
-  };
-
-  for (const auto &I : IRAnchors) {
-    // After fuzzy profile matching, use the matching result to remap the
-    // current IR callsite.
-    const auto &ProfileLoc = MapIRLocToProfileLoc(I.first);
-    const auto &IRCalleeName = I.second;
-    const auto &It = ProfileAnchors.find(ProfileLoc);
-    if (It == ProfileAnchors.end())
-      continue;
-    const auto &Callees = It->second;
-
-    bool IsCallsiteMatched = false;
-    // Since indirect call does not have CalleeName, check conservatively if
-    // callsite in the profile is a callsite location. This is to reduce num of
-    // false positive since otherwise all the indirect call samples will be
-    // reported as mismatching.
-    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
-      IsCallsiteMatched = true;
-    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
-      IsCallsiteMatched = true;
-
-    if (IsCallsiteMatched) {
-      auto It = CallsiteMatchStates.find(ProfileLoc);
-      if (It == CallsiteMatchStates.end())
-        CallsiteMatchStates.emplace(ProfileLoc, MatchState::InitialMatch);
-      else if (IsPostMatch) {
-        if (It->second == MatchState::InitialMatch)
-          It->second = MatchState::UnchangedMatch;
-        else if (It->second == MatchState::InitialMismatch)
-          It->second = MatchState::RecoveredMismatch;
-      }
-    }
-  }
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites.
-  for (const auto &I : ProfileAnchors) {
-    const auto &Loc = I.first;
-    [[maybe_unused]] const auto &Callees = I.second;
-    assert(!Callees.empty() && "Callees should not be empty");
-    auto It = CallsiteMatchStates.find(Loc);
-    if (It == CallsiteMatchStates.end())
-      CallsiteMatchStates.emplace(Loc, MatchState::InitialMismatch);
-    else if (IsPostMatch) {
-      // Update the state if it's not matched(UnchangedMatch or
-      // RecoveredMismatch).
-      if (It->second == MatchState::InitialMismatch)
-        It->second = MatchState::UnchangedMismatch;
-      else if (It->second == MatchState::InitialMatch)
-        It->second = MatchState::RemovedMatch;
-    }
-  }
-}
-
-void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
-                                                      bool IsTopLevel) {
-  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
-  // Skip the function that is external or renamed.
-  if (!FuncDesc)
-    return;
-
-  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
-    if (IsTopLevel)
-      NumStaleProfileFunc++;
-    // Given currently all probe ids are after block probe ids, once the
-    // checksum is mismatched, it's likely all the callites are mismatched and
-    // dropped. We conservatively count all the samples as mismatched and stop
-    // counting the inlinees' profiles.
-    MismatchedFunctionSamples += FS.getTotalSamples();
-    return;
-  }
-
-  // Even the current-level function checksum is matched, it's possible that the
-  // nested inlinees' checksums are mismatched that affect the inlinee's sample
-  // loading, we need to go deeper to check the inlinees' function samples.
-  // Similarly, count all the samples as mismatched if the inlinee's checksum is
-  // mismatched using this recursive function.
-  for (const auto &I : FS.getCallsiteSamples())
-    for (const auto &CS : I.second)
-      countMismatchedFuncSamples(CS.second, false);
-}
-
-void SampleProfileMatcher::countMismatchedCallsiteSamples(
-    const FunctionSamples &FS) {
-  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
-  // Skip it if no mismatched callsite or this is an external function.
-  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
-    return;
-  const auto &CallsiteMatchStates = It->second;
-
-  auto findMatchState = [&](const LineLocation &Loc) {
-    auto It = CallsiteMatchStates.find(Loc);
-    if (It == CallsiteMatchStates.end())
-      return MatchState::Unknown;
-    return It->second;
-  };
-
-  auto AttributeMismatchedSamples = [&](const enum MatchState &State,
-                                        uint64_t Samples) {
-    if (isMismatchState(State))
-      MismatchedCallsiteSamples += Samples;
-    else if (State == MatchState::RecoveredMismatch)
-      RecoveredCallsiteSamples += Samples;
-  };
-
-  // The non-inlined callsites are saved in the body samples of function
-  // profile, go through it to count the non-inlined callsite samples.
-  for (const auto &I : FS.getBodySamples())
-    AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
-
-  // Count the inlined callsite samples.
-  for (const auto &I : FS.getCallsiteSamples()) {
-    auto State = findMatchState(I.first);
-    uint64_t CallsiteSamples = 0;
-    for (const auto &CS : I.second)
-      CallsiteSamples += CS.second.getTotalSamples();
-    AttributeMismatchedSamples(State, CallsiteSamples);
-
-    if (isMismatchState(State))
-      continue;
-
-    // When the current level of inlined call site matches the profiled call
-    // site, we need to go deeper along the inline tree to count mismatches from
-    // lower level inlinees.
-    for (const auto &CS : I.second)
-      countMismatchedCallsiteSamples(CS.second);
-  }
-}
-
-void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
-  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
-  // Skip it if no mismatched callsite or this is an external function.
-  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
-    return;
-  const auto &MatchStates = It->second;
-  [[maybe_unused]] bool OnInitialState =
-      isInitialState(MatchStates.begin()->second);
-  for (const auto &I : MatchStates) {
-    TotalProfiledCallsites++;
-    assert(
-        (OnInitialState ? isInitialState(I.second) : isFinalState(I.second)) &&
-        "Profile matching state is inconsistent");
-
-    if (isMismatchState(I.second))
-      NumMismatchedCallsites++;
-    else if (I.second == MatchState::RecoveredMismatch)
-      NumRecoveredCallsites++;
-  }
-}
-
-void SampleProfileMatcher::computeAndReportProfileStaleness() {
-  if (!ReportProfileStaleness && !PersistProfileStaleness)
-    return;
-
-  // Count profile mismatches for profile staleness report.
-  for (const auto &F : M) {
-    if (skipProfileForFunction(F))
-      continue;
-    // As the stats will be merged by linker, skip reporting the metrics for
-    // imported functions to avoid repeated counting.
-    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
-      continue;
-    const auto *FS = Reader.getSamplesFor(F);
-    if (!FS)
-      continue;
-    TotalProfiledFunc++;
-    TotalFunctionSamples += FS->getTotalSamples();
-
-    // Checksum mismatch is only used in pseudo-probe mode.
-    if (FunctionSamples::ProfileIsProbeBased)
-      countMismatchedFuncSamples(*FS, true);
-
-    // Count mismatches and samples for calliste.
-    countMismatchCallsites(*FS);
-    countMismatchedCallsiteSamples(*FS);
-  }
-
-  if (ReportProfileStaleness) {
-    if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
-             << " of functions' profile are invalid and "
-             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
-             << ") of samples are discarded due to function hash mismatch.\n";
-    }
-    errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
-           << TotalProfiledCallsites << ")"
-           << " of callsites' profile are invalid and "
-           << "(" << (MismatchedCallsiteSamples + RecoveredCallsiteSamples)
-           << "/" << TotalFunctionSamples << ")"
-           << " of samples are discarded due to callsite location mismatch.\n";
-    errs() << "(" << NumRecoveredCallsites << "/"
-           << (NumRecoveredCallsites + NumMismatchedCallsites) << ")"
-           << " of callsites and "
-           << "(" << RecoveredCallsiteSamples << "/"
-           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) << ")"
-           << " of samples are recovered by stale profile matching.\n";
-  }
-
-  if (PersistProfileStaleness) {
-    LLVMContext &Ctx = M.getContext();
-    MDBuilder MDB(Ctx);
-
-    SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
-    if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumStaleProfileFunc", NumStaleProfileFunc);
-      ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
-      ProfStatsVec.emplace_back("MismatchedFunctionSamples",
-                                MismatchedFunctionSamples);
-      ProfStatsVec.emplace_back("TotalFunctionSamples", TotalFunctionSamples);
-    }
-
-    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
-    ProfStatsVec.emplace_back("NumRecoveredCallsites", NumRecoveredCallsites);
-    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
-    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
-                              MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("RecoveredCallsiteSamples",
-                              RecoveredCallsiteSamples);
-
-    auto *MD = MDB.createLLVMStats(ProfStatsVec);
-    auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
-    NMD->addOperand(MD);
-  }
-}
-
-void SampleProfileMatcher::runOnModule() {
-  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
-                                   FunctionSamples::ProfileIsCS);
-  for (auto &F : M) {
-    if (skipProfileForFunction(F))
-      continue;
-    runOnFunction(F);
-  }
-  if (SalvageStaleProfile)
-    distributeIRToProfileLocationMap();
-
-  computeAndReportProfileStaleness();
-}
-
-void SampleProfileMatcher::distributeIRToProfileLocationMap(
-    FunctionSamples &FS) {
-  const auto ProfileMappings = FuncMappings.find(FS.getFuncName());
-  if (ProfileMappings != FuncMappings.end()) {
-    FS.setIRToProfileLocationMap(&(ProfileMappings->second));
-  }
-
-  for (auto &Callees :
-       const_cast<CallsiteSampleMap &>(FS.getCallsiteSamples())) {
-    for (auto &FS : Callees.second) {
-      distributeIRToProfileLocationMap(FS.second);
-    }
-  }
-}
-
-// Use a central place to distribute the matching results. Outlined and inlined
-// profile with the function name will be set to the same pointer.
-void SampleProfileMatcher::distributeIRToProfileLocationMap() {
-  for (auto &I : Reader.getProfiles()) {
-    distributeIRToProfileLocationMap(I.second);
-  }
-}
-
 bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
                                       ProfileSummaryInfo *_PSI,
                                       LazyCallGraph &CG) {
diff --git a/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp b/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp
new file mode 100644
index 00000000000000..aa0a7364839c9b
--- /dev/null
+++ b/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp
@@ -0,0 +1,553 @@
+//===- SampleProfileMatcher.cpp - Sampling-based Stale Profile Matcher ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the SampleProfileMatcher used for stale
+// profile matching.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/IPO/SampleProfileMatcher.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/MDBuilder.h"
+
+using namespace llvm;
+using namespace sampleprof;
+
+#define DEBUG_TYPE "sample-profile-matcher"
+
+extern cl::opt<bool> SalvageStaleProfile;
+extern cl::opt<bool> PersistProfileStaleness;
+extern cl::opt<bool> ReportProfileStaleness;
+
+void SampleProfileMatcher::findIRAnchors(
+    const Function &F, std::map<LineLocation, StringRef> &IRAnchors) {
+  // For inlined code, recover the original callsite and callee by finding the
+  // top-level inline frame. e.g. For frame stack "main:1 @ foo:2 @ bar:3", the
+  // top-level frame is "main:1", the callsite is "1" and the callee is "foo".
+  auto FindTopLevelInlinedCallsite = [](const DILocation *DIL) {
+    assert((DIL && DIL->getInlinedAt()) && "No inlined callsite");
+    const DILocation *PrevDIL = nullptr;
+    do {
+      PrevDIL = DIL;
+      DIL = DIL->getInlinedAt();
+    } while (DIL->getInlinedAt());
+
+    LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+    StringRef CalleeName = PrevDIL->getSubprogramLinkageName();
+    return std::make_pair(Callsite, CalleeName);
+  };
+
+  auto GetCanonicalCalleeName = [](const CallBase *CB) {
+    StringRef CalleeName = UnknownIndirectCallee;
+    if (Function *Callee = CB->getCalledFunction())
+      CalleeName = FunctionSamples::getCanonicalFnName(Callee->getName());
+    return CalleeName;
+  };
+
+  // Extract profile matching anchors in the IR.
+  for (auto &BB : F) {
+    for (auto &I : BB) {
+      DILocation *DIL = I.getDebugLoc();
+      if (!DIL)
+        continue;
+
+      if (FunctionSamples::ProfileIsProbeBased) {
+        if (auto Probe = extractProbe(I)) {
+          // Flatten inlined IR for the matching.
+          if (DIL->getInlinedAt()) {
+            IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
+          } else {
+            // Use empty StringRef for basic block probe.
+            StringRef CalleeName;
+            if (const auto *CB = dyn_cast<CallBase>(&I)) {
+              // Skip the probe inst whose callee name is "llvm.pseudoprobe".
+              if (!isa<IntrinsicInst>(&I))
+                CalleeName = GetCanonicalCalleeName(CB);
+            }
+            IRAnchors.emplace(LineLocation(Probe->Id, 0), CalleeName);
+          }
+        }
+      } else {
+        // TODO: For line-number based profile(AutoFDO), currently only support
+        // find callsite anchors. In future, we need to parse all the non-call
+        // instructions to extract the line locations for profile matching.
+        if (!isa<CallBase>(&I) || isa<IntrinsicInst>(&I))
+          continue;
+
+        if (DIL->getInlinedAt()) {
+          IRAnchors.emplace(FindTopLevelInlinedCallsite(DIL));
+        } else {
+          LineLocation Callsite = FunctionSamples::getCallSiteIdentifier(DIL);
+          StringRef CalleeName = GetCanonicalCalleeName(dyn_cast<CallBase>(&I));
+          IRAnchors.emplace(Callsite, CalleeName);
+        }
+      }
+    }
+  }
+}
+
+void SampleProfileMatcher::findProfileAnchors(
+    const FunctionSamples &FS,
+    std::map<LineLocation, std::unordered_set<FunctionId>> &ProfileAnchors) {
+  auto isInvalidLineOffset = [](uint32_t LineOffset) {
+    return LineOffset & 0x8000;
+  };
+
+  for (const auto &I : FS.getBodySamples()) {
+    const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
+    for (const auto &I : I.second.getCallTargets()) {
+      auto Ret =
+          ProfileAnchors.try_emplace(Loc, std::unordered_set<FunctionId>());
+      Ret.first->second.insert(I.first);
+    }
+  }
+
+  for (const auto &I : FS.getCallsiteSamples()) {
+    const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
+    const auto &CalleeMap = I.second;
+    for (const auto &I : CalleeMap) {
+      auto Ret =
+          ProfileAnchors.try_emplace(Loc, std::unordered_set<FunctionId>());
+      Ret.first->second.insert(I.first);
+    }
+  }
+}
+
+// Call target name anchor based profile fuzzy matching.
+// Input:
+// For IR locations, the anchor is the callee name of direct callsite; For
+// profile locations, it's the call target name for BodySamples or inlinee's
+// profile name for CallsiteSamples.
+// Matching heuristic:
+// First match all the anchors in lexical order, then split the non-anchor
+// locations between the two anchors evenly, first half are matched based on the
+// start anchor, second half are matched based on the end anchor.
+// For example, given:
+// IR locations:      [1, 2(foo), 3, 5, 6(bar), 7]
+// Profile locations: [1, 2, 3(foo), 4, 7, 8(bar), 9]
+// The matching gives:
+//   [1,    2(foo), 3,  5,  6(bar), 7]
+//    |     |       |   |     |     |
+//   [1, 2, 3(foo), 4,  7,  8(bar), 9]
+// The output mapping: [2->3, 3->4, 5->7, 6->8, 7->9].
+void SampleProfileMatcher::runStaleProfileMatching(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    LocToLocMap &IRToProfileLocationMap) {
+  LLVM_DEBUG(dbgs() << "Run stale profile matching for " << F.getName()
+                    << "\n");
+  assert(IRToProfileLocationMap.empty() &&
+         "Run stale profile matching only once per function");
+
+  std::unordered_map<FunctionId, std::set<LineLocation>> CalleeToCallsitesMap;
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    const auto &Callees = I.second;
+    // Filter out possible indirect calls, use direct callee name as anchor.
+    if (Callees.size() == 1) {
+      FunctionId CalleeName = *Callees.begin();
+      const auto &Candidates = CalleeToCallsitesMap.try_emplace(
+          CalleeName, std::set<LineLocation>());
+      Candidates.first->second.insert(Loc);
+    }
+  }
+
+  auto InsertMatching = [&](const LineLocation &From, const LineLocation &To) {
+    // Skip the unchanged location mapping to save memory.
+    if (From != To)
+      IRToProfileLocationMap.insert({From, To});
+  };
+
+  // Use function's beginning location as the initial anchor.
+  int32_t LocationDelta = 0;
+  SmallVector<LineLocation> LastMatchedNonAnchors;
+
+  for (const auto &IR : IRAnchors) {
+    const auto &Loc = IR.first;
+    auto CalleeName = IR.second;
+    bool IsMatchedAnchor = false;
+    // Match the anchor location in lexical order.
+    if (!CalleeName.empty()) {
+      auto CandidateAnchors =
+          CalleeToCallsitesMap.find(getRepInFormat(CalleeName));
+      if (CandidateAnchors != CalleeToCallsitesMap.end() &&
+          !CandidateAnchors->second.empty()) {
+        auto CI = CandidateAnchors->second.begin();
+        const auto Candidate = *CI;
+        CandidateAnchors->second.erase(CI);
+        InsertMatching(Loc, Candidate);
+        LLVM_DEBUG(dbgs() << "Callsite with callee:" << CalleeName
+                          << " is matched from " << Loc << " to " << Candidate
+                          << "\n");
+        LocationDelta = Candidate.LineOffset - Loc.LineOffset;
+
+        // Match backwards for non-anchor locations.
+        // The locations in LastMatchedNonAnchors have been matched forwards
+        // based on the previous anchor, spilt it evenly and overwrite the
+        // second half based on the current anchor.
+        for (size_t I = (LastMatchedNonAnchors.size() + 1) / 2;
+             I < LastMatchedNonAnchors.size(); I++) {
+          const auto &L = LastMatchedNonAnchors[I];
+          uint32_t CandidateLineOffset = L.LineOffset + LocationDelta;
+          LineLocation Candidate(CandidateLineOffset, L.Discriminator);
+          InsertMatching(L, Candidate);
+          LLVM_DEBUG(dbgs() << "Location is rematched backwards from " << L
+                            << " to " << Candidate << "\n");
+        }
+
+        IsMatchedAnchor = true;
+        LastMatchedNonAnchors.clear();
+      }
+    }
+
+    // Match forwards for non-anchor locations.
+    if (!IsMatchedAnchor) {
+      uint32_t CandidateLineOffset = Loc.LineOffset + LocationDelta;
+      LineLocation Candidate(CandidateLineOffset, Loc.Discriminator);
+      InsertMatching(Loc, Candidate);
+      LLVM_DEBUG(dbgs() << "Location is matched from " << Loc << " to "
+                        << Candidate << "\n");
+      LastMatchedNonAnchors.emplace_back(Loc);
+    }
+  }
+}
+
+void SampleProfileMatcher::runOnFunction(Function &F) {
+  // We need to use flattened function samples for matching.
+  // Unlike IR, which includes all callsites from the source code, the callsites
+  // in profile only show up when they are hit by samples, i,e. the profile
+  // callsites in one context may differ from those in another context. To get
+  // the maximum number of callsites, we merge the function profiles from all
+  // contexts, aka, the flattened profile to find profile anchors.
+  const auto *FSFlattened = getFlattenedSamplesFor(F);
+  if (!FSFlattened)
+    return;
+
+  // Anchors for IR. It's a map from IR location to callee name, callee name is
+  // empty for non-call instruction and use a dummy name(UnknownIndirectCallee)
+  // for unknown indrect callee name.
+  std::map<LineLocation, StringRef> IRAnchors;
+  findIRAnchors(F, IRAnchors);
+  // Anchors for profile. It's a map from callsite location to a set of callee
+  // name.
+  std::map<LineLocation, std::unordered_set<FunctionId>> ProfileAnchors;
+  findProfileAnchors(*FSFlattened, ProfileAnchors);
+
+  // Compute the callsite match states for profile staleness report.
+  if (ReportProfileStaleness || PersistProfileStaleness)
+    recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors, nullptr);
+
+  // Run profile matching for checksum mismatched profile, currently only
+  // support for pseudo-probe.
+  if (SalvageStaleProfile && FunctionSamples::ProfileIsProbeBased &&
+      !ProbeManager->profileIsValid(F, *FSFlattened)) {
+    // For imported functions, the checksum metadata(pseudo_probe_desc) are
+    // dropped, so we leverage function attribute(profile-checksum-mismatch) to
+    // transfer the info: add the attribute during pre-link phase and check it
+    // during post-link phase(see "profileIsValid").
+    if (FunctionSamples::ProfileIsProbeBased &&
+        LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink)
+      F.addFnAttr("profile-checksum-mismatch");
+
+    // The matching result will be saved to IRToProfileLocationMap, create a
+    // new map for each function.
+    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
+    runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
+                            IRToProfileLocationMap);
+    // Find and update callsite match states after matching.
+    if (ReportProfileStaleness || PersistProfileStaleness)
+      recordCallsiteMatchStates(F, IRAnchors, ProfileAnchors,
+                                &IRToProfileLocationMap);
+  }
+}
+
+void SampleProfileMatcher::recordCallsiteMatchStates(
+    const Function &F, const std::map<LineLocation, StringRef> &IRAnchors,
+    const std::map<LineLocation, std::unordered_set<FunctionId>>
+        &ProfileAnchors,
+    const LocToLocMap *IRToProfileLocationMap) {
+  bool IsPostMatch = IRToProfileLocationMap != nullptr;
+  auto &CallsiteMatchStates =
+      FuncCallsiteMatchStates[FunctionSamples::getCanonicalFnName(F.getName())];
+
+  auto MapIRLocToProfileLoc = [&](const LineLocation &IRLoc) {
+    // IRToProfileLocationMap is null in pre-match phrase.
+    if (!IRToProfileLocationMap)
+      return IRLoc;
+    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap->end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  };
+
+  for (const auto &I : IRAnchors) {
+    // After fuzzy profile matching, use the matching result to remap the
+    // current IR callsite.
+    const auto &ProfileLoc = MapIRLocToProfileLoc(I.first);
+    const auto &IRCalleeName = I.second;
+    const auto &It = ProfileAnchors.find(ProfileLoc);
+    if (It == ProfileAnchors.end())
+      continue;
+    const auto &Callees = It->second;
+
+    bool IsCallsiteMatched = false;
+    // Since indirect call does not have CalleeName, check conservatively if
+    // callsite in the profile is a callsite location. This is to reduce num of
+    // false positive since otherwise all the indirect call samples will be
+    // reported as mismatching.
+    if (IRCalleeName == SampleProfileMatcher::UnknownIndirectCallee)
+      IsCallsiteMatched = true;
+    else if (Callees.size() == 1 && Callees.count(getRepInFormat(IRCalleeName)))
+      IsCallsiteMatched = true;
+
+    if (IsCallsiteMatched) {
+      auto It = CallsiteMatchStates.find(ProfileLoc);
+      if (It == CallsiteMatchStates.end())
+        CallsiteMatchStates.emplace(ProfileLoc, MatchState::InitialMatch);
+      else if (IsPostMatch) {
+        if (It->second == MatchState::InitialMatch)
+          It->second = MatchState::UnchangedMatch;
+        else if (It->second == MatchState::InitialMismatch)
+          It->second = MatchState::RecoveredMismatch;
+      }
+    }
+  }
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites.
+  for (const auto &I : ProfileAnchors) {
+    const auto &Loc = I.first;
+    [[maybe_unused]] const auto &Callees = I.second;
+    assert(!Callees.empty() && "Callees should not be empty");
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      CallsiteMatchStates.emplace(Loc, MatchState::InitialMismatch);
+    else if (IsPostMatch) {
+      // Update the state if it's not matched(UnchangedMatch or
+      // RecoveredMismatch).
+      if (It->second == MatchState::InitialMismatch)
+        It->second = MatchState::UnchangedMismatch;
+      else if (It->second == MatchState::InitialMatch)
+        It->second = MatchState::RemovedMatch;
+    }
+  }
+}
+
+void SampleProfileMatcher::countMismatchedFuncSamples(const FunctionSamples &FS,
+                                                      bool IsTopLevel) {
+  const auto *FuncDesc = ProbeManager->getDesc(FS.getGUID());
+  // Skip the function that is external or renamed.
+  if (!FuncDesc)
+    return;
+
+  if (ProbeManager->profileIsHashMismatched(*FuncDesc, FS)) {
+    if (IsTopLevel)
+      NumStaleProfileFunc++;
+    // Given currently all probe ids are after block probe ids, once the
+    // checksum is mismatched, it's likely all the callites are mismatched and
+    // dropped. We conservatively count all the samples as mismatched and stop
+    // counting the inlinees' profiles.
+    MismatchedFunctionSamples += FS.getTotalSamples();
+    return;
+  }
+
+  // Even the current-level function checksum is matched, it's possible that the
+  // nested inlinees' checksums are mismatched that affect the inlinee's sample
+  // loading, we need to go deeper to check the inlinees' function samples.
+  // Similarly, count all the samples as mismatched if the inlinee's checksum is
+  // mismatched using this recursive function.
+  for (const auto &I : FS.getCallsiteSamples())
+    for (const auto &CS : I.second)
+      countMismatchedFuncSamples(CS.second, false);
+}
+
+void SampleProfileMatcher::countMismatchedCallsiteSamples(
+    const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &CallsiteMatchStates = It->second;
+
+  auto findMatchState = [&](const LineLocation &Loc) {
+    auto It = CallsiteMatchStates.find(Loc);
+    if (It == CallsiteMatchStates.end())
+      return MatchState::Unknown;
+    return It->second;
+  };
+
+  auto AttributeMismatchedSamples = [&](const enum MatchState &State,
+                                        uint64_t Samples) {
+    if (isMismatchState(State))
+      MismatchedCallsiteSamples += Samples;
+    else if (State == MatchState::RecoveredMismatch)
+      RecoveredCallsiteSamples += Samples;
+  };
+
+  // The non-inlined callsites are saved in the body samples of function
+  // profile, go through it to count the non-inlined callsite samples.
+  for (const auto &I : FS.getBodySamples())
+    AttributeMismatchedSamples(findMatchState(I.first), I.second.getSamples());
+
+  // Count the inlined callsite samples.
+  for (const auto &I : FS.getCallsiteSamples()) {
+    auto State = findMatchState(I.first);
+    uint64_t CallsiteSamples = 0;
+    for (const auto &CS : I.second)
+      CallsiteSamples += CS.second.getTotalSamples();
+    AttributeMismatchedSamples(State, CallsiteSamples);
+
+    if (isMismatchState(State))
+      continue;
+
+    // When the current level of inlined call site matches the profiled call
+    // site, we need to go deeper along the inline tree to count mismatches from
+    // lower level inlinees.
+    for (const auto &CS : I.second)
+      countMismatchedCallsiteSamples(CS.second);
+  }
+}
+
+void SampleProfileMatcher::countMismatchCallsites(const FunctionSamples &FS) {
+  auto It = FuncCallsiteMatchStates.find(FS.getFuncName());
+  // Skip it if no mismatched callsite or this is an external function.
+  if (It == FuncCallsiteMatchStates.end() || It->second.empty())
+    return;
+  const auto &MatchStates = It->second;
+  [[maybe_unused]] bool OnInitialState =
+      isInitialState(MatchStates.begin()->second);
+  for (const auto &I : MatchStates) {
+    TotalProfiledCallsites++;
+    assert(
+        (OnInitialState ? isInitialState(I.second) : isFinalState(I.second)) &&
+        "Profile matching state is inconsistent");
+
+    if (isMismatchState(I.second))
+      NumMismatchedCallsites++;
+    else if (I.second == MatchState::RecoveredMismatch)
+      NumRecoveredCallsites++;
+  }
+}
+
+void SampleProfileMatcher::computeAndReportProfileStaleness() {
+  if (!ReportProfileStaleness && !PersistProfileStaleness)
+    return;
+
+  // Count profile mismatches for profile staleness report.
+  for (const auto &F : M) {
+    if (skipProfileForFunction(F))
+      continue;
+    // As the stats will be merged by linker, skip reporting the metrics for
+    // imported functions to avoid repeated counting.
+    if (GlobalValue::isAvailableExternallyLinkage(F.getLinkage()))
+      continue;
+    const auto *FS = Reader.getSamplesFor(F);
+    if (!FS)
+      continue;
+    TotalProfiledFunc++;
+    TotalFunctionSamples += FS->getTotalSamples();
+
+    // Checksum mismatch is only used in pseudo-probe mode.
+    if (FunctionSamples::ProfileIsProbeBased)
+      countMismatchedFuncSamples(*FS, true);
+
+    // Count mismatches and samples for calliste.
+    countMismatchCallsites(*FS);
+    countMismatchedCallsiteSamples(*FS);
+  }
+
+  if (ReportProfileStaleness) {
+    if (FunctionSamples::ProfileIsProbeBased) {
+      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
+             << " of functions' profile are invalid and "
+             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
+             << ") of samples are discarded due to function hash mismatch.\n";
+    }
+    errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
+           << TotalProfiledCallsites << ")"
+           << " of callsites' profile are invalid and "
+           << "(" << (MismatchedCallsiteSamples + RecoveredCallsiteSamples)
+           << "/" << TotalFunctionSamples << ")"
+           << " of samples are discarded due to callsite location mismatch.\n";
+    errs() << "(" << NumRecoveredCallsites << "/"
+           << (NumRecoveredCallsites + NumMismatchedCallsites) << ")"
+           << " of callsites and "
+           << "(" << RecoveredCallsiteSamples << "/"
+           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) << ")"
+           << " of samples are recovered by stale profile matching.\n";
+  }
+
+  if (PersistProfileStaleness) {
+    LLVMContext &Ctx = M.getContext();
+    MDBuilder MDB(Ctx);
+
+    SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
+    if (FunctionSamples::ProfileIsProbeBased) {
+      ProfStatsVec.emplace_back("NumStaleProfileFunc", NumStaleProfileFunc);
+      ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
+      ProfStatsVec.emplace_back("MismatchedFunctionSamples",
+                                MismatchedFunctionSamples);
+      ProfStatsVec.emplace_back("TotalFunctionSamples", TotalFunctionSamples);
+    }
+
+    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("NumRecoveredCallsites", NumRecoveredCallsites);
+    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
+    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
+                              MismatchedCallsiteSamples);
+    ProfStatsVec.emplace_back("RecoveredCallsiteSamples",
+                              RecoveredCallsiteSamples);
+
+    auto *MD = MDB.createLLVMStats(ProfStatsVec);
+    auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
+    NMD->addOperand(MD);
+  }
+}
+
+void SampleProfileMatcher::runOnModule() {
+  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+                                   FunctionSamples::ProfileIsCS);
+  for (auto &F : M) {
+    if (skipProfileForFunction(F))
+      continue;
+    runOnFunction(F);
+  }
+  if (SalvageStaleProfile)
+    distributeIRToProfileLocationMap();
+
+  computeAndReportProfileStaleness();
+}
+
+void SampleProfileMatcher::distributeIRToProfileLocationMap(
+    FunctionSamples &FS) {
+  const auto ProfileMappings = FuncMappings.find(FS.getFuncName());
+  if (ProfileMappings != FuncMappings.end()) {
+    FS.setIRToProfileLocationMap(&(ProfileMappings->second));
+  }
+
+  for (auto &Callees :
+       const_cast<CallsiteSampleMap &>(FS.getCallsiteSamples())) {
+    for (auto &FS : Callees.second) {
+      distributeIRToProfileLocationMap(FS.second);
+    }
+  }
+}
+
+// Use a central place to distribute the matching results. Outlined and inlined
+// profile with the function name will be set to the same pointer.
+void SampleProfileMatcher::distributeIRToProfileLocationMap() {
+  for (auto &I : Reader.getProfiles()) {
+    distributeIRToProfileLocationMap(I.second);
+  }
+}
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-callee-profile-mismatch.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-callee-profile-mismatch.ll
index e00b737cae4e85..4881937df101ac 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-callee-profile-mismatch.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-callee-profile-mismatch.ll
@@ -1,6 +1,6 @@
 ; REQUIRES: x86_64-linux
 ; REQUIRES: asserts
-; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-callee-profile-mismatch.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-impl  -pass-remarks=inline 2>&1 | FileCheck %s
+; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-callee-profile-mismatch.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-matcher,sample-profile-impl  -pass-remarks=inline 2>&1 | FileCheck %s
 
 
 ; CHECK: Run stale profile matching for bar
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching-lto.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching-lto.ll
index 270beee4ebc2bd..7aabeeca2585b6 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching-lto.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching-lto.ll
@@ -1,6 +1,6 @@
 ; REQUIRES: x86_64-linux
 ; REQUIRES: asserts
-; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-stale-profile-matching-lto.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-impl 2>&1 | FileCheck %s
+; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-stale-profile-matching-lto.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-matcher,sample-profile-impl 2>&1 | FileCheck %s
 
 
 ; CHECK: Run stale profile matching for main
diff --git a/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching.ll b/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching.ll
index 29877fb22a2c2e..0d471e43d2a723 100644
--- a/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching.ll
+++ b/llvm/test/Transforms/SampleProfile/pseudo-probe-stale-profile-matching.ll
@@ -1,6 +1,6 @@
 ; REQUIRES: x86_64-linux
 ; REQUIRES: asserts
-; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-stale-profile-matching.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-impl 2>&1 | FileCheck %s
+; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/pseudo-probe-stale-profile-matching.prof --salvage-stale-profile -S --debug-only=sample-profile,sample-profile-matcher,sample-profile-impl 2>&1 | FileCheck %s
 
 ; The profiled source code:
 

>From be054980120a22f2f38a69dbd027b0f95e8baf6d Mon Sep 17 00:00:00 2001
From: wlei <wlei at fb.com>
Date: Thu, 28 Mar 2024 15:21:24 -0700
Subject: [PATCH 2/2] fix lint

---
 .../Transforms/IPO/SampleProfileMatcher.cpp   | 25 +++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp b/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp
index aa0a7364839c9b..bb46539989ab5b 100644
--- a/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfileMatcher.cpp
@@ -469,23 +469,22 @@ void SampleProfileMatcher::computeAndReportProfileStaleness() {
 
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc << ")"
-             << " of functions' profile are invalid and "
-             << " (" << MismatchedFunctionSamples << "/" << TotalFunctionSamples
+      errs() << "(" << NumStaleProfileFunc << "/" << TotalProfiledFunc
+             << ") of functions' profile are invalid and ("
+             << MismatchedFunctionSamples << "/" << TotalFunctionSamples
              << ") of samples are discarded due to function hash mismatch.\n";
     }
     errs() << "(" << (NumMismatchedCallsites + NumRecoveredCallsites) << "/"
-           << TotalProfiledCallsites << ")"
-           << " of callsites' profile are invalid and "
-           << "(" << (MismatchedCallsiteSamples + RecoveredCallsiteSamples)
-           << "/" << TotalFunctionSamples << ")"
-           << " of samples are discarded due to callsite location mismatch.\n";
+           << TotalProfiledCallsites
+           << ") of callsites' profile are invalid and ("
+           << (MismatchedCallsiteSamples + RecoveredCallsiteSamples) << "/"
+           << TotalFunctionSamples
+           << ") of samples are discarded due to callsite location mismatch.\n";
     errs() << "(" << NumRecoveredCallsites << "/"
-           << (NumRecoveredCallsites + NumMismatchedCallsites) << ")"
-           << " of callsites and "
-           << "(" << RecoveredCallsiteSamples << "/"
-           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples) << ")"
-           << " of samples are recovered by stale profile matching.\n";
+           << (NumRecoveredCallsites + NumMismatchedCallsites)
+           << ") of callsites and (" << RecoveredCallsiteSamples << "/"
+           << (RecoveredCallsiteSamples + MismatchedCallsiteSamples)
+           << ") of samples are recovered by stale profile matching.\n";
   }
 
   if (PersistProfileStaleness) {



More information about the llvm-commits mailing list