[llvm] [MemProf] Handle missing tail call frames (PR #75823)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 18 08:27:38 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Teresa Johnson (teresajohnson)

<details>
<summary>Changes</summary>

If tail call optimization was not disabled for the profiled binary, the
call contexts will be missing frames for tail calls. Handle this by
performing a limited search through tail call edges for the profiled
callee when a discontinuity is detected. The search depth is adjustable
but defaults to 5.

If we are able to identify a short sequence of tail calls, update the
graph for those calls. In the case of ThinLTO, synthesize the necessary
CallsiteInfos for carrying the cloning information to the backends.


---

Patch is 36.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75823.diff


5 Files Affected:

- (modified) llvm/include/llvm/IR/ModuleSummaryIndex.h (+6) 
- (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+16-1) 
- (modified) llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp (+422-54) 
- (added) llvm/test/ThinLTO/X86/memprof-tailcall.ll (+110) 
- (added) llvm/test/Transforms/MemProfContextDisambiguation/tailcall.ll (+84) 


``````````diff
diff --git a/llvm/include/llvm/IR/ModuleSummaryIndex.h b/llvm/include/llvm/IR/ModuleSummaryIndex.h
index e72f74ad4adb66..66c7d10d823d9c 100644
--- a/llvm/include/llvm/IR/ModuleSummaryIndex.h
+++ b/llvm/include/llvm/IR/ModuleSummaryIndex.h
@@ -1011,6 +1011,12 @@ class FunctionSummary : public GlobalValueSummary {
     return *Callsites;
   }
 
+  void addCallsite(CallsiteInfo &Callsite) {
+    if (!Callsites)
+      Callsites = std::make_unique<CallsitesTy>();
+    Callsites->push_back(Callsite);
+  }
+
   ArrayRef<AllocInfo> allocs() const {
     if (Allocs)
       return *Allocs;
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8fca569a391baf..a5fc267b1883bf 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -459,9 +459,24 @@ class IndexBitcodeWriter : public BitcodeWriterBase {
       // Record all stack id indices actually used in the summary entries being
       // written, so that we can compact them in the case of distributed ThinLTO
       // indexes.
-      for (auto &CI : FS->callsites())
+      for (auto &CI : FS->callsites()) {
+        // If the stack id list is empty, this callsite info was synthesized for
+        // a missing tail call frame. Ensure that the callee's GUID gets a value
+        // id. Normally we only generate these for defined summaries, which in
+        // the case of distributed ThinLTO is only the functions already defined
+        // in the module or that we want to import. We don't bother to include
+        // all the callee symbols as they aren't normally needed in the backend.
+        // However, for the synthesized callsite infos we do need the callee
+        // GUID in the backend so that we can correlate the identified callee
+        // with this callsite info (which for non-tail calls is done by the
+        // ordering of the callsite infos and verified via stack ids).
+        if (CI.StackIdIndices.empty()) {
+          GUIDToValueIdMap[CI.Callee.getGUID()] = ++GlobalValueId;
+          continue;
+        }
         for (auto Idx : CI.StackIdIndices)
           StackIdIndices.push_back(Idx);
+      }
       for (auto &AI : FS->allocs())
         for (auto &MIB : AI.MIBs)
           for (auto Idx : MIB.StackIdIndices)
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index 70a3f3067d9d6d..59f982cd420bb6 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -77,6 +77,14 @@ STATISTIC(MaxAllocVersionsThinBackend,
           "allocation during ThinLTO backend");
 STATISTIC(UnclonableAllocsThinBackend,
           "Number of unclonable ambigous allocations during ThinLTO backend");
+STATISTIC(RemovedEdgesWithMismatchedCallees,
+          "Number of edges removed due to mismatched callees (profiled vs IR)");
+STATISTIC(FoundProfiledCalleeCount,
+          "Number of profiled callees found via tail calls");
+STATISTIC(FoundProfiledCalleeDepth,
+          "Aggregate depth of profiled callees found via tail calls");
+STATISTIC(FoundProfiledCalleeMaxDepth,
+          "Maximum depth of profiled callees found via tail calls");
 
 static cl::opt<std::string> DotFilePathPrefix(
     "memprof-dot-file-path-prefix", cl::init(""), cl::Hidden,
@@ -104,6 +112,12 @@ static cl::opt<std::string> MemProfImportSummary(
     cl::desc("Import summary to use for testing the ThinLTO backend via opt"),
     cl::Hidden);
 
+static cl::opt<unsigned>
+    TailCallSearchDepth("memprof-tail-call-search-depth", cl::init(5),
+                        cl::Hidden,
+                        cl::desc("Max depth to recursively search for missing "
+                                 "frames through tail calls."));
+
 namespace llvm {
 // Indicate we are linking with an allocator that supports hot/cold operator
 // new interfaces.
@@ -365,8 +379,7 @@ class CallsiteContextGraph {
 
   /// Save lists of calls with MemProf metadata in each function, for faster
   /// iteration.
-  std::vector<std::pair<FuncTy *, std::vector<CallInfo>>>
-      FuncToCallsWithMetadata;
+  std::map<FuncTy *, std::vector<CallInfo>> FuncToCallsWithMetadata;
 
   /// Map from callsite node to the enclosing caller function.
   std::map<const ContextNode *, const FuncTy *> NodeToCallingFunc;
@@ -411,9 +424,25 @@ class CallsiteContextGraph {
     return static_cast<const DerivedCCG *>(this)->getStackId(IdOrIndex);
   }
 
-  /// Returns true if the given call targets the given function.
-  bool calleeMatchesFunc(CallTy Call, const FuncTy *Func) {
-    return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(Call, Func);
+  /// Returns true if the given call targets the callee of the given edge, or if
+  /// we were able to identify the call chain through intermediate tail calls.
+  /// In the latter case new context nodes are added to the graph for the
+  /// identified tail calls, and their synthesized nodes are added to
+  /// TailCallToContextNodeMap. The EdgeIter is updated in either case to the
+  /// next element after the input position (either incremented or updated after
+  /// removing the old edge).
+  bool
+  calleesMatch(CallTy Call, EdgeIter &EI,
+               MapVector<CallInfo, ContextNode *> &TailCallToContextNodeMap);
+
+  /// Returns true if the given call targets the given function, or if we were
+  /// able to identify the call chain through intermediate tail calls (in which
+  /// case FoundCalleeChain will be populated).
+  bool calleeMatchesFunc(
+      CallTy Call, const FuncTy *Func, const FuncTy *CallerFunc,
+      std::vector<std::pair<CallTy, FuncTy *>> &FoundCalleeChain) {
+    return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(
+        Call, Func, CallerFunc, FoundCalleeChain);
   }
 
   /// Get a list of nodes corresponding to the stack ids in the given
@@ -553,7 +582,12 @@ class ModuleCallsiteContextGraph
                               Instruction *>;
 
   uint64_t getStackId(uint64_t IdOrIndex) const;
-  bool calleeMatchesFunc(Instruction *Call, const Function *Func);
+  bool calleeMatchesFunc(
+      Instruction *Call, const Function *Func, const Function *CallerFunc,
+      std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain);
+  bool findProfiledCalleeThroughTailCalls(
+      const Function *ProfiledCallee, Value *CurCallee, unsigned Depth,
+      std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain);
   uint64_t getLastStackId(Instruction *Call);
   std::vector<uint64_t> getStackIdsWithContextNodesForCall(Instruction *Call);
   void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
@@ -606,12 +640,30 @@ class IndexCallsiteContextGraph
       function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
           isPrevailing);
 
+  ~IndexCallsiteContextGraph() {
+    // Now that we are done with the graph it is safe to add the new
+    // CallsiteInfo structs to the function summary vectors. The graph nodes
+    // point into locations within these vectors, so we don't want to add them
+    // any earlier.
+    for (auto &I : FunctionCalleesToSynthesizedCallsiteInfos) {
+      auto *FS = I.first;
+      for (auto &Callsite : I.second)
+        FS->addCallsite(*Callsite.second);
+    }
+  }
+
 private:
   friend CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
                               IndexCall>;
 
   uint64_t getStackId(uint64_t IdOrIndex) const;
-  bool calleeMatchesFunc(IndexCall &Call, const FunctionSummary *Func);
+  bool calleeMatchesFunc(
+      IndexCall &Call, const FunctionSummary *Func,
+      const FunctionSummary *CallerFunc,
+      std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain);
+  bool findProfiledCalleeThroughTailCalls(
+      ValueInfo ProfiledCallee, ValueInfo CurCallee, unsigned Depth,
+      std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain);
   uint64_t getLastStackId(IndexCall &Call);
   std::vector<uint64_t> getStackIdsWithContextNodesForCall(IndexCall &Call);
   void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
@@ -630,6 +682,16 @@ class IndexCallsiteContextGraph
   std::map<const FunctionSummary *, ValueInfo> FSToVIMap;
 
   const ModuleSummaryIndex &Index;
+  function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+      isPrevailing;
+
+  // Saves/owns the callsite info structures synthesized for missing tail call
+  // frames that we discover while building the graph.
+  // It maps from the summary of the function making the tail call, to a map
+  // of callee ValueInfo to corresponding synthesized callsite info.
+  std::map<FunctionSummary *,
+           std::map<ValueInfo, std::unique_ptr<CallsiteInfo>>>
+      FunctionCalleesToSynthesizedCallsiteInfos;
 };
 } // namespace
 
@@ -1493,7 +1555,7 @@ ModuleCallsiteContextGraph::ModuleCallsiteContextGraph(
       }
     }
     if (!CallsWithMetadata.empty())
-      FuncToCallsWithMetadata.push_back({&F, CallsWithMetadata});
+      FuncToCallsWithMetadata[&F] = CallsWithMetadata;
   }
 
   if (DumpCCG) {
@@ -1518,7 +1580,7 @@ IndexCallsiteContextGraph::IndexCallsiteContextGraph(
     ModuleSummaryIndex &Index,
     function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
         isPrevailing)
-    : Index(Index) {
+    : Index(Index), isPrevailing(isPrevailing) {
   for (auto &I : Index) {
     auto VI = Index.getValueInfo(I);
     for (auto &S : VI.getSummaryList()) {
@@ -1572,7 +1634,7 @@ IndexCallsiteContextGraph::IndexCallsiteContextGraph(
           CallsWithMetadata.push_back({&SN});
 
       if (!CallsWithMetadata.empty())
-        FuncToCallsWithMetadata.push_back({FS, CallsWithMetadata});
+        FuncToCallsWithMetadata[FS] = CallsWithMetadata;
 
       if (!FS->allocs().empty() || !FS->callsites().empty())
         FSToVIMap[FS] = VI;
@@ -1604,6 +1666,11 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
   // this transformation for regular LTO, and for ThinLTO we can simulate that
   // effect in the summary and perform the actual speculative devirtualization
   // while cloning in the ThinLTO backend.
+
+  // Keep track of the new nodes synthesizes for discovered tail calls missing
+  // from the profiled contexts.
+  MapVector<CallInfo, ContextNode *> TailCallToContextNodeMap;
+
   for (auto Entry = NonAllocationCallToContextNodeMap.begin();
        Entry != NonAllocationCallToContextNodeMap.end();) {
     auto *Node = Entry->second;
@@ -1611,13 +1678,17 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
     // Check all node callees and see if in the same function.
     bool Removed = false;
     auto Call = Node->Call.call();
-    for (auto &Edge : Node->CalleeEdges) {
-      if (!Edge->Callee->hasCall())
+    for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end();) {
+      auto Edge = *EI;
+      if (!Edge->Callee->hasCall()) {
+        ++EI;
         continue;
+      }
       assert(NodeToCallingFunc.count(Edge->Callee));
       // Check if the called function matches that of the callee node.
-      if (calleeMatchesFunc(Call, NodeToCallingFunc[Edge->Callee]))
+      if (calleesMatch(Call, EI, TailCallToContextNodeMap))
         continue;
+      RemovedEdgesWithMismatchedCallees++;
       // Work around by setting Node to have a null call, so it gets
       // skipped during cloning. Otherwise assignFunctions will assert
       // because its data structures are not designed to handle this case.
@@ -1629,6 +1700,11 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
     if (!Removed)
       Entry++;
   }
+
+  // Add the new nodes after the above loop so that the iteration is not
+  // invalidated.
+  for (auto I : TailCallToContextNodeMap)
+    NonAllocationCallToContextNodeMap[I.first] = I.second;
 }
 
 uint64_t ModuleCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
@@ -1642,8 +1718,160 @@ uint64_t IndexCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
   return Index.getStackIdAtIndex(IdOrIndex);
 }
 
-bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call,
-                                                   const Function *Func) {
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch(
+    CallTy Call, EdgeIter &EI,
+    MapVector<CallInfo, ContextNode *> &TailCallToContextNodeMap) {
+  auto Edge = *EI;
+  const FuncTy *ProfiledCalleeFunc = NodeToCallingFunc[Edge->Callee];
+  const FuncTy *CallerFunc = NodeToCallingFunc[Edge->Caller];
+  // Will be populated in order of callee to caller if we find a chain of tail
+  // calls between the profiled caller and callee.
+  std::vector<std::pair<CallTy, FuncTy *>> FoundCalleeChain;
+  if (!calleeMatchesFunc(Call, ProfiledCalleeFunc, CallerFunc,
+                         FoundCalleeChain)) {
+    ++EI;
+    return false;
+  }
+
+  // The usual case where the profiled callee matches that of the IR/summary.
+  if (FoundCalleeChain.empty()) {
+    ++EI;
+    return true;
+  }
+
+  auto AddEdge = [Edge, &EI](ContextNode *Caller, ContextNode *Callee) {
+    auto *CurEdge = Callee->findEdgeFromCaller(Caller);
+    // If there is already an edge between these nodes, simply update it and
+    // return.
+    if (CurEdge) {
+      CurEdge->ContextIds.insert(Edge->ContextIds.begin(),
+                                 Edge->ContextIds.end());
+      CurEdge->AllocTypes |= Edge->AllocTypes;
+      return;
+    }
+    // Otherwise, create a new edge and insert it into the caller and callee
+    // lists.
+    auto NewEdge = std::make_shared<ContextEdge>(
+        Callee, Caller, Edge->AllocTypes, Edge->ContextIds);
+    Callee->CallerEdges.push_back(NewEdge);
+    if (Caller == Edge->Caller) {
+      // If we are inserting the new edge into the current edge's caller, insert
+      // the new edge before the current iterator position, and then increment
+      // back to the current edge.
+      EI = Caller->CalleeEdges.insert(EI, NewEdge);
+      ++EI;
+      assert(*EI == Edge);
+    } else
+      Caller->CalleeEdges.push_back(NewEdge);
+  };
+
+  // Create new nodes for each found callee and connect in between the profiled
+  // caller and callee.
+  auto *CurCalleeNode = Edge->Callee;
+  for (auto I : FoundCalleeChain) {
+    auto &NewCall = I.first;
+    auto *Func = I.second;
+    ContextNode *NewNode = nullptr;
+    // First check if we have already synthesized a node for this tail call.
+    if (TailCallToContextNodeMap.count(NewCall)) {
+      NewNode = TailCallToContextNodeMap[NewCall];
+      NewNode->ContextIds.insert(Edge->ContextIds.begin(),
+                                 Edge->ContextIds.end());
+      NewNode->AllocTypes |= Edge->AllocTypes;
+    } else {
+      FuncToCallsWithMetadata[Func].push_back({NewCall});
+      // Create Node and record node info.
+      NodeOwner.push_back(
+          std::make_unique<ContextNode>(/*IsAllocation=*/false, NewCall));
+      NewNode = NodeOwner.back().get();
+      NodeToCallingFunc[NewNode] = Func;
+      TailCallToContextNodeMap[NewCall] = NewNode;
+      NewNode->ContextIds = Edge->ContextIds;
+      NewNode->AllocTypes = Edge->AllocTypes;
+    }
+
+    // Hook up node to its callee node
+    AddEdge(NewNode, CurCalleeNode);
+
+    CurCalleeNode = NewNode;
+  }
+
+  // Hook up edge's original caller to new callee node.
+  AddEdge(Edge->Caller, CurCalleeNode);
+
+  // Remove old edge
+  Edge->Callee->eraseCallerEdge(Edge.get());
+  EI = Edge->Caller->CalleeEdges.erase(EI);
+
+  return true;
+}
+
+bool ModuleCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
+    const Function *ProfiledCallee, Value *CurCallee, unsigned Depth,
+    std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain) {
+  // Stop recursive search if we have already explored the maximum specified
+  // depth.
+  if (Depth > TailCallSearchDepth)
+    return false;
+
+  auto SaveCallsiteInfo = [&](Instruction *Callsite, Function *F) {
+    FoundCalleeChain.push_back({Callsite, F});
+  };
+
+  auto *CalleeFunc = dyn_cast<Function>(CurCallee);
+  if (!CalleeFunc) {
+    auto *Alias = dyn_cast<GlobalAlias>(CurCallee);
+    assert(Alias);
+    CalleeFunc = dyn_cast<Function>(Alias->getAliasee());
+    assert(CalleeFunc);
+  }
+
+  // Look for tail calls in this function, and check if they either call the
+  // profiled callee directly, or indirectly (via a recursive search).
+  for (auto &BB : *CalleeFunc) {
+    for (auto &I : BB) {
+      auto *CB = dyn_cast<CallBase>(&I);
+      if (!CB->isTailCall())
+        continue;
+      auto *CalledValue = CB->getCalledOperand();
+      auto *CalledFunction = CB->getCalledFunction();
+      if (CalledValue && !CalledFunction) {
+        CalledValue = CalledValue->stripPointerCasts();
+        // Stripping pointer casts can reveal a called function.
+        CalledFunction = dyn_cast<Function>(CalledValue);
+      }
+      // Check if this is an alias to a function. If so, get the
+      // called aliasee for the checks below.
+      if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) {
+        assert(!CalledFunction &&
+               "Expected null called function in callsite for alias");
+        CalledFunction = dyn_cast<Function>(GA->getAliaseeObject());
+      }
+      if (!CalledFunction)
+        continue;
+      if (CalledFunction == ProfiledCallee) {
+        FoundProfiledCalleeCount++;
+        FoundProfiledCalleeDepth += Depth;
+        if (Depth > FoundProfiledCalleeMaxDepth)
+          FoundProfiledCalleeMaxDepth = Depth;
+        SaveCallsiteInfo(&I, CalleeFunc);
+        return true;
+      }
+      if (findProfiledCalleeThroughTailCalls(ProfiledCallee, CalledFunction,
+                                             Depth + 1, FoundCalleeChain)) {
+        SaveCallsiteInfo(&I, CalleeFunc);
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+bool ModuleCallsiteContextGraph::calleeMatchesFunc(
+    Instruction *Call, const Function *Func, const Function *CallerFunc,
+    std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain) {
   auto *CB = dyn_cast<CallBase>(Call);
   if (!CB->getCalledOperand())
     return false;
@@ -1652,11 +1880,96 @@ bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call,
   if (CalleeFunc == Func)
     return true;
   auto *Alias = dyn_cast<GlobalAlias>(CalleeVal);
-  return Alias && Alias->getAliasee() == Func;
+  if (Alias && Alias->getAliasee() == Func)
+    return true;
+
+  // Recursively search for the profiled callee through tail calls starting with
+  // the actual Callee. The discovered tail call chain is saved in
+  // FoundCalleeChain, and we will fixup the graph to include these callsites
+  // after returning.
+  // FIXME: We will currently redo the same recursive walk if we find the same
+  // mismatched callee from another callsite. We can improve this with more
+  // bookkeeping of the created chain of new nodes for each mismatch.
+  unsigned Depth = 1;
+  if (!findProfiledCalleeThroughTailCalls(Func, CalleeVal, Depth,
+                                          FoundCalleeChain)) {
+    LLVM_DEBUG(dbgs() << "Not found through tail calls: " << Func->getName()
+                      << " from " << CallerFunc->getName()
+                      << " that actually called " << CalleeVal->getName()
+                      << "\n");
+    return false;
+  }
+
+  return true;
 }
 
-bool IndexCallsiteContextGraph::calleeMatchesFunc(IndexCall &Call,
-                                                  const FunctionSummary *Func) {
+bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
+    ValueInfo ProfiledCallee, ValueInfo CurCallee, unsigned Depth,
+    std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain) {
+  // Stop recursive search if we have already explored the maximum specified
+  // depth.
+  if (Depth > TailCallSearchDepth)
+    return false;
+
+  auto CreateAndSaveCallsiteInfo = [&](ValueInfo Callee, FunctionSummary *FS) {
+    // Make a CallsiteInfo for each discovered callee, if one hasn't already
+    // been synthesized.
+    if (!FunctionCalleesToSynthesizedCallsiteInfos.count(FS) ||
+        !FunctionCalleesToSynthesizedCallsiteInfos[FS].count(Callee))
+      // StackIds is empty (we don't have debug info available in the index for
+      // these callsites)
+      FunctionCal...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75823


More information about the llvm-commits mailing list