[llvm] 30b0232 - [CSSPGO][llvm-profgen] Context-sensitive global pre-inliner

Wenlei He via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 29 09:54:09 PDT 2021


Author: Wenlei He
Date: 2021-03-29T09:46:14-07:00
New Revision: 30b023233696f044427d6c3ae6c0e290e3ef1aa0

URL: https://github.com/llvm/llvm-project/commit/30b023233696f044427d6c3ae6c0e290e3ef1aa0
DIFF: https://github.com/llvm/llvm-project/commit/30b023233696f044427d6c3ae6c0e290e3ef1aa0.diff

LOG: [CSSPGO][llvm-profgen] Context-sensitive global pre-inliner

This change sets up a framework in llvm-profgen to estimate inline decision and adjust context-sensitive profile based on that. We call it a global pre-inliner in llvm-profgen.

It will serve two purposes:
  1) Since context profile for not inlined context will be merged into base profile, if we estimate a context will not be inlined, we can merge the context profile in the output to save profile size.
  2) For thinLTO, when a context involving functions from different modules is not inined, we can't merge functions profiles across modules, leading to suboptimal post-inline count quality. By estimating some inline decisions, we would be able to adjust/merge context profiles beforehand as a mitigation.

Compiler inline heuristic uses inline cost which is not available in llvm-profgen. But since inline cost is closely related to size, we could get an estimate through function size from debug info. Because the size we have in llvm-profgen is the final size, it could also be more accurate than the inline cost estimation in the compiler.

This change only has the framework, with a few TODOs left for follow up patches for a complete implementation:
  1) We need to retrieve size for funciton//inlinee from debug info for inlining estimation. Currently we use number of samples in a profile as place holder for size estimation.
  2) Currently the thresholds are using the values used by sample loader inliner. But they need to be tuned since the size here is fully optimized machine code size, instead of inline cost based on not yet fully optimized IR.

Differential Revision: https://reviews.llvm.org/D99146

Added: 
    llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
    llvm/test/tools/llvm-profgen/cs-preinline.test
    llvm/tools/llvm-profgen/CSPreInliner.cpp
    llvm/tools/llvm-profgen/CSPreInliner.h

Modified: 
    llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
    llvm/lib/Transforms/IPO/SampleContextTracker.cpp
    llvm/lib/Transforms/IPO/SampleProfile.cpp
    llvm/tools/llvm-profgen/CMakeLists.txt
    llvm/tools/llvm-profgen/ProfileGenerator.cpp
    llvm/tools/llvm-profgen/ProfileGenerator.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h b/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
new file mode 100644
index 0000000000000..8eea41ca77a7f
--- /dev/null
+++ b/llvm/include/llvm/Transforms/IPO/ProfiledCallGraph.h
@@ -0,0 +1,135 @@
+//===-- ProfiledCallGraph.h - Profiled Call Graph ----------------- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_PROFGEN_PROFILEDCALLGRAPH_H
+#define LLVM_TOOLS_LLVM_PROFGEN_PROFILEDCALLGRAPH_H
+
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ProfileData/SampleProf.h"
+#include "llvm/Transforms/IPO/SampleContextTracker.h"
+#include <queue>
+#include <set>
+#include <string>
+
+using namespace llvm;
+using namespace sampleprof;
+
+namespace llvm {
+namespace sampleprof {
+
+struct ProfiledCallGraphNode {
+  ProfiledCallGraphNode(StringRef FName = StringRef()) : Name(FName) {}
+  StringRef Name;
+
+  struct ProfiledCallGraphNodeComparer {
+    bool operator()(const ProfiledCallGraphNode *L,
+                    const ProfiledCallGraphNode *R) const {
+      return L->Name < R->Name;
+    }
+  };
+  std::set<ProfiledCallGraphNode *, ProfiledCallGraphNodeComparer> Callees;
+};
+
+class ProfiledCallGraph {
+public:
+  using iterator = std::set<ProfiledCallGraphNode *>::iterator;
+  ProfiledCallGraph(StringMap<FunctionSamples> &ProfileMap,
+                    SampleContextTracker &ContextTracker) {
+    // Add all profiled functions into profiled call graph.
+    // We only add function with actual context profile
+    for (auto &FuncSample : ProfileMap) {
+      FunctionSamples *FSamples = &FuncSample.second;
+      addProfiledFunction(FSamples->getName());
+    }
+
+    // BFS traverse the context profile trie to add call edges for
+    // both samples calls as well as calls shown in context.
+    std::queue<ContextTrieNode *> Queue;
+    Queue.push(&ContextTracker.getRootContext());
+    while (!Queue.empty()) {
+      ContextTrieNode *Caller = Queue.front();
+      Queue.pop();
+      FunctionSamples *CallerSamples = Caller->getFunctionSamples();
+
+      // Add calls for context, if both caller and callee has context profile.
+      for (auto &Child : Caller->getAllChildContext()) {
+        ContextTrieNode *Callee = &Child.second;
+        Queue.push(Callee);
+        if (CallerSamples && Callee->getFunctionSamples()) {
+          addProfiledCall(Caller->getFuncName(), Callee->getFuncName());
+        }
+      }
+
+      // Add calls from call site samples
+      if (CallerSamples) {
+        for (auto &LocCallSite : CallerSamples->getBodySamples()) {
+          for (auto &NameCallSite : LocCallSite.second.getCallTargets()) {
+            addProfiledCall(Caller->getFuncName(), NameCallSite.first());
+          }
+        }
+      }
+    }
+  }
+
+  iterator begin() { return Root.Callees.begin(); }
+  iterator end() { return Root.Callees.end(); }
+  ProfiledCallGraphNode *getEntryNode() { return &Root; }
+  void addProfiledFunction(StringRef Name) {
+    if (!ProfiledFunctions.count(Name)) {
+      // Link to synthetic root to make sure every node is reachable
+      // from root. This does not affect SCC order.
+      Root.Callees.insert(&ProfiledFunctions[Name]);
+      ProfiledFunctions[Name] = ProfiledCallGraphNode(Name);
+    }
+  }
+  void addProfiledCall(StringRef CallerName, StringRef CalleeName) {
+    assert(ProfiledFunctions.count(CallerName));
+    auto CalleeIt = ProfiledFunctions.find(CalleeName);
+    if (CalleeIt == ProfiledFunctions.end()) {
+      return;
+    }
+    ProfiledFunctions[CallerName].Callees.insert(&CalleeIt->second);
+  }
+
+private:
+  ProfiledCallGraphNode Root;
+  StringMap<ProfiledCallGraphNode> ProfiledFunctions;
+};
+
+} // end namespace sampleprof
+
+template <> struct GraphTraits<ProfiledCallGraphNode *> {
+  using NodeRef = ProfiledCallGraphNode *;
+  using ChildIteratorType = std::set<ProfiledCallGraphNode *>::iterator;
+
+  static NodeRef getEntryNode(NodeRef PCGN) { return PCGN; }
+  static ChildIteratorType child_begin(NodeRef N) { return N->Callees.begin(); }
+  static ChildIteratorType child_end(NodeRef N) { return N->Callees.end(); }
+};
+
+template <>
+struct GraphTraits<ProfiledCallGraph *>
+    : public GraphTraits<ProfiledCallGraphNode *> {
+  static NodeRef getEntryNode(ProfiledCallGraph *PCG) {
+    return PCG->getEntryNode();
+  }
+
+  static ChildIteratorType nodes_begin(ProfiledCallGraph *PCG) {
+    return PCG->begin();
+  }
+
+  static ChildIteratorType nodes_end(ProfiledCallGraph *PCG) {
+    return PCG->end();
+  }
+};
+
+} // end namespace llvm
+
+#endif

diff  --git a/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h b/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
index bc8f602795a9e..685a060fe463d 100644
--- a/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
+++ b/llvm/include/llvm/Transforms/IPO/SampleContextTracker.h
@@ -114,13 +114,14 @@ class SampleContextTracker {
   FunctionSamples *getBaseSamplesFor(const Function &Func,
                                      bool MergeContext = true);
   // Query base profile for a given function by name.
-  FunctionSamples *getBaseSamplesFor(StringRef Name, bool MergeContext);
+  FunctionSamples *getBaseSamplesFor(StringRef Name, bool MergeContext = true);
   // Retrieve the context trie node for given profile context
   ContextTrieNode *getContextFor(const SampleContext &Context);
   // Mark a context profile as inlined when function is inlined.
   // This makes sure that inlined context profile will be excluded in
   // function's base profile.
   void markContextSamplesInlined(const FunctionSamples *InlinedSamples);
+  ContextTrieNode &getRootContext();
   void promoteMergeContextSamplesTree(const Instruction &Inst,
                                       StringRef CalleeName);
   void addCallGraphEdges(CallGraph &CG, StringMap<Function *> &SymbolMap);

diff  --git a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
index 5ad0ba20b3e02..863e8f3833fb8 100644
--- a/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
+++ b/llvm/lib/Transforms/IPO/SampleContextTracker.cpp
@@ -328,6 +328,8 @@ void SampleContextTracker::markContextSamplesInlined(
   InlinedSamples->getContext().setState(InlinedContext);
 }
 
+ContextTrieNode &SampleContextTracker::getRootContext() { return RootContext; }
+
 void SampleContextTracker::promoteMergeContextSamplesTree(
     const Instruction &Inst, StringRef CalleeName) {
   LLVM_DEBUG(dbgs() << "Promoting and merging context tree for instr: \n"
@@ -490,6 +492,7 @@ SampleContextTracker::getOrCreateContextPath(const SampleContext &Context,
 }
 
 ContextTrieNode *SampleContextTracker::getTopLevelContextNode(StringRef FName) {
+  assert(!FName.empty() && "Top level node query must provide valid name");
   return RootContext.getChildContext(LineLocation(0, 0), FName);
 }
 

diff  --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 548a8ad216b1e..79d68f2c62cfc 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -175,41 +175,42 @@ static cl::opt<bool> ProfileSizeInline(
     cl::desc("Inline cold call sites in profile loader if it's beneficial "
              "for code size."));
 
-static cl::opt<int> ProfileInlineGrowthLimit(
+cl::opt<int> ProfileInlineGrowthLimit(
     "sample-profile-inline-growth-limit", cl::Hidden, cl::init(12),
     cl::desc("The size growth ratio limit for proirity-based sample profile "
              "loader inlining."));
 
-static cl::opt<int> ProfileInlineLimitMin(
+cl::opt<int> ProfileInlineLimitMin(
     "sample-profile-inline-limit-min", cl::Hidden, cl::init(100),
     cl::desc("The lower bound of size growth limit for "
              "proirity-based sample profile loader inlining."));
 
-static cl::opt<int> ProfileInlineLimitMax(
+cl::opt<int> ProfileInlineLimitMax(
     "sample-profile-inline-limit-max", cl::Hidden, cl::init(10000),
     cl::desc("The upper bound of size growth limit for "
              "proirity-based sample profile loader inlining."));
 
+cl::opt<int> SampleHotCallSiteThreshold(
+    "sample-profile-hot-inline-threshold", cl::Hidden, cl::init(3000),
+    cl::desc("Hot callsite threshold for proirity-based sample profile loader "
+             "inlining."));
+
+cl::opt<int> SampleColdCallSiteThreshold(
+    "sample-profile-cold-inline-threshold", cl::Hidden, cl::init(45),
+    cl::desc("Threshold for inlining cold callsites"));
+
 static cl::opt<int> ProfileICPThreshold(
     "sample-profile-icp-threshold", cl::Hidden, cl::init(5),
     cl::desc(
         "Relative hotness threshold for indirect "
         "call promotion in proirity-based sample profile loader inlining."));
 
-static cl::opt<int> SampleHotCallSiteThreshold(
-    "sample-profile-hot-inline-threshold", cl::Hidden, cl::init(3000),
-    cl::desc("Hot callsite threshold for proirity-based sample profile loader "
-             "inlining."));
-
 static cl::opt<bool> CallsitePrioritizedInline(
     "sample-profile-prioritized-inline", cl::Hidden, cl::ZeroOrMore,
     cl::init(false),
     cl::desc("Use call site prioritized inlining for sample profile loader."
              "Currently only CSSPGO is supported."));
 
-static cl::opt<int> SampleColdCallSiteThreshold(
-    "sample-profile-cold-inline-threshold", cl::Hidden, cl::init(45),
-    cl::desc("Threshold for inlining cold callsites"));
 
 static cl::opt<std::string> ProfileInlineReplayFile(
     "sample-profile-inline-replay", cl::init(""), cl::value_desc("filename"),

diff  --git a/llvm/test/tools/llvm-profgen/cs-preinline.test b/llvm/test/tools/llvm-profgen/cs-preinline.test
new file mode 100644
index 0000000000000..e9aa7cbc73aa5
--- /dev/null
+++ b/llvm/test/tools/llvm-profgen/cs-preinline.test
@@ -0,0 +1,41 @@
+; Test default llvm-profgen with preinline off
+; RUN: llvm-profgen --perfscript=%S/Inputs/inline-cs-noprobe.perfscript --binary=%S/Inputs/inline-cs-noprobe.perfbin --output=%t
+; RUN: FileCheck %s --input-file %t --check-prefix=CHECK-DEFAULT
+
+; Test llvm-profgen with preinliner on will merge not inlinable profile into base profile.
+; RUN: llvm-profgen --perfscript=%S/Inputs/inline-cs-noprobe.perfscript --binary=%S/Inputs/inline-cs-noprobe.perfbin --output=%t --csspgo-preinliner=1
+; RUN: FileCheck %s --input-file %t --check-prefix=CHECK-PREINL
+
+; Test preinliner threshold that prevents all possible inlining and merges everything into base profile.
+; RUN: llvm-profgen --perfscript=%S/Inputs/inline-cs-noprobe.perfscript --binary=%S/Inputs/inline-cs-noprobe.perfbin --output=%t --csspgo-preinliner=1  -sample-profile-hot-inline-threshold=0
+; RUN: FileCheck %s --input-file %t --check-prefix=CHECK-NO-PREINL
+
+; CHECK-DEFAULT:     [main:1 @ foo]:309:0
+; CHECK-DEFAULT-NEXT: 2.1: 14
+; CHECK-DEFAULT-NEXT: 3: 15
+; CHECK-DEFAULT-NEXT: 3.1: 14 bar:14
+; CHECK-DEFAULT-NEXT: 3.2: 1
+; CHECK-DEFAULT-NEXT: !Attributes: 1
+; CHECK-DEFAULT-NEXT:[main:1 @ foo:3.1 @ bar]:84:0
+; CHECK-DEFAULT-NEXT: 1: 14
+; CHECK-DEFAULT-NEXT: !Attributes: 1
+
+; CHECK-PREINL:     [foo]:309:0
+; CHECK-PREINL-NEXT: 2.1: 14
+; CHECK-PREINL-NEXT: 3: 15
+; CHECK-PREINL-NEXT: 3.1: 14 bar:14
+; CHECK-PREINL-NEXT: 3.2: 1
+; CHECK-PREINL-NEXT: !Attributes: 1
+; CHECK-PREINL-NEXT:[foo:3.1 @ bar]:84:0
+; CHECK-PREINL-NEXT: 1: 14
+; CHECK-PREINL-NEXT: !Attributes: 3
+
+; CHECK-NO-PREINL:     [foo]:309:0
+; CHECK-NO-PREINL-NEXT: 2.1: 14
+; CHECK-NO-PREINL-NEXT: 3: 15
+; CHECK-NO-PREINL-NEXT: 3.1: 14 bar:14
+; CHECK-NO-PREINL-NEXT: 3.2: 1
+; CHECK-NO-PREINL-NEXT: !Attributes: 1
+; CHECK-NO-PREINL-NEXT:[bar]:84:0
+; CHECK-NO-PREINL-NEXT: 1: 14
+; CHECK-NO-PREINL-NEXT: !Attributes: 1

diff  --git a/llvm/tools/llvm-profgen/CMakeLists.txt b/llvm/tools/llvm-profgen/CMakeLists.txt
index e7705eb21c9f6..949b45ff2f96e 100644
--- a/llvm/tools/llvm-profgen/CMakeLists.txt
+++ b/llvm/tools/llvm-profgen/CMakeLists.txt
@@ -5,6 +5,7 @@ set(LLVM_LINK_COMPONENTS
   AllTargetsInfos
   Core
   MC
+  IPO
   MCDisassembler
   Object
   ProfileData
@@ -15,6 +16,7 @@ set(LLVM_LINK_COMPONENTS
 add_llvm_tool(llvm-profgen
   llvm-profgen.cpp
   PerfReader.cpp
+  CSPreInliner.cpp
   ProfiledBinary.cpp
   ProfileGenerator.cpp
   PseudoProbe.cpp

diff  --git a/llvm/tools/llvm-profgen/CSPreInliner.cpp b/llvm/tools/llvm-profgen/CSPreInliner.cpp
new file mode 100644
index 0000000000000..74cd09c25de74
--- /dev/null
+++ b/llvm/tools/llvm-profgen/CSPreInliner.cpp
@@ -0,0 +1,229 @@
+//===-- CSPreInliner.cpp - Profile guided preinliner -------------- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "CSPreInliner.h"
+#include "llvm/ADT/SCCIterator.h"
+#include <cstdint>
+#include <queue>
+
+#define DEBUG_TYPE "cs-preinliner"
+
+using namespace llvm;
+using namespace sampleprof;
+
+static cl::opt<bool> EnableCSPreInliner(
+    "csspgo-preinliner", cl::Hidden, cl::init(false),
+    cl::desc("Run a global pre-inliner to merge context profile based on "
+             "estimated global top-down inline decisions"));
+
+// The switches specify inline thresholds used in SampleProfileLoader inlining.
+// TODO: the actual threshold to be tuned here because the size here is based
+// on machine code not LLVM IR.
+extern cl::opt<int> SampleHotCallSiteThreshold;
+extern cl::opt<int> SampleColdCallSiteThreshold;
+extern cl::opt<int> ProfileInlineGrowthLimit;
+extern cl::opt<int> ProfileInlineLimitMin;
+extern cl::opt<int> ProfileInlineLimitMax;
+
+static cl::opt<bool> SamplePreInlineReplay(
+    "csspgo-replay-preinline", cl::Hidden, cl::init(false),
+    cl::desc(
+        "Replay previous inlining and adjust context profile accordingly"));
+
+CSPreInliner::CSPreInliner(StringMap<FunctionSamples> &Profiles,
+                           uint64_t HotThreshold, uint64_t ColdThreshold)
+    : ContextTracker(Profiles), ProfileMap(Profiles),
+      HotCountThreshold(HotThreshold), ColdCountThreshold(ColdThreshold) {}
+
+std::vector<StringRef> CSPreInliner::buildTopDownOrder() {
+  std::vector<StringRef> Order;
+  ProfiledCallGraph ProfiledCG(ProfileMap, ContextTracker);
+
+  // Now that we have a profiled call graph, construct top-down order
+  // by building up SCC and reversing SCC order.
+  scc_iterator<ProfiledCallGraph *> I = scc_begin(&ProfiledCG);
+  while (!I.isAtEnd()) {
+    for (ProfiledCallGraphNode *Node : *I) {
+      if (Node != ProfiledCG.getEntryNode())
+        Order.push_back(Node->Name);
+    }
+    ++I;
+  }
+  std::reverse(Order.begin(), Order.end());
+
+  return Order;
+}
+
+bool CSPreInliner::getInlineCandidates(ProfiledCandidateQueue &CQueue,
+                                       const FunctionSamples *CallerSamples) {
+  assert(CallerSamples && "Expect non-null caller samples");
+
+  // Ideally we want to consider everything a function calls, but as far as
+  // context profile is concerned, only those frames that are children of
+  // current one in the trie is relavent. So we walk the trie instead of call
+  // targets from function profile.
+  ContextTrieNode *CallerNode =
+      ContextTracker.getContextFor(CallerSamples->getContext());
+
+  bool HasNewCandidate = false;
+  for (auto &Child : CallerNode->getAllChildContext()) {
+    ContextTrieNode *CalleeNode = &Child.second;
+    FunctionSamples *CalleeSamples = CalleeNode->getFunctionSamples();
+    if (!CalleeSamples)
+      continue;
+
+    // Call site count is more reliable, so we look up the corresponding call
+    // target profile in caller's context profile to retrieve call site count.
+    uint64_t CalleeEntryCount = CalleeSamples->getEntrySamples();
+    uint64_t CallsiteCount = 0;
+    LineLocation Callsite = CalleeNode->getCallSiteLoc();
+    if (auto CallTargets = CallerSamples->findCallTargetMapAt(Callsite)) {
+      SampleRecord::CallTargetMap &TargetCounts = CallTargets.get();
+      auto It = TargetCounts.find(CalleeSamples->getName());
+      if (It != TargetCounts.end())
+        CallsiteCount = It->second;
+    }
+
+    // TODO: call site and callee entry count should be mostly consistent, add
+    // check for that.
+    HasNewCandidate = true;
+    CQueue.emplace(CalleeSamples, std::max(CallsiteCount, CalleeEntryCount));
+  }
+
+  return HasNewCandidate;
+}
+
+bool CSPreInliner::shouldInline(ProfiledInlineCandidate &Candidate) {
+  // If replay inline is requested, simply follow the inline decision of the
+  // profiled binary.
+  if (SamplePreInlineReplay)
+    return Candidate.CalleeSamples->getContext().hasAttribute(
+        ContextWasInlined);
+
+  // Adjust threshold based on call site hotness, only do this for callsite
+  // prioritized inliner because otherwise cost-benefit check is done earlier.
+  unsigned int SampleThreshold = SampleColdCallSiteThreshold;
+  if (Candidate.CallsiteCount > HotCountThreshold)
+    SampleThreshold = SampleHotCallSiteThreshold;
+
+  // TODO: for small cold functions, we may inlined them and we need to keep
+  // context profile accordingly.
+  if (Candidate.CallsiteCount < ColdCountThreshold)
+    SampleThreshold = SampleColdCallSiteThreshold;
+
+  return (Candidate.SizeCost < SampleThreshold);
+}
+
+void CSPreInliner::processFunction(const StringRef Name) {
+  LLVM_DEBUG(dbgs() << "Process " << Name
+                    << " for context-sensitive pre-inlining\n");
+
+  FunctionSamples *FSamples = ContextTracker.getBaseSamplesFor(Name);
+  if (!FSamples)
+    return;
+
+  // Use the number of lines/probes as proxy for function size for now.
+  // TODO: retrieve accurate size from dwarf or binary instead.
+  unsigned FuncSize = FSamples->getBodySamples().size();
+  unsigned FuncFinalSize = FuncSize;
+  unsigned SizeLimit = FuncSize * ProfileInlineGrowthLimit;
+  SizeLimit = std::min(SizeLimit, (unsigned)ProfileInlineLimitMax);
+  SizeLimit = std::max(SizeLimit, (unsigned)ProfileInlineLimitMin);
+
+  ProfiledCandidateQueue CQueue;
+  getInlineCandidates(CQueue, FSamples);
+
+  while (!CQueue.empty() && FuncFinalSize < SizeLimit) {
+    ProfiledInlineCandidate Candidate = CQueue.top();
+    CQueue.pop();
+    bool ShouldInline = false;
+    if ((ShouldInline = shouldInline(Candidate))) {
+      // We mark context as inlined as the corresponding context profile
+      // won't be merged into that function's base profile.
+      ContextTracker.markContextSamplesInlined(Candidate.CalleeSamples);
+      Candidate.CalleeSamples->getContext().setAttribute(
+          ContextShouldBeInlined);
+      FuncFinalSize += Candidate.SizeCost;
+      getInlineCandidates(CQueue, Candidate.CalleeSamples);
+    }
+    LLVM_DEBUG(dbgs() << (ShouldInline ? "  Inlined" : "  Outlined")
+                      << " context profile for: "
+                      << Candidate.CalleeSamples->getNameWithContext()
+                      << " (callee size: " << Candidate.SizeCost
+                      << ", call count:" << Candidate.CallsiteCount << ")\n");
+  }
+
+  LLVM_DEBUG({
+    if (!CQueue.empty())
+      dbgs() << "  Inline candidates ignored due to size limit (inliner "
+                "original size: "
+             << FuncSize << ", inliner final size: " << FuncFinalSize
+             << ", size limit: " << SizeLimit << ")\n";
+
+    while (!CQueue.empty()) {
+      ProfiledInlineCandidate Candidate = CQueue.top();
+      CQueue.pop();
+      bool WasInlined =
+          Candidate.CalleeSamples->getContext().hasAttribute(ContextWasInlined);
+      dbgs() << "    " << Candidate.CalleeSamples->getNameWithContext()
+             << " (candidate size:" << Candidate.SizeCost
+             << ", call count: " << Candidate.CallsiteCount << ", previously "
+             << (WasInlined ? "inlined)\n" : "not inlined)\n");
+    }
+  });
+}
+
+void CSPreInliner::run() {
+  if (!EnableCSPreInliner)
+    return;
+
+#ifndef NDEBUG
+  auto printProfileNames = [](StringMap<FunctionSamples> &Profiles,
+                              bool IsInput) {
+    dbgs() << (IsInput ? "Input" : "Output") << " context-sensitive profiles ("
+           << Profiles.size() << " total):\n";
+    for (auto &It : Profiles) {
+      const FunctionSamples &Samples = It.second;
+      dbgs() << "  [" << Samples.getNameWithContext() << "] "
+             << Samples.getTotalSamples() << ":" << Samples.getHeadSamples()
+             << "\n";
+    }
+  };
+#endif
+
+  LLVM_DEBUG(printProfileNames(ProfileMap, true));
+
+  // Execute global pre-inliner to estimate a global top-down inline
+  // decision and merge profiles accordingly. This helps with profile
+  // merge for ThinLTO otherwise we won't be able to merge profiles back
+  // to base profile across module/thin-backend boundaries.
+  // It also helps better compress context profile to control profile
+  // size, as we now only need context profile for functions going to
+  // be inlined.
+  for (StringRef FuncName : buildTopDownOrder()) {
+    processFunction(FuncName);
+  }
+
+  // Not inlined context profiles are merged into its base, so we can
+  // trim out such profiles from the output.
+  std::vector<StringRef> ProfilesToBeRemoved;
+  for (auto &It : ProfileMap) {
+    SampleContext Context = It.second.getContext();
+    if (!Context.isBaseContext() && !Context.hasState(InlinedContext)) {
+      assert(Context.hasState(MergedContext) &&
+             "Not inlined context profile should be merged already");
+      ProfilesToBeRemoved.push_back(It.first());
+    }
+  }
+
+  for (StringRef ContextName : ProfilesToBeRemoved) {
+    ProfileMap.erase(ContextName);
+  }
+
+  LLVM_DEBUG(printProfileNames(ProfileMap, false));
+}

diff  --git a/llvm/tools/llvm-profgen/CSPreInliner.h b/llvm/tools/llvm-profgen/CSPreInliner.h
new file mode 100644
index 0000000000000..5c65d8fd4a3b7
--- /dev/null
+++ b/llvm/tools/llvm-profgen/CSPreInliner.h
@@ -0,0 +1,92 @@
+//===-- CSPreInliner.h - Profile guided preinliner ---------------- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TOOLS_LLVM_PROFGEN_PGOINLINEADVISOR_H
+#define LLVM_TOOLS_LLVM_PROFGEN_PGOINLINEADVISOR_H
+
+#include "llvm/ADT/PriorityQueue.h"
+#include "llvm/ProfileData/ProfileCommon.h"
+#include "llvm/ProfileData/SampleProf.h"
+#include "llvm/Transforms/IPO/ProfiledCallGraph.h"
+#include "llvm/Transforms/IPO/SampleContextTracker.h"
+
+using namespace llvm;
+using namespace sampleprof;
+
+namespace llvm {
+namespace sampleprof {
+
+// Inline candidate seen from profile
+struct ProfiledInlineCandidate {
+  ProfiledInlineCandidate(const FunctionSamples *Samples, uint64_t Count)
+      : CalleeSamples(Samples), CallsiteCount(Count),
+        SizeCost(Samples->getBodySamples().size()) {}
+  // Context-sensitive function profile for inline candidate
+  const FunctionSamples *CalleeSamples;
+  // Call site count for an inline candidate
+  // TODO: make sure entry count for context profile and call site
+  // target count for corresponding call are consistent.
+  uint64_t CallsiteCount;
+  // Size proxy for function under particular call context.
+  // TODO: use post-inline callee size from debug info.
+  uint64_t SizeCost;
+};
+
+// Inline candidate comparer using call site weight
+struct ProfiledCandidateComparer {
+  bool operator()(const ProfiledInlineCandidate &LHS,
+                  const ProfiledInlineCandidate &RHS) {
+    if (LHS.CallsiteCount != RHS.CallsiteCount)
+      return LHS.CallsiteCount < RHS.CallsiteCount;
+
+    if (LHS.SizeCost != RHS.SizeCost)
+      return LHS.SizeCost > RHS.SizeCost;
+
+    // Tie breaker using GUID so we have stable/deterministic inlining order
+    assert(LHS.CalleeSamples && RHS.CalleeSamples &&
+           "Expect non-null FunctionSamples");
+    return LHS.CalleeSamples->getGUID(LHS.CalleeSamples->getName()) <
+           RHS.CalleeSamples->getGUID(RHS.CalleeSamples->getName());
+  }
+};
+
+using ProfiledCandidateQueue =
+    PriorityQueue<ProfiledInlineCandidate, std::vector<ProfiledInlineCandidate>,
+                  ProfiledCandidateComparer>;
+
+// Pre-compilation inliner based on context-sensitive profile.
+// The PreInliner estimates inline decision using hotness from profile
+// and cost estimation from machine code size. It helps merges context
+// profile globally and achieves better post-inine profile quality, which
+// otherwise won't be possible for ThinLTO. It also reduce context profile
+// size by only keep context that is estimated to be inlined.
+class CSPreInliner {
+public:
+  CSPreInliner(StringMap<FunctionSamples> &Profiles, uint64_t HotThreshold,
+               uint64_t ColdThreshold);
+  void run();
+
+private:
+  bool getInlineCandidates(ProfiledCandidateQueue &CQueue,
+                           const FunctionSamples *FCallerContextSamples);
+  std::vector<StringRef> buildTopDownOrder();
+  void processFunction(StringRef Name);
+  bool shouldInline(ProfiledInlineCandidate &Candidate);
+  SampleContextTracker ContextTracker;
+  StringMap<FunctionSamples> &ProfileMap;
+
+  // Count thresholds to answer isHotCount and isColdCount queries.
+  // Mirrors the threshold in ProfileSummaryInfo.
+  uint64_t HotCountThreshold;
+  uint64_t ColdCountThreshold;
+};
+
+} // end namespace sampleprof
+} // end namespace llvm
+
+#endif

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.cpp b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
index b3fb015b6725c..a6794f01551e8 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.cpp
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.cpp
@@ -234,9 +234,7 @@ void CSProfileGenerator::generateProfile() {
   // body sample.
   populateInferredFunctionSamples();
 
-  // Compute hot/cold threshold based on profile. This will be used for cold
-  // context profile merging/trimming.
-  computeSummaryAndThreshold();
+  postProcessProfiles();
 }
 
 void CSProfileGenerator::updateBodySamplesforFunctionProfile(
@@ -392,6 +390,20 @@ void CSProfileGenerator::populateInferredFunctionSamples() {
   }
 }
 
+void CSProfileGenerator::postProcessProfiles() {
+  // Compute hot/cold threshold based on profile. This will be used for cold
+  // context profile merging/trimming.
+  computeSummaryAndThreshold();
+
+  // Run global pre-inliner to adjust/merge context profile based on estimated
+  // inline decisions.
+  CSPreInliner(ProfileMap, PSI->getHotCountThreshold(),
+               PSI->getColdCountThreshold())
+      .run();
+
+  mergeAndTrimColdProfile(ProfileMap);
+}
+
 void CSProfileGenerator::computeSummaryAndThreshold() {
   SampleProfileSummaryBuilder Builder(ProfileSummaryBuilder::DefaultCutoffs);
   auto Summary = Builder.computeSummaryForProfiles(ProfileMap);
@@ -451,17 +463,19 @@ void CSProfileGenerator::mergeAndTrimColdProfile(
 
 void CSProfileGenerator::write(std::unique_ptr<SampleProfileWriter> Writer,
                                StringMap<FunctionSamples> &ProfileMap) {
-  mergeAndTrimColdProfile(ProfileMap);
   // Add bracket for context key to support 
diff erent profile binary format
   StringMap<FunctionSamples> CxtWithBracketPMap;
   for (const auto &Item : ProfileMap) {
-    std::string ContextWithBracket = "[" + Item.first().str() + "]";
+    // After CSPreInliner the key of ProfileMap is no longer accurate for
+    // context, use the context attached to function samples instead.
+    std::string ContextWithBracket =
+        "[" + Item.second.getNameWithContext().str() + "]";
     auto Ret = CxtWithBracketPMap.try_emplace(ContextWithBracket, Item.second);
     assert(Ret.second && "Must be a unique context");
     SampleContext FContext(Ret.first->first(), RawContext);
     FunctionSamples &FProfile = Ret.first->second;
     FContext.setAllAttributes(FProfile.getContext().getAllAttributes());
-    FProfile.setName(FContext.getNameWithContext(true));
+    FProfile.setName(FContext.getNameWithoutContext());
     FProfile.setContext(FContext);
   }
   Writer->write(CxtWithBracketPMap);
@@ -500,9 +514,7 @@ void PseudoProbeCSProfileGenerator::generateProfile() {
     }
   }
 
-  // Compute hot/cold threshold based on profile. This will be used for cold
-  // context profile merging/trimming.
-  computeSummaryAndThreshold();
+  postProcessProfiles();
 }
 
 void PseudoProbeCSProfileGenerator::extractProbesFromRange(

diff  --git a/llvm/tools/llvm-profgen/ProfileGenerator.h b/llvm/tools/llvm-profgen/ProfileGenerator.h
index ff0116fb5c351..0ba884f3afbb8 100644
--- a/llvm/tools/llvm-profgen/ProfileGenerator.h
+++ b/llvm/tools/llvm-profgen/ProfileGenerator.h
@@ -8,6 +8,7 @@
 
 #ifndef LLVM_TOOLS_LLVM_PROGEN_PROFILEGENERATOR_H
 #define LLVM_TOOLS_LLVM_PROGEN_PROFILEGENERATOR_H
+#include "CSPreInliner.h"
 #include "ErrorHandling.h"
 #include "PerfReader.h"
 #include "ProfiledBinary.h"
@@ -178,6 +179,9 @@ class CSProfileGenerator : public ProfileGenerator {
   // Lookup or create FunctionSamples for the context
   FunctionSamples &getFunctionProfileForContext(StringRef ContextId,
                                                 bool WasLeafInlined = false);
+  // Post processing for profiles before writing out, such as mermining
+  // and trimming cold profiles, running preinliner on profiles.
+  void postProcessProfiles();
   // Merge cold context profile whose total sample is below threshold
   // into base profile.
   void mergeAndTrimColdProfile(StringMap<FunctionSamples> &ProfileMap);
@@ -185,6 +189,9 @@ class CSProfileGenerator : public ProfileGenerator {
   void write(std::unique_ptr<SampleProfileWriter> Writer,
              StringMap<FunctionSamples> &ProfileMap) override;
 
+  // Profile summary to answer isHotCount and isColdCount queries.
+  std::unique_ptr<ProfileSummaryInfo> PSI;
+
 private:
   // Helper function for updating body sample for a leaf location in
   // FunctionProfile
@@ -200,9 +207,6 @@ class CSProfileGenerator : public ProfileGenerator {
                                        ProfiledBinary *Binary);
   void populateInferredFunctionSamples();
 
-  // Profile summary to answer isHotCount and isColdCount queries.
-  std::unique_ptr<ProfileSummaryInfo> PSI;
-
 public:
   // Deduplicate adjacent repeated context sequences up to a given sequence
   // length. -1 means no size limit.


        


More information about the llvm-commits mailing list