[llvm] [CSSPGO] Compute and report profile matching recovered callsites and samples (PR #79090)

Lei Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 31 16:29:40 PST 2024


================
@@ -2460,63 +2528,108 @@ void SampleProfileMatcher::runOnFunction(const Function &F) {
       !ProbeManager->profileIsValid(F, *FSFlattened)) {
     // The matching result will be saved to IRToProfileLocationMap, create a new
     // map for each function.
+    auto &IRToProfileLocationMap = getIRToProfileLocationMap(F);
     runStaleProfileMatching(F, IRAnchors, ProfileAnchors,
-                            getIRToProfileLocationMap(F));
+                            IRToProfileLocationMap);
+    PostMatchStats.countMismatchedCallsites(F, IRAnchors, ProfileAnchors,
+                                            IRToProfileLocationMap);
   }
 }
 
-void SampleProfileMatcher::runOnModule() {
-  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
-                                   FunctionSamples::ProfileIsCS);
-  for (auto &F : M) {
-    if (F.isDeclaration() || !F.hasFnAttribute("use-sample-profile"))
-      continue;
-    runOnFunction(F);
-  }
-  if (SalvageStaleProfile)
-    distributeIRToProfileLocationMap();
-
+void SampleProfileMatcher::reportOrPersistProfileStats() {
   if (ReportProfileStaleness) {
     if (FunctionSamples::ProfileIsProbeBased) {
-      errs() << "(" << NumMismatchedFuncHash << "/" << TotalProfiledFunc << ")"
+      errs() << "(" << PreMatchStats.NumMismatchedFuncHash << "/"
+             << PreMatchStats.TotalProfiledFunc << ")"
              << " of functions' profile are invalid and "
-             << " (" << MismatchedFuncHashSamples << "/" << TotalFuncHashSamples
-             << ")"
+             << " (" << PreMatchStats.MismatchedFuncHashSamples << "/"
+             << PreMatchStats.TotalFunctionSamples << ")"
              << " of samples are discarded due to function hash mismatch.\n";
     }
-    errs() << "(" << NumMismatchedCallsites << "/" << TotalProfiledCallsites
-           << ")"
+    errs() << "(" << PreMatchStats.NumMismatchedCallsites << "/"
+           << PreMatchStats.TotalProfiledCallsites << ")"
            << " of callsites' profile are invalid and "
-           << "(" << MismatchedCallsiteSamples << "/" << TotalCallsiteSamples
-           << ")"
+           << "(" << PreMatchStats.MismatchedCallsiteSamples << "/"
+           << PreMatchStats.TotalFunctionSamples << ")"
            << " of samples are discarded due to callsite location mismatch.\n";
+    if (SalvageStaleProfile) {
+      uint64_t NumRecoveredCallsites = PostMatchStats.TotalProfiledCallsites -
+                                       PostMatchStats.NumMismatchedCallsites;
+      uint64_t NumMismatchedCallsites =
+          PreMatchStats.NumMismatchedCallsites - NumRecoveredCallsites;
+      errs() << "Out of " << PostMatchStats.TotalProfiledCallsites
+             << " callsites used for profile matching, "
+             << NumRecoveredCallsites
+             << " callsites have been recovered. After the matching, ("
+             << NumMismatchedCallsites << "/"
+             << PreMatchStats.TotalProfiledCallsites
+             << ") of callsites are still invalid ("
+             << PostMatchStats.MismatchedCallsiteSamples << "/"
+             << PreMatchStats.TotalFunctionSamples << ")"
+             << " of samples are still discarded.\n";
+    }
   }
 
   if (PersistProfileStaleness) {
     LLVMContext &Ctx = M.getContext();
     MDBuilder MDB(Ctx);
 
     SmallVector<std::pair<StringRef, uint64_t>> ProfStatsVec;
+    ProfStatsVec.emplace_back("NumMismatchedCallsites",
+                              PreMatchStats.NumMismatchedCallsites);
+    ProfStatsVec.emplace_back("TotalProfiledCallsites",
+                              PreMatchStats.TotalProfiledCallsites);
+    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
+                              PreMatchStats.MismatchedCallsiteSamples);
+    ProfStatsVec.emplace_back("TotalProfiledFunc",
+                              PreMatchStats.TotalProfiledFunc);
+    ProfStatsVec.emplace_back("TotalFunctionSamples",
+                              PreMatchStats.TotalFunctionSamples);
     if (FunctionSamples::ProfileIsProbeBased) {
-      ProfStatsVec.emplace_back("NumMismatchedFuncHash", NumMismatchedFuncHash);
-      ProfStatsVec.emplace_back("TotalProfiledFunc", TotalProfiledFunc);
+      ProfStatsVec.emplace_back("NumMismatchedFuncHash",
+                                PreMatchStats.NumMismatchedFuncHash);
       ProfStatsVec.emplace_back("MismatchedFuncHashSamples",
-                                MismatchedFuncHashSamples);
-      ProfStatsVec.emplace_back("TotalFuncHashSamples", TotalFuncHashSamples);
+                                PreMatchStats.MismatchedFuncHashSamples);
+    }
+    if (SalvageStaleProfile) {
+      ProfStatsVec.emplace_back("PostMatchNumMismatchedCallsites",
+                                PostMatchStats.NumMismatchedCallsites);
+      ProfStatsVec.emplace_back("NumCallsitesForMatching",
+                                PostMatchStats.TotalProfiledCallsites);
+      ProfStatsVec.emplace_back("PostMatchMismatchedCallsiteSamples",
+                                PostMatchStats.MismatchedCallsiteSamples);
     }
-
-    ProfStatsVec.emplace_back("NumMismatchedCallsites", NumMismatchedCallsites);
-    ProfStatsVec.emplace_back("TotalProfiledCallsites", TotalProfiledCallsites);
-    ProfStatsVec.emplace_back("MismatchedCallsiteSamples",
-                              MismatchedCallsiteSamples);
-    ProfStatsVec.emplace_back("TotalCallsiteSamples", TotalCallsiteSamples);
 
     auto *MD = MDB.createLLVMStats(ProfStatsVec);
     auto *NMD = M.getOrInsertNamedMetadata("llvm.stats");
     NMD->addOperand(MD);
   }
 }
 
+void SampleProfileMatcher::runOnModule() {
+  ProfileConverter::flattenProfile(Reader.getProfiles(), FlattenedProfiles,
+                                   FunctionSamples::ProfileIsCS);
+  for (auto &F : M) {
+    if (ShouldSkipProfileLoading(F))
+      continue;
+    runOnFunction(F);
+  }
+
+  if (SalvageStaleProfile)
+    distributeIRToProfileLocationMap();
+
+  PreMatchStats.countMismatchedCallsiteSamples();
+  if (SalvageStaleProfile) {
+    // If a function doesn't run the matching but has mismatched callsites, this
+    // won't be any data for that function in post-match stats, so just reuse
+    // the pre-match stats.
+    PostMatchStats.copyUnchangedCallsiteMismatches(
+        PreMatchStats.FuncMismatchedCallsites);
+    PostMatchStats.countMismatchedCallsiteSamples();
----------------
wlei-llvm wrote:

Pushed a new version. The changes are:

Most of the counting part are moved out of the function-level matching, and moved into a new function `computeAndReportProfileStaleness`, which is a module-level function. It includes all the counts about function checksum mismatch and callsite mismatch, and both weighted(samples) and unweighted mismatches. 

Extended the (the old FuncMismatchedCallsites) to `  StringMap<std::unordered_map<LineLocation, MatchState, LineLocationHash>>
      FuncCallsiteMatchStates;`, use a enum MatchState to present the match state.  one reason is now we can compress the struture as we discussed before. Another reason is we have to move `total profiled function metrics` or `total profiled callsites` out into `computeAndReportProfileStaleness`.  it's not easy /clean to compute the two metrics in the old way because it's only under non-imported function, however, for match states, we need to count the imported function too.  So to achieve this, we need a "Matched" callsite state used to count the original profiled function.

https://github.com/llvm/llvm-project/pull/79090


More information about the llvm-commits mailing list