[llvm] aa58b7b - [CSSPGO][llvm-profgen] Reimplement computeSummaryAndThreshold using context trie

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 27 23:30:09 PDT 2022


Author: wlei
Date: 2022-06-27T23:22:21-07:00
New Revision: aa58b7b1e30fbbd9c8c2bf6ba291f1742f53afed

URL: https://github.com/llvm/llvm-project/commit/aa58b7b1e30fbbd9c8c2bf6ba291f1742f53afed
DIFF: https://github.com/llvm/llvm-project/commit/aa58b7b1e30fbbd9c8c2bf6ba291f1742f53afed.diff

LOG: [CSSPGO][llvm-profgen] Reimplement computeSummaryAndThreshold using context trie

Follow-up patch to https://reviews.llvm.org/D125246, support `computeSummaryAndThreshold` based on context trie.

Reviewed By: hoy, wenlei

Differential Revision: https://reviews.llvm.org/D127026

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
    llvm/lib/Transforms/IPO/SampleContextTracker.cpp
    llvm/tools/llvm-profgen/ProfileGenerator.cpp
    llvm/tools/llvm-profgen/ProfileGenerator.h
    llvm/tools/llvm-profgen/llvm-profgen.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h b/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
index 45dec6242f2c2..3bc417c9d2d1b 100644
--- a/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
+++ b/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
@@ -142,6 +142,9 @@ class SampleContextTracker {
   ContextTrieNode &getRootContext();
   void promoteMergeContextSamplesTree(const Instruction &Inst,
                                       StringRef CalleeName);
+
+  // Create a merged conext-less profile map.
+  void createContextLessProfileMap(SampleProfileMap &ContextLessProfiles);
   // Dump the internal context profile trie.
   void dump();
 
@@ -158,7 +161,6 @@ class SampleContextTracker {
   promoteMergeContextSamplesTree(ContextTrieNode &FromNode,
                                  ContextTrieNode &ToNodeParent,
                                  uint32_t ContextFramesToRemove);
-
   // Map from function name to context profiles (excluding base profile)
   StringMap<ContextSamplesTy> FuncToCtxtProfiles;
 

diff  --git a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
index b154aa51ed844..2471f3271308a 100644
--- a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
+++ b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
@@ -595,4 +595,24 @@ ContextTrieNode &SampleContextTracker::promoteMergeContextSamplesTree(
 
   return *ToNode;
 }
+
+void SampleContextTracker::createContextLessProfileMap(
+    SampleProfileMap &ContextLessProfiles) {
+  std::queue<ContextTrieNode *> NodeQueue;
+  NodeQueue.push(&RootContext);
+
+  while (!NodeQueue.empty()) {
+    ContextTrieNode *Node = NodeQueue.front();
+    FunctionSamples *FProfile = Node->getFunctionSamples();
+    NodeQueue.pop();
+
+    if (FProfile) {
+      // Profile's context can be empty, use ContextNode's func name.
+      ContextLessProfiles[Node->getFuncName()].merge(*FProfile);
+    }
+
+    for (auto &It : Node->getAllChildContext())
+      NodeQueue.push(&It.second);
+  }
+}
 } // namespace llvm

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.cpp b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
index 4bc72f7678216..90df77171cb60 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.cpp
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
@@ -91,6 +91,7 @@ static cl::opt<bool> UpdateTotalSamples(
     llvm::cl::Optional);
 
 extern cl::opt<int> ProfileSummaryCutoffHot;
+extern cl::opt<bool> UseContextLessSummary;
 
 static cl::opt<bool> GenCSNestedProfile(
     "gen-cs-nested-profile", cl::Hidden, cl::init(true),
@@ -128,14 +129,13 @@ ProfileGeneratorBase::create(ProfiledBinary *Binary,
 }
 
 std::unique_ptr<ProfileGeneratorBase>
-ProfileGeneratorBase::create(ProfiledBinary *Binary,
-                             const SampleProfileMap &&Profiles,
+ProfileGeneratorBase::create(ProfiledBinary *Binary, SampleProfileMap &Profiles,
                              bool ProfileIsCS) {
   std::unique_ptr<ProfileGeneratorBase> Generator;
   if (ProfileIsCS) {
     if (Binary->useFSDiscriminator())
       exitWithError("FS discriminator is not supported in CS profile.");
-    Generator.reset(new CSProfileGenerator(Binary, std::move(Profiles)));
+    Generator.reset(new CSProfileGenerator(Binary, Profiles));
   } else {
     Generator.reset(new ProfileGenerator(Binary, std::move(Profiles)));
   }
@@ -403,43 +403,73 @@ void ProfileGeneratorBase::updateFunctionSamples() {
 
 void ProfileGeneratorBase::collectProfiledFunctions() {
   std::unordered_set<const BinaryFunction *> ProfiledFunctions;
-  if (SampleCounters) {
-    // Go through all the stacks, ranges and branches in sample counters, use
-    // the start of the range to look up the function it belongs and record the
-    // function.
-    for (const auto &CI : *SampleCounters) {
-      if (const auto *CtxKey = dyn_cast<AddrBasedCtxKey>(CI.first.getPtr())) {
-        for (auto Addr : CtxKey->Context) {
-          if (FuncRange *FRange = Binary->findFuncRangeForOffset(
-                  Binary->virtualAddrToOffset(Addr)))
-            ProfiledFunctions.insert(FRange->Func);
-        }
-      }
+  if (collectFunctionsFromRawProfile(ProfiledFunctions))
+    Binary->setProfiledFunctions(ProfiledFunctions);
+  else if (collectFunctionsFromLLVMProfile(ProfiledFunctions))
+    Binary->setProfiledFunctions(ProfiledFunctions);
+  else
+    llvm_unreachable("Unsupported input profile");
+}
 
-      for (auto Item : CI.second.RangeCounter) {
-        uint64_t StartOffset = Item.first.first;
-        if (FuncRange *FRange = Binary->findFuncRangeForOffset(StartOffset))
+bool ProfileGeneratorBase::collectFunctionsFromRawProfile(
+    std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
+  if (!SampleCounters)
+    return false;
+  // Go through all the stacks, ranges and branches in sample counters, use
+  // the start of the range to look up the function it belongs and record the
+  // function.
+  for (const auto &CI : *SampleCounters) {
+    if (const auto *CtxKey = dyn_cast<AddrBasedCtxKey>(CI.first.getPtr())) {
+      for (auto Addr : CtxKey->Context) {
+        if (FuncRange *FRange = Binary->findFuncRangeForOffset(
+                Binary->virtualAddrToOffset(Addr)))
           ProfiledFunctions.insert(FRange->Func);
       }
+    }
 
-      for (auto Item : CI.second.BranchCounter) {
-        uint64_t SourceOffset = Item.first.first;
-        uint64_t TargetOffset = Item.first.first;
-        if (FuncRange *FRange = Binary->findFuncRangeForOffset(SourceOffset))
-          ProfiledFunctions.insert(FRange->Func);
-        if (FuncRange *FRange = Binary->findFuncRangeForOffset(TargetOffset))
-          ProfiledFunctions.insert(FRange->Func);
-      }
+    for (auto Item : CI.second.RangeCounter) {
+      uint64_t StartOffset = Item.first.first;
+      if (FuncRange *FRange = Binary->findFuncRangeForOffset(StartOffset))
+        ProfiledFunctions.insert(FRange->Func);
     }
-  } else {
-    // This is for the case the input is a llvm sample profile.
-    for (const auto &FS : ProfileMap) {
-      if (auto *Func = Binary->getBinaryFunction(FS.first.getName()))
-        ProfiledFunctions.insert(Func);
+
+    for (auto Item : CI.second.BranchCounter) {
+      uint64_t SourceOffset = Item.first.first;
+      uint64_t TargetOffset = Item.first.first;
+      if (FuncRange *FRange = Binary->findFuncRangeForOffset(SourceOffset))
+        ProfiledFunctions.insert(FRange->Func);
+      if (FuncRange *FRange = Binary->findFuncRangeForOffset(TargetOffset))
+        ProfiledFunctions.insert(FRange->Func);
     }
   }
+  return true;
+}
+
+bool ProfileGenerator::collectFunctionsFromLLVMProfile(
+    std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
+  for (const auto &FS : ProfileMap) {
+    if (auto *Func = Binary->getBinaryFunction(FS.first.getName()))
+      ProfiledFunctions.insert(Func);
+  }
+  return true;
+}
 
-  Binary->setProfiledFunctions(ProfiledFunctions);
+bool CSProfileGenerator::collectFunctionsFromLLVMProfile(
+    std::unordered_set<const BinaryFunction *> &ProfiledFunctions) {
+  std::queue<ContextTrieNode *> NodeQueue;
+  NodeQueue.push(&getRootContext());
+  while (!NodeQueue.empty()) {
+    ContextTrieNode *Node = NodeQueue.front();
+    NodeQueue.pop();
+
+    if (!Node->getFuncName().empty())
+      if (auto *Func = Binary->getBinaryFunction(Node->getFuncName()))
+        ProfiledFunctions.insert(Func);
+
+    for (auto &It : Node->getAllChildContext())
+      NodeQueue.push(&It.second);
+  }
+  return true;
 }
 
 FunctionSamples &
@@ -471,7 +501,7 @@ void ProfileGenerator::generateProfile() {
 }
 
 void ProfileGenerator::postProcessProfiles() {
-  computeSummaryAndThreshold();
+  computeSummaryAndThreshold(ProfileMap);
   trimColdProfiles(ProfileMap, ColdCountThreshold);
   calculateAndShowDensity(ProfileMap);
 }
@@ -965,13 +995,12 @@ void CSProfileGenerator::convertToProfileMap() {
 }
 
 void CSProfileGenerator::postProcessProfiles() {
-  if (SampleCounters)
-    convertToProfileMap();
-
   // Compute hot/cold threshold based on profile. This will be used for cold
   // context profile merging/trimming.
   computeSummaryAndThreshold();
 
+  convertToProfileMap();
+
   // Run global pre-inliner to adjust/merge context profile based on estimated
   // inline decisions.
   if (EnableCSPreInliner) {
@@ -1003,15 +1032,33 @@ void CSProfileGenerator::postProcessProfiles() {
   }
 }
 
-void ProfileGeneratorBase::computeSummaryAndThreshold() {
+void ProfileGeneratorBase::computeSummaryAndThreshold(
+    SampleProfileMap &Profiles) {
   SampleProfileSummaryBuilder Builder(ProfileSummaryBuilder::DefaultCutoffs);
-  Summary = Builder.computeSummaryForProfiles(ProfileMap);
+  Summary = Builder.computeSummaryForProfiles(Profiles);
   HotCountThreshold = ProfileSummaryBuilder::getHotCountThreshold(
       (Summary->getDetailedSummary()));
   ColdCountThreshold = ProfileSummaryBuilder::getColdCountThreshold(
       (Summary->getDetailedSummary()));
 }
 
+void CSProfileGenerator::computeSummaryAndThreshold() {
+  // Always merge and use context-less profile map to compute summary.
+  SampleProfileMap ContextLessProfiles;
+  ContextTracker.createContextLessProfileMap(ContextLessProfiles);
+
+  // Set the flag below to avoid merging the profile again in
+  // computeSummaryAndThreshold
+  FunctionSamples::ProfileIsCS = false;
+  assert(
+      (!UseContextLessSummary.getNumOccurrences() || UseContextLessSummary) &&
+      "Don't set --profile-summary-contextless to false for profile "
+      "generation");
+  ProfileGeneratorBase::computeSummaryAndThreshold(ContextLessProfiles);
+  // Recover the old value.
+  FunctionSamples::ProfileIsCS = true;
+}
+
 void ProfileGeneratorBase::extractProbesFromRange(
     const RangeSample &RangeCounter, ProbeCounterMap &ProbeCounter,
     bool FindDisjointRanges) {

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.h b/llvm/tools/llvm-profgen/ProfileGenerator.h
index f26c3d0d82097..0ce464506be15 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.h
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.h
@@ -32,6 +32,7 @@ using ProbeCounterMap =
 class ProfileGeneratorBase {
 
 public:
+  ProfileGeneratorBase(ProfiledBinary *Binary) : Binary(Binary){};
   ProfileGeneratorBase(ProfiledBinary *Binary,
                        const ContextSampleCounterMap *Counters)
       : Binary(Binary), SampleCounters(Counters){};
@@ -44,7 +45,7 @@ class ProfileGeneratorBase {
   create(ProfiledBinary *Binary, const ContextSampleCounterMap *Counters,
          bool profileIsCS);
   static std::unique_ptr<ProfileGeneratorBase>
-  create(ProfiledBinary *Binary, const SampleProfileMap &&ProfileMap,
+  create(ProfiledBinary *Binary, SampleProfileMap &ProfileMap,
          bool profileIsCS);
   virtual void generateProfile() = 0;
   void write();
@@ -109,7 +110,7 @@ class ProfileGeneratorBase {
 
   StringRef getCalleeNameForOffset(uint64_t TargetOffset);
 
-  void computeSummaryAndThreshold();
+  void computeSummaryAndThreshold(SampleProfileMap &ProfileMap);
 
   void calculateAndShowDensity(const SampleProfileMap &Profiles);
 
@@ -120,6 +121,13 @@ class ProfileGeneratorBase {
 
   void collectProfiledFunctions();
 
+  bool collectFunctionsFromRawProfile(
+      std::unordered_set<const BinaryFunction *> &ProfiledFunctions);
+
+  // Collect profiled Functions for llvm sample profile input.
+  virtual bool collectFunctionsFromLLVMProfile(
+      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) = 0;
+
   // Thresholds from profile summary to answer isHotCount/isColdCount queries.
   uint64_t HotCountThreshold;
 
@@ -166,6 +174,8 @@ class ProfileGenerator : public ProfileGeneratorBase {
   void postProcessProfiles();
   void trimColdProfiles(const SampleProfileMap &Profiles,
                         uint64_t ColdCntThreshold);
+  bool collectFunctionsFromLLVMProfile(
+      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
 };
 
 class CSProfileGenerator : public ProfileGeneratorBase {
@@ -173,8 +183,8 @@ class CSProfileGenerator : public ProfileGeneratorBase {
   CSProfileGenerator(ProfiledBinary *Binary,
                      const ContextSampleCounterMap *Counters)
       : ProfileGeneratorBase(Binary, Counters){};
-  CSProfileGenerator(ProfiledBinary *Binary, const SampleProfileMap &&Profiles)
-      : ProfileGeneratorBase(Binary, std::move(Profiles)){};
+  CSProfileGenerator(ProfiledBinary *Binary, SampleProfileMap &Profiles)
+      : ProfileGeneratorBase(Binary), ContextTracker(Profiles, nullptr){};
   void generateProfile() override;
 
   // Trim the context stack at a given depth.
@@ -343,6 +353,11 @@ class CSProfileGenerator : public ProfileGeneratorBase {
 
   void convertToProfileMap();
 
+  void computeSummaryAndThreshold();
+
+  bool collectFunctionsFromLLVMProfile(
+      std::unordered_set<const BinaryFunction *> &ProfiledFunctions) override;
+
   ContextTrieNode &getRootContext() { return ContextTracker.getRootContext(); };
 
   // The container for holding the FunctionSamples used by context trie.

diff  --git a/llvm/tools/llvm-profgen/llvm-profgen.cpp b/llvm/tools/llvm-profgen/llvm-profgen.cpp
index 3aff3eab2c013..8b12c2fe46c72 100644
--- a/llvm/tools/llvm-profgen/llvm-profgen.cpp
+++ b/llvm/tools/llvm-profgen/llvm-profgen.cpp
@@ -164,8 +164,7 @@ int main(int argc, const char *argv[]) {
         std::move(ReaderOrErr.get());
     Reader->read();
     std::unique_ptr<ProfileGeneratorBase> Generator =
-        ProfileGeneratorBase::create(Binary.get(),
-                                     std::move(Reader->getProfiles()),
+        ProfileGeneratorBase::create(Binary.get(), Reader->getProfiles(),
                                      Reader->profileIsCS());
     Generator->generateProfile();
     Generator->write();


        


More information about the llvm-commits mailing list