[llvm] a98d6a1 - [SamplePGO] Stale profile matching(part 1)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 28 13:13:55 PDT 2023


Author: wlei
Date: 2023-04-28T13:07:32-07:00
New Revision: a98d6a11ea192756ee99dc2547d46599141977f2

URL: https://github.com/llvm/llvm-project/commit/a98d6a11ea192756ee99dc2547d46599141977f2
DIFF: https://github.com/llvm/llvm-project/commit/a98d6a11ea192756ee99dc2547d46599141977f2.diff

LOG: [SamplePGO] Stale profile matching(part 1)

AutoFDO/CSSPGO often has to deal with stale profiles collected on binaries built from several revisions behind release. It’s likely to get incorrect profile annotations using the stale profile, which results in unstable or low performing binaries. Currently for source location based profile, once a code change causes a profile mismatch, all the locations afterward are mismatched, the affected samples or inlining info are lost. If we can provide a matching framework to reuse parts of the mismatched profile - aka incremental PGO, it will make PGO more stable, also increase the optimization coverage and boost the performance of binary.

This patch is the part 1 of stale profile matching, summary of the implementation:
 - Added a structure for the matching result:`LocToLocMap`, which is a location to location map meaning the location of current build is matched to the location of the previous build(to be used to query the “stale” profile).
 - In order to use the matching results for sample query, we need to pass them to all the location queries. For code cleanliness, we added a new pointer field(`IRToProfileLocationMap`) to `FunctionSamples`.
 - Added a wrapper(`mapIRLocToProfileLoc`) for the query to the location, the location from input IR will be remapped to the matched profile location.
 - Added a new switch `--salvage-stale-profile`.
 - Some refactoring for the staleness detection.

Test case is in part 2 with the matching algorithm.

Reviewed By: wenlei

Differential Revision: https://reviews.llvm.org/D147456

Added: 
    

Modified: 
    llvm/include/llvm/ProfileData/SampleProf.h
    llvm/lib/ProfileData/SampleProf.cpp
    llvm/lib/Transforms/IPO/SampleProfile.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index ddab9a1e49b3d..69003d32699b5 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -723,6 +723,8 @@ using BodySampleMap = std::map<LineLocation, SampleRecord>;
 // memory, which is *very* significant for large profiles.
 using FunctionSamplesMap = std::map<std::string, FunctionSamples, std::less<>>;
 using CallsiteSampleMap = std::map<LineLocation, FunctionSamplesMap>;
+using LocToLocMap =
+    std::unordered_map<LineLocation, LineLocation, LineLocationHash>;
 
 /// Representation of the samples collected for a function.
 ///
@@ -842,12 +844,26 @@ class FunctionSamples {
     }
   }
 
+  // Query the stale profile matching results and remap the location.
+  const LineLocation &mapIRLocToProfileLoc(const LineLocation &IRLoc) const {
+    // There is no remapping if the profile is not stale or the matching gives
+    // the same location.
+    if (!IRToProfileLocationMap)
+      return IRLoc;
+    const auto &ProfileLoc = IRToProfileLocationMap->find(IRLoc);
+    if (ProfileLoc != IRToProfileLocationMap->end())
+      return ProfileLoc->second;
+    else
+      return IRLoc;
+  }
+
   /// Return the number of samples collected at the given location.
   /// Each location is specified by \p LineOffset and \p Discriminator.
   /// If the location is not found in profile, return error.
   ErrorOr<uint64_t> findSamplesAt(uint32_t LineOffset,
                                   uint32_t Discriminator) const {
-    const auto &ret = BodySamples.find(LineLocation(LineOffset, Discriminator));
+    const auto &ret = BodySamples.find(
+        mapIRLocToProfileLoc(LineLocation(LineOffset, Discriminator)));
     if (ret == BodySamples.end())
       return std::error_code();
     return ret->second.getSamples();
@@ -858,7 +874,8 @@ class FunctionSamples {
   /// If the location is not found in profile, return error.
   ErrorOr<SampleRecord::CallTargetMap>
   findCallTargetMapAt(uint32_t LineOffset, uint32_t Discriminator) const {
-    const auto &ret = BodySamples.find(LineLocation(LineOffset, Discriminator));
+    const auto &ret = BodySamples.find(
+        mapIRLocToProfileLoc(LineLocation(LineOffset, Discriminator)));
     if (ret == BodySamples.end())
       return std::error_code();
     return ret->second.getCallTargets();
@@ -868,7 +885,7 @@ class FunctionSamples {
   /// CallSite. If the location is not found in profile, return error.
   ErrorOr<SampleRecord::CallTargetMap>
   findCallTargetMapAt(const LineLocation &CallSite) const {
-    const auto &Ret = BodySamples.find(CallSite);
+    const auto &Ret = BodySamples.find(mapIRLocToProfileLoc(CallSite));
     if (Ret == BodySamples.end())
       return std::error_code();
     return Ret->second.getCallTargets();
@@ -876,13 +893,13 @@ class FunctionSamples {
 
   /// Return the function samples at the given callsite location.
   FunctionSamplesMap &functionSamplesAt(const LineLocation &Loc) {
-    return CallsiteSamples[Loc];
+    return CallsiteSamples[mapIRLocToProfileLoc(Loc)];
   }
 
   /// Returns the FunctionSamplesMap at the given \p Loc.
   const FunctionSamplesMap *
   findFunctionSamplesMapAt(const LineLocation &Loc) const {
-    auto iter = CallsiteSamples.find(Loc);
+    auto iter = CallsiteSamples.find(mapIRLocToProfileLoc(Loc));
     if (iter == CallsiteSamples.end())
       return nullptr;
     return &iter->second;
@@ -1046,6 +1063,11 @@ class FunctionSamples {
 
   uint64_t getFunctionHash() const { return FunctionHash; }
 
+  void setIRToProfileLocationMap(const LocToLocMap *LTLM) {
+    assert(IRToProfileLocationMap == nullptr && "this should be set only once");
+    IRToProfileLocationMap = LTLM;
+  }
+
   /// Return the canonical name for a function, taking into account
   /// suffix elision policy attributes.
   static StringRef getCanonicalFnName(const Function &F) {
@@ -1229,6 +1251,25 @@ class FunctionSamples {
   /// in the call to bar() at line offset 1, the other for all the samples
   /// collected in the call to baz() at line offset 8.
   CallsiteSampleMap CallsiteSamples;
+
+  /// IR to profile location map generated by stale profile matching.
+  ///
+  /// Each entry is a mapping from the location on current build to the matched
+  /// location in the "stale" profile. For example:
+  ///   Profiled source code:
+  ///      void foo() {
+  ///   1    bar();
+  ///      }
+  ///
+  ///   Current source code:
+  ///      void foo() {
+  ///   1    // Code change
+  ///   2    bar();
+  ///      }
+  /// Supposing the stale profile matching algorithm generated the mapping [2 ->
+  /// 1], the profile query using the location of bar on the IR which is 2 will
+  /// be remapped to 1 and find the location of bar in the profile.
+  const LocToLocMap *IRToProfileLocationMap = nullptr;
 };
 
 raw_ostream &operator<<(raw_ostream &OS, const FunctionSamples &FS);

diff  --git a/llvm/lib/ProfileData/SampleProf.cpp b/llvm/lib/ProfileData/SampleProf.cpp
index 3c60ef436db91..fdae8a011e712 100644
--- a/llvm/lib/ProfileData/SampleProf.cpp
+++ b/llvm/lib/ProfileData/SampleProf.cpp
@@ -291,7 +291,7 @@ const FunctionSamples *FunctionSamples::findFunctionSamplesAt(
   std::string CalleeGUID;
   CalleeName = getRepInFormat(CalleeName, UseMD5, CalleeGUID);
 
-  auto iter = CallsiteSamples.find(Loc);
+  auto iter = CallsiteSamples.find(mapIRLocToProfileLoc(Loc));
   if (iter == CallsiteSamples.end())
     return nullptr;
   auto FS = iter->second.find(CalleeName);

diff  --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 6c430aa5988c6..b80572adfef8d 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -128,6 +128,11 @@ static cl::opt<std::string> SampleProfileRemappingFile(
     "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"),
     cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden);
 
+static cl::opt<bool> SalvageStaleProfile(
+    "salvage-stale-profile", cl::Hidden, cl::init(false),
+    cl::desc("Salvage stale profile by fuzzy matching and use the remapped "
+             "location for sample profile query."));
+
 static cl::opt<bool> ReportProfileStaleness(
     "report-profile-staleness", cl::Hidden, cl::init(false),
     cl::desc("Compute and report stale profile statistical metrics."));
@@ -458,7 +463,9 @@ class SampleProfileMatcher {
                                        FunctionSamples::ProfileIsCS);
     }
   }
+  void runOnModule();
 
+private:
   FunctionSamples *getFlattenedSamplesFor(const Function &F) {
     StringRef CanonFName = FunctionSamples::getCanonicalFnName(F);
     auto It = FlattenedProfiles.find(CanonFName);
@@ -466,9 +473,11 @@ class SampleProfileMatcher {
       return &It->second;
     return nullptr;
   }
-
-  void detectProfileMismatch();
-  void detectProfileMismatch(const Function &F, const FunctionSamples &FS);
+  void runOnFunction(const Function &F, const FunctionSamples &FS);
+  void countProfileMismatches(
+      const FunctionSamples &FS,
+      const std::unordered_set<LineLocation, LineLocationHash>
+          &MatchedCallsiteLocs);
 };
 
 /// Sample profile pass.
@@ -2071,7 +2080,8 @@ bool SampleProfileLoader::doInitialization(Module &M,
     }
   }
 
-  if (ReportProfileStaleness || PersistProfileStaleness) {
+  if (ReportProfileStaleness || PersistProfileStaleness ||
+      SalvageStaleProfile) {
     MatchingManager =
         std::make_unique<SampleProfileMatcher>(M, *Reader, ProbeManager.get());
   }
@@ -2079,8 +2089,53 @@ bool SampleProfileLoader::doInitialization(Module &M,
   return true;
 }
 
-void SampleProfileMatcher::detectProfileMismatch(const Function &F,
-                                                 const FunctionSamples &FS) {
+void SampleProfileMatcher::countProfileMismatches(
+    const FunctionSamples &FS,
+    const std::unordered_set<LineLocation, LineLocationHash>
+        &MatchedCallsiteLocs) {
+
+  auto isInvalidLineOffset = [](uint32_t LineOffset) {
+    return LineOffset & 0x8000;
+  };
+
+  // Check if there are any callsites in the profile that does not match to any
+  // IR callsites, those callsite samples will be discarded.
+  for (auto &I : FS.getBodySamples()) {
+    const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
+
+    uint64_t Count = I.second.getSamples();
+    if (!I.second.getCallTargets().empty()) {
+      TotalCallsiteSamples += Count;
+      TotalProfiledCallsites++;
+      if (!MatchedCallsiteLocs.count(Loc)) {
+        MismatchedCallsiteSamples += Count;
+        NumMismatchedCallsites++;
+      }
+    }
+  }
+
+  for (auto &I : FS.getCallsiteSamples()) {
+    const LineLocation &Loc = I.first;
+    if (isInvalidLineOffset(Loc.LineOffset))
+      continue;
+
+    uint64_t Count = 0;
+    for (auto &FM : I.second) {
+      Count += FM.second.getHeadSamplesEstimate();
+    }
+    TotalCallsiteSamples += Count;
+    TotalProfiledCallsites++;
+    if (!MatchedCallsiteLocs.count(Loc)) {
+      MismatchedCallsiteSamples += Count;
+      NumMismatchedCallsites++;
+    }
+  }
+}
+
+void SampleProfileMatcher::runOnFunction(const Function &F,
+                                         const FunctionSamples &FS) {
   if (FunctionSamples::ProfileIsProbeBased) {
     uint64_t Count = FS.getTotalSamples();
     TotalFuncHashSamples += Count;
@@ -2130,47 +2185,12 @@ void SampleProfileMatcher::detectProfileMismatch(const Function &F,
     }
   }
 
-  auto isInvalidLineOffset = [](uint32_t LineOffset) {
-    return LineOffset & 0x8000;
-  };
-
-  // Check if there are any callsites in the profile that does not match to any
-  // IR callsites, those callsite samples will be discarded.
-  for (auto &I : FS.getBodySamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
-
-    uint64_t Count = I.second.getSamples();
-    if (!I.second.getCallTargets().empty()) {
-      TotalCallsiteSamples += Count;
-      TotalProfiledCallsites++;
-      if (!MatchedCallsiteLocs.count(Loc)) {
-        MismatchedCallsiteSamples += Count;
-        NumMismatchedCallsites++;
-      }
-    }
-  }
-
-  for (auto &I : FS.getCallsiteSamples()) {
-    const LineLocation &Loc = I.first;
-    if (isInvalidLineOffset(Loc.LineOffset))
-      continue;
-
-    uint64_t Count = 0;
-    for (auto &FM : I.second) {
-      Count += FM.second.getHeadSamplesEstimate();
-    }
-    TotalCallsiteSamples += Count;
-    TotalProfiledCallsites++;
-    if (!MatchedCallsiteLocs.count(Loc)) {
-      MismatchedCallsiteSamples += Count;
-      NumMismatchedCallsites++;
-    }
-  }
+  // Detect profile mismatch for profile staleness metrics report.
+  if (ReportProfileStaleness || PersistProfileStaleness)
+    countProfileMismatches(FS, MatchedCallsiteLocs);
 }
 
-void SampleProfileMatcher::detectProfileMismatch() {
+void SampleProfileMatcher::runOnModule() {
   for (auto &F : M) {
     if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
       continue;
@@ -2181,7 +2201,7 @@ void SampleProfileMatcher::detectProfileMismatch() {
       FS = Reader.getSamplesFor(F);
     if (!FS)
       continue;
-    detectProfileMismatch(F, *FS);
+    runOnFunction(F, *FS);
   }
 
   if (ReportProfileStaleness) {
@@ -2270,8 +2290,10 @@ bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM,
   assert(SymbolMap.count(StringRef()) == 0 &&
          "No empty StringRef should be added in SymbolMap");
 
-  if (ReportProfileStaleness || PersistProfileStaleness)
-    MatchingManager->detectProfileMismatch();
+  if (ReportProfileStaleness || PersistProfileStaleness ||
+      SalvageStaleProfile) {
+    MatchingManager->runOnModule();
+  }
 
   bool retval = false;
   for (auto *F : buildFunctionOrder(M, CG)) {


        


More information about the llvm-commits mailing list