[llvm] [MemProf] Handle missing tail call frames (PR #75823)

Teresa Johnson via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 18 08:27:06 PST 2023


https://github.com/teresajohnson created https://github.com/llvm/llvm-project/pull/75823

If tail call optimization was not disabled for the profiled binary, the
call contexts will be missing frames for tail calls. Handle this by
performing a limited search through tail call edges for the profiled
callee when a discontinuity is detected. The search depth is adjustable
but defaults to 5.

If we are able to identify a short sequence of tail calls, update the
graph for those calls. In the case of ThinLTO, synthesize the necessary
CallsiteInfos for carrying the cloning information to the backends.


>From c6db18e6c1ac30217b534692fef12e4338d1393a Mon Sep 17 00:00:00 2001
From: Teresa Johnson <tejohnson at google.com>
Date: Fri, 15 Dec 2023 11:56:53 -0800
Subject: [PATCH] [MemProf] Handle missing tail call frames

If tail call optimization was not disabled for the profiled binary, the
call contexts will be missing frames for tail calls. Handle this by
performing a limited search through tail call edges for the profiled
callee when a discontinuity is detected. The search depth is adjustable
but defaults to 5.

If we are able to identify a short sequence of tail calls, update the
graph for those calls. In the case of ThinLTO, synthesize the necessary
CallsiteInfos for carrying the cloning information to the backends.
---
 llvm/include/llvm/IR/ModuleSummaryIndex.h     |   6 +
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |  17 +-
 .../IPO/MemProfContextDisambiguation.cpp      | 476 ++++++++++++++++--
 llvm/test/ThinLTO/X86/memprof-tailcall.ll     | 110 ++++
 .../MemProfContextDisambiguation/tailcall.ll  |  84 ++++
 5 files changed, 638 insertions(+), 55 deletions(-)
 create mode 100644 llvm/test/ThinLTO/X86/memprof-tailcall.ll
 create mode 100644 llvm/test/Transforms/MemProfContextDisambiguation/tailcall.ll

diff --git a/llvm/include/llvm/IR/ModuleSummaryIndex.h b/llvm/include/llvm/IR/ModuleSummaryIndex.h
index e72f74ad4adb66..66c7d10d823d9c 100644
--- a/llvm/include/llvm/IR/ModuleSummaryIndex.h
+++ b/llvm/include/llvm/IR/ModuleSummaryIndex.h
@@ -1011,6 +1011,12 @@ class FunctionSummary : public GlobalValueSummary {
     return *Callsites;
   }
 
+  void addCallsite(CallsiteInfo &Callsite) {
+    if (!Callsites)
+      Callsites = std::make_unique<CallsitesTy>();
+    Callsites->push_back(Callsite);
+  }
+
   ArrayRef<AllocInfo> allocs() const {
     if (Allocs)
       return *Allocs;
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index 8fca569a391baf..a5fc267b1883bf 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -459,9 +459,24 @@ class IndexBitcodeWriter : public BitcodeWriterBase {
       // Record all stack id indices actually used in the summary entries being
       // written, so that we can compact them in the case of distributed ThinLTO
       // indexes.
-      for (auto &CI : FS->callsites())
+      for (auto &CI : FS->callsites()) {
+        // If the stack id list is empty, this callsite info was synthesized for
+        // a missing tail call frame. Ensure that the callee's GUID gets a value
+        // id. Normally we only generate these for defined summaries, which in
+        // the case of distributed ThinLTO is only the functions already defined
+        // in the module or that we want to import. We don't bother to include
+        // all the callee symbols as they aren't normally needed in the backend.
+        // However, for the synthesized callsite infos we do need the callee
+        // GUID in the backend so that we can correlate the identified callee
+        // with this callsite info (which for non-tail calls is done by the
+        // ordering of the callsite infos and verified via stack ids).
+        if (CI.StackIdIndices.empty()) {
+          GUIDToValueIdMap[CI.Callee.getGUID()] = ++GlobalValueId;
+          continue;
+        }
         for (auto Idx : CI.StackIdIndices)
           StackIdIndices.push_back(Idx);
+      }
       for (auto &AI : FS->allocs())
         for (auto &MIB : AI.MIBs)
           for (auto Idx : MIB.StackIdIndices)
diff --git a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
index 70a3f3067d9d6d..59f982cd420bb6 100644
--- a/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
+++ b/llvm/lib/Transforms/IPO/MemProfContextDisambiguation.cpp
@@ -77,6 +77,14 @@ STATISTIC(MaxAllocVersionsThinBackend,
           "allocation during ThinLTO backend");
 STATISTIC(UnclonableAllocsThinBackend,
           "Number of unclonable ambigous allocations during ThinLTO backend");
+STATISTIC(RemovedEdgesWithMismatchedCallees,
+          "Number of edges removed due to mismatched callees (profiled vs IR)");
+STATISTIC(FoundProfiledCalleeCount,
+          "Number of profiled callees found via tail calls");
+STATISTIC(FoundProfiledCalleeDepth,
+          "Aggregate depth of profiled callees found via tail calls");
+STATISTIC(FoundProfiledCalleeMaxDepth,
+          "Maximum depth of profiled callees found via tail calls");
 
 static cl::opt<std::string> DotFilePathPrefix(
     "memprof-dot-file-path-prefix", cl::init(""), cl::Hidden,
@@ -104,6 +112,12 @@ static cl::opt<std::string> MemProfImportSummary(
     cl::desc("Import summary to use for testing the ThinLTO backend via opt"),
     cl::Hidden);
 
+static cl::opt<unsigned>
+    TailCallSearchDepth("memprof-tail-call-search-depth", cl::init(5),
+                        cl::Hidden,
+                        cl::desc("Max depth to recursively search for missing "
+                                 "frames through tail calls."));
+
 namespace llvm {
 // Indicate we are linking with an allocator that supports hot/cold operator
 // new interfaces.
@@ -365,8 +379,7 @@ class CallsiteContextGraph {
 
   /// Save lists of calls with MemProf metadata in each function, for faster
   /// iteration.
-  std::vector<std::pair<FuncTy *, std::vector<CallInfo>>>
-      FuncToCallsWithMetadata;
+  std::map<FuncTy *, std::vector<CallInfo>> FuncToCallsWithMetadata;
 
   /// Map from callsite node to the enclosing caller function.
   std::map<const ContextNode *, const FuncTy *> NodeToCallingFunc;
@@ -411,9 +424,25 @@ class CallsiteContextGraph {
     return static_cast<const DerivedCCG *>(this)->getStackId(IdOrIndex);
   }
 
-  /// Returns true if the given call targets the given function.
-  bool calleeMatchesFunc(CallTy Call, const FuncTy *Func) {
-    return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(Call, Func);
+  /// Returns true if the given call targets the callee of the given edge, or if
+  /// we were able to identify the call chain through intermediate tail calls.
+  /// In the latter case new context nodes are added to the graph for the
+  /// identified tail calls, and their synthesized nodes are added to
+  /// TailCallToContextNodeMap. The EdgeIter is updated in either case to the
+  /// next element after the input position (either incremented or updated after
+  /// removing the old edge).
+  bool
+  calleesMatch(CallTy Call, EdgeIter &EI,
+               MapVector<CallInfo, ContextNode *> &TailCallToContextNodeMap);
+
+  /// Returns true if the given call targets the given function, or if we were
+  /// able to identify the call chain through intermediate tail calls (in which
+  /// case FoundCalleeChain will be populated).
+  bool calleeMatchesFunc(
+      CallTy Call, const FuncTy *Func, const FuncTy *CallerFunc,
+      std::vector<std::pair<CallTy, FuncTy *>> &FoundCalleeChain) {
+    return static_cast<DerivedCCG *>(this)->calleeMatchesFunc(
+        Call, Func, CallerFunc, FoundCalleeChain);
   }
 
   /// Get a list of nodes corresponding to the stack ids in the given
@@ -553,7 +582,12 @@ class ModuleCallsiteContextGraph
                               Instruction *>;
 
   uint64_t getStackId(uint64_t IdOrIndex) const;
-  bool calleeMatchesFunc(Instruction *Call, const Function *Func);
+  bool calleeMatchesFunc(
+      Instruction *Call, const Function *Func, const Function *CallerFunc,
+      std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain);
+  bool findProfiledCalleeThroughTailCalls(
+      const Function *ProfiledCallee, Value *CurCallee, unsigned Depth,
+      std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain);
   uint64_t getLastStackId(Instruction *Call);
   std::vector<uint64_t> getStackIdsWithContextNodesForCall(Instruction *Call);
   void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
@@ -606,12 +640,30 @@ class IndexCallsiteContextGraph
       function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
           isPrevailing);
 
+  ~IndexCallsiteContextGraph() {
+    // Now that we are done with the graph it is safe to add the new
+    // CallsiteInfo structs to the function summary vectors. The graph nodes
+    // point into locations within these vectors, so we don't want to add them
+    // any earlier.
+    for (auto &I : FunctionCalleesToSynthesizedCallsiteInfos) {
+      auto *FS = I.first;
+      for (auto &Callsite : I.second)
+        FS->addCallsite(*Callsite.second);
+    }
+  }
+
 private:
   friend CallsiteContextGraph<IndexCallsiteContextGraph, FunctionSummary,
                               IndexCall>;
 
   uint64_t getStackId(uint64_t IdOrIndex) const;
-  bool calleeMatchesFunc(IndexCall &Call, const FunctionSummary *Func);
+  bool calleeMatchesFunc(
+      IndexCall &Call, const FunctionSummary *Func,
+      const FunctionSummary *CallerFunc,
+      std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain);
+  bool findProfiledCalleeThroughTailCalls(
+      ValueInfo ProfiledCallee, ValueInfo CurCallee, unsigned Depth,
+      std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain);
   uint64_t getLastStackId(IndexCall &Call);
   std::vector<uint64_t> getStackIdsWithContextNodesForCall(IndexCall &Call);
   void updateAllocationCall(CallInfo &Call, AllocationType AllocType);
@@ -630,6 +682,16 @@ class IndexCallsiteContextGraph
   std::map<const FunctionSummary *, ValueInfo> FSToVIMap;
 
   const ModuleSummaryIndex &Index;
+  function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
+      isPrevailing;
+
+  // Saves/owns the callsite info structures synthesized for missing tail call
+  // frames that we discover while building the graph.
+  // It maps from the summary of the function making the tail call, to a map
+  // of callee ValueInfo to corresponding synthesized callsite info.
+  std::map<FunctionSummary *,
+           std::map<ValueInfo, std::unique_ptr<CallsiteInfo>>>
+      FunctionCalleesToSynthesizedCallsiteInfos;
 };
 } // namespace
 
@@ -1493,7 +1555,7 @@ ModuleCallsiteContextGraph::ModuleCallsiteContextGraph(
       }
     }
     if (!CallsWithMetadata.empty())
-      FuncToCallsWithMetadata.push_back({&F, CallsWithMetadata});
+      FuncToCallsWithMetadata[&F] = CallsWithMetadata;
   }
 
   if (DumpCCG) {
@@ -1518,7 +1580,7 @@ IndexCallsiteContextGraph::IndexCallsiteContextGraph(
     ModuleSummaryIndex &Index,
     function_ref<bool(GlobalValue::GUID, const GlobalValueSummary *)>
         isPrevailing)
-    : Index(Index) {
+    : Index(Index), isPrevailing(isPrevailing) {
   for (auto &I : Index) {
     auto VI = Index.getValueInfo(I);
     for (auto &S : VI.getSummaryList()) {
@@ -1572,7 +1634,7 @@ IndexCallsiteContextGraph::IndexCallsiteContextGraph(
           CallsWithMetadata.push_back({&SN});
 
       if (!CallsWithMetadata.empty())
-        FuncToCallsWithMetadata.push_back({FS, CallsWithMetadata});
+        FuncToCallsWithMetadata[FS] = CallsWithMetadata;
 
       if (!FS->allocs().empty() || !FS->callsites().empty())
         FSToVIMap[FS] = VI;
@@ -1604,6 +1666,11 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
   // this transformation for regular LTO, and for ThinLTO we can simulate that
   // effect in the summary and perform the actual speculative devirtualization
   // while cloning in the ThinLTO backend.
+
+  // Keep track of the new nodes synthesizes for discovered tail calls missing
+  // from the profiled contexts.
+  MapVector<CallInfo, ContextNode *> TailCallToContextNodeMap;
+
   for (auto Entry = NonAllocationCallToContextNodeMap.begin();
        Entry != NonAllocationCallToContextNodeMap.end();) {
     auto *Node = Entry->second;
@@ -1611,13 +1678,17 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
     // Check all node callees and see if in the same function.
     bool Removed = false;
     auto Call = Node->Call.call();
-    for (auto &Edge : Node->CalleeEdges) {
-      if (!Edge->Callee->hasCall())
+    for (auto EI = Node->CalleeEdges.begin(); EI != Node->CalleeEdges.end();) {
+      auto Edge = *EI;
+      if (!Edge->Callee->hasCall()) {
+        ++EI;
         continue;
+      }
       assert(NodeToCallingFunc.count(Edge->Callee));
       // Check if the called function matches that of the callee node.
-      if (calleeMatchesFunc(Call, NodeToCallingFunc[Edge->Callee]))
+      if (calleesMatch(Call, EI, TailCallToContextNodeMap))
         continue;
+      RemovedEdgesWithMismatchedCallees++;
       // Work around by setting Node to have a null call, so it gets
       // skipped during cloning. Otherwise assignFunctions will assert
       // because its data structures are not designed to handle this case.
@@ -1629,6 +1700,11 @@ void CallsiteContextGraph<DerivedCCG, FuncTy,
     if (!Removed)
       Entry++;
   }
+
+  // Add the new nodes after the above loop so that the iteration is not
+  // invalidated.
+  for (auto I : TailCallToContextNodeMap)
+    NonAllocationCallToContextNodeMap[I.first] = I.second;
 }
 
 uint64_t ModuleCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
@@ -1642,8 +1718,160 @@ uint64_t IndexCallsiteContextGraph::getStackId(uint64_t IdOrIndex) const {
   return Index.getStackIdAtIndex(IdOrIndex);
 }
 
-bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call,
-                                                   const Function *Func) {
+template <typename DerivedCCG, typename FuncTy, typename CallTy>
+bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::calleesMatch(
+    CallTy Call, EdgeIter &EI,
+    MapVector<CallInfo, ContextNode *> &TailCallToContextNodeMap) {
+  auto Edge = *EI;
+  const FuncTy *ProfiledCalleeFunc = NodeToCallingFunc[Edge->Callee];
+  const FuncTy *CallerFunc = NodeToCallingFunc[Edge->Caller];
+  // Will be populated in order of callee to caller if we find a chain of tail
+  // calls between the profiled caller and callee.
+  std::vector<std::pair<CallTy, FuncTy *>> FoundCalleeChain;
+  if (!calleeMatchesFunc(Call, ProfiledCalleeFunc, CallerFunc,
+                         FoundCalleeChain)) {
+    ++EI;
+    return false;
+  }
+
+  // The usual case where the profiled callee matches that of the IR/summary.
+  if (FoundCalleeChain.empty()) {
+    ++EI;
+    return true;
+  }
+
+  auto AddEdge = [Edge, &EI](ContextNode *Caller, ContextNode *Callee) {
+    auto *CurEdge = Callee->findEdgeFromCaller(Caller);
+    // If there is already an edge between these nodes, simply update it and
+    // return.
+    if (CurEdge) {
+      CurEdge->ContextIds.insert(Edge->ContextIds.begin(),
+                                 Edge->ContextIds.end());
+      CurEdge->AllocTypes |= Edge->AllocTypes;
+      return;
+    }
+    // Otherwise, create a new edge and insert it into the caller and callee
+    // lists.
+    auto NewEdge = std::make_shared<ContextEdge>(
+        Callee, Caller, Edge->AllocTypes, Edge->ContextIds);
+    Callee->CallerEdges.push_back(NewEdge);
+    if (Caller == Edge->Caller) {
+      // If we are inserting the new edge into the current edge's caller, insert
+      // the new edge before the current iterator position, and then increment
+      // back to the current edge.
+      EI = Caller->CalleeEdges.insert(EI, NewEdge);
+      ++EI;
+      assert(*EI == Edge);
+    } else
+      Caller->CalleeEdges.push_back(NewEdge);
+  };
+
+  // Create new nodes for each found callee and connect in between the profiled
+  // caller and callee.
+  auto *CurCalleeNode = Edge->Callee;
+  for (auto I : FoundCalleeChain) {
+    auto &NewCall = I.first;
+    auto *Func = I.second;
+    ContextNode *NewNode = nullptr;
+    // First check if we have already synthesized a node for this tail call.
+    if (TailCallToContextNodeMap.count(NewCall)) {
+      NewNode = TailCallToContextNodeMap[NewCall];
+      NewNode->ContextIds.insert(Edge->ContextIds.begin(),
+                                 Edge->ContextIds.end());
+      NewNode->AllocTypes |= Edge->AllocTypes;
+    } else {
+      FuncToCallsWithMetadata[Func].push_back({NewCall});
+      // Create Node and record node info.
+      NodeOwner.push_back(
+          std::make_unique<ContextNode>(/*IsAllocation=*/false, NewCall));
+      NewNode = NodeOwner.back().get();
+      NodeToCallingFunc[NewNode] = Func;
+      TailCallToContextNodeMap[NewCall] = NewNode;
+      NewNode->ContextIds = Edge->ContextIds;
+      NewNode->AllocTypes = Edge->AllocTypes;
+    }
+
+    // Hook up node to its callee node
+    AddEdge(NewNode, CurCalleeNode);
+
+    CurCalleeNode = NewNode;
+  }
+
+  // Hook up edge's original caller to new callee node.
+  AddEdge(Edge->Caller, CurCalleeNode);
+
+  // Remove old edge
+  Edge->Callee->eraseCallerEdge(Edge.get());
+  EI = Edge->Caller->CalleeEdges.erase(EI);
+
+  return true;
+}
+
+bool ModuleCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
+    const Function *ProfiledCallee, Value *CurCallee, unsigned Depth,
+    std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain) {
+  // Stop recursive search if we have already explored the maximum specified
+  // depth.
+  if (Depth > TailCallSearchDepth)
+    return false;
+
+  auto SaveCallsiteInfo = [&](Instruction *Callsite, Function *F) {
+    FoundCalleeChain.push_back({Callsite, F});
+  };
+
+  auto *CalleeFunc = dyn_cast<Function>(CurCallee);
+  if (!CalleeFunc) {
+    auto *Alias = dyn_cast<GlobalAlias>(CurCallee);
+    assert(Alias);
+    CalleeFunc = dyn_cast<Function>(Alias->getAliasee());
+    assert(CalleeFunc);
+  }
+
+  // Look for tail calls in this function, and check if they either call the
+  // profiled callee directly, or indirectly (via a recursive search).
+  for (auto &BB : *CalleeFunc) {
+    for (auto &I : BB) {
+      auto *CB = dyn_cast<CallBase>(&I);
+      if (!CB->isTailCall())
+        continue;
+      auto *CalledValue = CB->getCalledOperand();
+      auto *CalledFunction = CB->getCalledFunction();
+      if (CalledValue && !CalledFunction) {
+        CalledValue = CalledValue->stripPointerCasts();
+        // Stripping pointer casts can reveal a called function.
+        CalledFunction = dyn_cast<Function>(CalledValue);
+      }
+      // Check if this is an alias to a function. If so, get the
+      // called aliasee for the checks below.
+      if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) {
+        assert(!CalledFunction &&
+               "Expected null called function in callsite for alias");
+        CalledFunction = dyn_cast<Function>(GA->getAliaseeObject());
+      }
+      if (!CalledFunction)
+        continue;
+      if (CalledFunction == ProfiledCallee) {
+        FoundProfiledCalleeCount++;
+        FoundProfiledCalleeDepth += Depth;
+        if (Depth > FoundProfiledCalleeMaxDepth)
+          FoundProfiledCalleeMaxDepth = Depth;
+        SaveCallsiteInfo(&I, CalleeFunc);
+        return true;
+      }
+      if (findProfiledCalleeThroughTailCalls(ProfiledCallee, CalledFunction,
+                                             Depth + 1, FoundCalleeChain)) {
+        SaveCallsiteInfo(&I, CalleeFunc);
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+bool ModuleCallsiteContextGraph::calleeMatchesFunc(
+    Instruction *Call, const Function *Func, const Function *CallerFunc,
+    std::vector<std::pair<Instruction *, Function *>> &FoundCalleeChain) {
   auto *CB = dyn_cast<CallBase>(Call);
   if (!CB->getCalledOperand())
     return false;
@@ -1652,11 +1880,96 @@ bool ModuleCallsiteContextGraph::calleeMatchesFunc(Instruction *Call,
   if (CalleeFunc == Func)
     return true;
   auto *Alias = dyn_cast<GlobalAlias>(CalleeVal);
-  return Alias && Alias->getAliasee() == Func;
+  if (Alias && Alias->getAliasee() == Func)
+    return true;
+
+  // Recursively search for the profiled callee through tail calls starting with
+  // the actual Callee. The discovered tail call chain is saved in
+  // FoundCalleeChain, and we will fixup the graph to include these callsites
+  // after returning.
+  // FIXME: We will currently redo the same recursive walk if we find the same
+  // mismatched callee from another callsite. We can improve this with more
+  // bookkeeping of the created chain of new nodes for each mismatch.
+  unsigned Depth = 1;
+  if (!findProfiledCalleeThroughTailCalls(Func, CalleeVal, Depth,
+                                          FoundCalleeChain)) {
+    LLVM_DEBUG(dbgs() << "Not found through tail calls: " << Func->getName()
+                      << " from " << CallerFunc->getName()
+                      << " that actually called " << CalleeVal->getName()
+                      << "\n");
+    return false;
+  }
+
+  return true;
 }
 
-bool IndexCallsiteContextGraph::calleeMatchesFunc(IndexCall &Call,
-                                                  const FunctionSummary *Func) {
+bool IndexCallsiteContextGraph::findProfiledCalleeThroughTailCalls(
+    ValueInfo ProfiledCallee, ValueInfo CurCallee, unsigned Depth,
+    std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain) {
+  // Stop recursive search if we have already explored the maximum specified
+  // depth.
+  if (Depth > TailCallSearchDepth)
+    return false;
+
+  auto CreateAndSaveCallsiteInfo = [&](ValueInfo Callee, FunctionSummary *FS) {
+    // Make a CallsiteInfo for each discovered callee, if one hasn't already
+    // been synthesized.
+    if (!FunctionCalleesToSynthesizedCallsiteInfos.count(FS) ||
+        !FunctionCalleesToSynthesizedCallsiteInfos[FS].count(Callee))
+      // StackIds is empty (we don't have debug info available in the index for
+      // these callsites)
+      FunctionCalleesToSynthesizedCallsiteInfos[FS][Callee] =
+          std::make_unique<CallsiteInfo>(Callee, SmallVector<unsigned>());
+    CallsiteInfo *NewCallsiteInfo =
+        FunctionCalleesToSynthesizedCallsiteInfos[FS][Callee].get();
+    FoundCalleeChain.push_back({NewCallsiteInfo, FS});
+  };
+
+  // Look for tail calls in this function, and check if they either call the
+  // profiled callee directly, or indirectly (via a recursive search).
+  for (auto &S : CurCallee.getSummaryList()) {
+    if (!GlobalValue::isLocalLinkage(S->linkage()) &&
+        !isPrevailing(CurCallee.getGUID(), S.get()))
+      continue;
+    auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject());
+    if (!FS)
+      continue;
+    auto FSVI = CurCallee;
+    auto *AS = dyn_cast<AliasSummary>(S.get());
+    if (AS)
+      FSVI = AS->getAliaseeVI();
+    for (auto &CallEdge : FS->calls()) {
+      if (!CallEdge.second.hasTailCall())
+        continue;
+      if (CallEdge.first == ProfiledCallee) {
+        FoundProfiledCalleeCount++;
+        FoundProfiledCalleeDepth += Depth;
+        if (Depth > FoundProfiledCalleeMaxDepth)
+          FoundProfiledCalleeMaxDepth = Depth;
+        CreateAndSaveCallsiteInfo(CallEdge.first, FS);
+        // Add FS to FSToVIMap  in case it isn't already there.
+        assert(!FSToVIMap.count(FS) || FSToVIMap[FS] == FSVI);
+        FSToVIMap[FS] = FSVI;
+        return true;
+      }
+      if (findProfiledCalleeThroughTailCalls(ProfiledCallee, CallEdge.first,
+                                             Depth + 1, FoundCalleeChain)) {
+        CreateAndSaveCallsiteInfo(CallEdge.first, FS);
+        // Add FS to FSToVIMap  in case it isn't already there.
+        assert(!FSToVIMap.count(FS) || FSToVIMap[FS] == FSVI);
+        FSToVIMap[FS] = FSVI;
+        return true;
+      }
+    }
+  }
+
+  return false;
+}
+
+bool IndexCallsiteContextGraph::calleeMatchesFunc(
+    IndexCall &Call, const FunctionSummary *Func,
+    const FunctionSummary *CallerFunc,
+    std::vector<std::pair<IndexCall, FunctionSummary *>> &FoundCalleeChain) {
   ValueInfo Callee =
       dyn_cast_if_present<CallsiteInfo *>(Call.getBase())->Callee;
   // If there is no summary list then this is a call to an externally defined
@@ -1666,11 +1979,31 @@ bool IndexCallsiteContextGraph::calleeMatchesFunc(IndexCall &Call,
           ? nullptr
           : dyn_cast<AliasSummary>(Callee.getSummaryList()[0].get());
   assert(FSToVIMap.count(Func));
-  return Callee == FSToVIMap[Func] ||
-         // If callee is an alias, check the aliasee, since only function
-         // summary base objects will contain the stack node summaries and thus
-         // get a context node.
-         (Alias && Alias->getAliaseeVI() == FSToVIMap[Func]);
+  auto FuncVI = FSToVIMap[Func];
+  if (Callee == FuncVI ||
+      // If callee is an alias, check the aliasee, since only function
+      // summary base objects will contain the stack node summaries and thus
+      // get a context node.
+      (Alias && Alias->getAliaseeVI() == FuncVI))
+    return true;
+
+  // Recursively search for the profiled callee through tail calls starting with
+  // the actual Callee. The discovered tail call chain is saved in
+  // FoundCalleeChain, and we will fixup the graph to include these callsites
+  // after returning.
+  // FIXME: We will currently redo the same recursive walk if we find the same
+  // mismatched callee from another callsite. We can improve this with more
+  // bookkeeping of the created chain of new nodes for each mismatch.
+  unsigned Depth = 1;
+  if (!findProfiledCalleeThroughTailCalls(FuncVI, Callee, Depth,
+                                          FoundCalleeChain)) {
+    LLVM_DEBUG(dbgs() << "Not found through tail calls: " << FuncVI << " from "
+                      << FSToVIMap[CallerFunc] << " that actually called "
+                      << Callee << "\n");
+    return false;
+  }
+
+  return true;
 }
 
 static std::string getAllocTypeString(uint8_t AllocTypes) {
@@ -2533,6 +2866,9 @@ bool CallsiteContextGraph<DerivedCCG, FuncTy, CallTy>::assignFunctions() {
           // that were previously assigned to call PreviousAssignedFuncClone,
           // to record that they now call NewFuncClone.
           for (auto CE : Clone->CallerEdges) {
+            // Skip any that have been removed on an earlier iteration.
+            if (!CE)
+              continue;
             // Ignore any caller that does not have a recorded callsite Call.
             if (!CE->Caller->hasCall())
               continue;
@@ -2945,6 +3281,42 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
       NumClonesCreated = NumClones;
     };
 
+    auto CloneCallsite = [&](const CallsiteInfo &StackNode, CallBase *CB,
+                             Function *CalledFunction) {
+      // Perform cloning if not yet done.
+      CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size());
+
+      // Should have skipped indirect calls via mayHaveMemprofSummary.
+      assert(CalledFunction);
+      assert(!IsMemProfClone(*CalledFunction));
+
+      // Update the calls per the summary info.
+      // Save orig name since it gets updated in the first iteration
+      // below.
+      auto CalleeOrigName = CalledFunction->getName();
+      for (unsigned J = 0; J < StackNode.Clones.size(); J++) {
+        // Do nothing if this version calls the original version of its
+        // callee.
+        if (!StackNode.Clones[J])
+          continue;
+        auto NewF = M.getOrInsertFunction(
+            getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]),
+            CalledFunction->getFunctionType());
+        CallBase *CBClone;
+        // Copy 0 is the original function.
+        if (!J)
+          CBClone = CB;
+        else
+          CBClone = cast<CallBase>((*VMaps[J - 1])[CB]);
+        CBClone->setCalledFunction(NewF);
+        ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CBClone)
+                 << ore::NV("Call", CBClone) << " in clone "
+                 << ore::NV("Caller", CBClone->getFunction())
+                 << " assigned to call function clone "
+                 << ore::NV("Callee", NewF.getCallee()));
+      }
+    };
+
     // Locate the summary for F.
     ValueInfo TheFnVI = findValueInfoForFunc(F, M, ImportSummary);
     // If not found, this could be an imported local (see comment in
@@ -2974,6 +3346,23 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
     auto SI = FS->callsites().begin();
     auto AI = FS->allocs().begin();
 
+    // To handle callsite infos synthesized for tail calls which have missing
+    // frames in the profiled context, map callee VI to the synthesized callsite
+    // info.
+    DenseMap<ValueInfo, CallsiteInfo> MapTailCallCalleeVIToCallsite;
+    // Iterate the callsites for this function in reverse, since we place all
+    // those synthesized for tail calls at the end.
+    for (auto CallsiteIt = FS->callsites().rbegin();
+         CallsiteIt != FS->callsites().rend(); CallsiteIt++) {
+      auto &Callsite = *CallsiteIt;
+      // Stop as soon as we see a non-synthesized callsite info (see comment
+      // above loop). All the entries added for discovered tail calls have empty
+      // stack ids.
+      if (!Callsite.StackIdIndices.empty())
+        break;
+      MapTailCallCalleeVIToCallsite.insert({Callsite.Callee, Callsite});
+    }
+
     // Assume for now that the instructions are in the exact same order
     // as when the summary was created, but confirm this is correct by
     // matching the stack ids.
@@ -3126,37 +3515,16 @@ bool MemProfContextDisambiguation::applyImport(Module &M) {
           }
 #endif
 
-          // Perform cloning if not yet done.
-          CloneFuncIfNeeded(/*NumClones=*/StackNode.Clones.size());
-
-          // Should have skipped indirect calls via mayHaveMemprofSummary.
-          assert(CalledFunction);
-          assert(!IsMemProfClone(*CalledFunction));
-
-          // Update the calls per the summary info.
-          // Save orig name since it gets updated in the first iteration
-          // below.
-          auto CalleeOrigName = CalledFunction->getName();
-          for (unsigned J = 0; J < StackNode.Clones.size(); J++) {
-            // Do nothing if this version calls the original version of its
-            // callee.
-            if (!StackNode.Clones[J])
-              continue;
-            auto NewF = M.getOrInsertFunction(
-                getMemProfFuncName(CalleeOrigName, StackNode.Clones[J]),
-                CalledFunction->getFunctionType());
-            CallBase *CBClone;
-            // Copy 0 is the original function.
-            if (!J)
-              CBClone = CB;
-            else
-              CBClone = cast<CallBase>((*VMaps[J - 1])[CB]);
-            CBClone->setCalledFunction(NewF);
-            ORE.emit(OptimizationRemark(DEBUG_TYPE, "MemprofCall", CBClone)
-                     << ore::NV("Call", CBClone) << " in clone "
-                     << ore::NV("Caller", CBClone->getFunction())
-                     << " assigned to call function clone "
-                     << ore::NV("Callee", NewF.getCallee()));
+          CloneCallsite(StackNode, CB, CalledFunction);
+        } else if (CB->isTailCall()) {
+          // Locate the synthesized callsite info for the callee VI, if any was
+          // created, and use that for cloning.
+          ValueInfo CalleeVI =
+              findValueInfoForFunc(*CalledFunction, M, ImportSummary);
+          if (CalleeVI && MapTailCallCalleeVIToCallsite.count(CalleeVI)) {
+            auto Callsite = MapTailCallCalleeVIToCallsite.find(CalleeVI);
+            assert(Callsite != MapTailCallCalleeVIToCallsite.end());
+            CloneCallsite(Callsite->second, CB, CalledFunction);
           }
         }
         // Memprof and callsite metadata on memory allocations no longer needed.
diff --git a/llvm/test/ThinLTO/X86/memprof-tailcall.ll b/llvm/test/ThinLTO/X86/memprof-tailcall.ll
new file mode 100644
index 00000000000000..4207b8b2caf459
--- /dev/null
+++ b/llvm/test/ThinLTO/X86/memprof-tailcall.ll
@@ -0,0 +1,110 @@
+;; Test to make sure that missing tail call frames in memprof profiles are
+;; identified and cloned as needed for regular LTO.
+
+;; -stats requires asserts
+; REQUIRES: asserts
+
+; RUN: opt -thinlto-bc %s >%t.o
+; RUN: llvm-lto2 run %t.o -enable-memprof-context-disambiguation \
+; RUN:  -supports-hot-cold-new \
+; RUN:  -r=%t.o,_Z3barv,plx \
+; RUN:  -r=%t.o,_Z3bazv,plx \
+; RUN:  -r=%t.o,_Z3foov,plx \
+; RUN:  -r=%t.o,main,plx \
+; RUN:  -r=%t.o,_Znam, \
+; RUN:  -stats -save-temps \
+; RUN:  -o %t.out 2>&1 | FileCheck %s --check-prefix=STATS
+
+; RUN: llvm-dis %t.out.1.4.opt.bc -o - | FileCheck %s --check-prefix=IR
+
+;; Try again but with distributed ThinLTO
+; RUN: llvm-lto2 run %t.o -enable-memprof-context-disambiguation \
+; RUN:  -supports-hot-cold-new \
+; RUN:  -thinlto-distributed-indexes \
+; RUN:  -r=%t.o,_Z3barv,plx \
+; RUN:  -r=%t.o,_Z3bazv,plx \
+; RUN:  -r=%t.o,_Z3foov,plx \
+; RUN:  -r=%t.o,main,plx \
+; RUN:  -r=%t.o,_Znam, \
+; RUN:  -stats \
+; RUN:  -o %t2.out 2>&1 | FileCheck %s --check-prefix=STATS
+
+;; Run ThinLTO backend
+; RUN: opt -passes=memprof-context-disambiguation \
+; RUN:  -memprof-import-summary=%t.o.thinlto.bc \
+; RUN:  -stats %t.o -S 2>&1 | FileCheck %s --check-prefix=IR
+
+; STATS: 2 memprof-context-disambiguation - Number of profiled callees found via tail calls
+; STATS: 4 memprof-context-disambiguation - Aggregate depth of profiled callees found via tail calls
+; STATS: 2 memprof-context-disambiguation - Maximum depth of profiled callees found via tail calls
+
+source_filename = "memprof-tailcall.cc"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3barv()
+define ptr @_Z3barv() local_unnamed_addr #0 {
+entry:
+  ; IR: call {{.*}}  @_Znam(i64 10) #[[NOTCOLD:[0-9]+]]
+  %call = tail call ptr @_Znam(i64 10) #2, !memprof !0, !callsite !5
+  ret ptr %call
+}
+
+; Function Attrs: nobuiltin allocsize(0)
+declare ptr @_Znam(i64) #1
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3bazv()
+define ptr @_Z3bazv() #0 {
+entry:
+  ; IR: call ptr @_Z3barv()
+  %call = tail call ptr @_Z3barv()
+  ret ptr %call
+}
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3foov()
+define ptr @_Z3foov() #0 {
+entry:
+  ; IR: call ptr @_Z3bazv()
+  %call = tail call ptr @_Z3bazv()
+  ret ptr %call
+}
+
+; Function Attrs: noinline
+; IR-LABEL: @main()
+define i32 @main() #0 {
+  ;; The first call to foo is part of a cold context, and should use the
+  ;; original functions.
+  ; IR: call ptr @_Z3foov()
+  %call = tail call ptr @_Z3foov(), !callsite !6
+  ;; The second call to foo is part of a cold context, and should call the
+  ;; cloned functions.
+  ; IR: call ptr @_Z3foov.memprof.1()
+  %call1 = tail call ptr @_Z3foov(), !callsite !7
+  ret i32 0
+}
+
+; IR-LABEL: @_Z3barv.memprof.1()
+; IR: call {{.*}}  @_Znam(i64 10) #[[COLD:[0-9]+]]
+; IR-LABEL: @_Z3bazv.memprof.1()
+; IR: call ptr @_Z3barv.memprof.1()
+; IR-LABEL: @_Z3foov.memprof.1()
+; IR: call ptr @_Z3bazv.memprof.1()
+
+; IR: attributes #[[NOTCOLD]] = { builtin allocsize(0) "memprof"="notcold" }
+; IR: attributes #[[COLD]] = { builtin allocsize(0) "memprof"="cold" }
+
+attributes #0 = { noinline }
+attributes #1 = { nobuiltin allocsize(0) }
+attributes #2 = { builtin allocsize(0) }
+
+!0 = !{!1, !3}
+!1 = !{!2, !"notcold"}
+!2 = !{i64 3186456655321080972, i64 8632435727821051414}
+!3 = !{!4, !"cold"}
+!4 = !{i64 3186456655321080972, i64 -3421689549917153178}
+!5 = !{i64 3186456655321080972}
+!6 = !{i64 8632435727821051414}
+!7 = !{i64 -3421689549917153178}
diff --git a/llvm/test/Transforms/MemProfContextDisambiguation/tailcall.ll b/llvm/test/Transforms/MemProfContextDisambiguation/tailcall.ll
new file mode 100644
index 00000000000000..5f09c23d40e6d1
--- /dev/null
+++ b/llvm/test/Transforms/MemProfContextDisambiguation/tailcall.ll
@@ -0,0 +1,84 @@
+;; Test to make sure that missing tail call frames in memprof profiles are
+;; identified and cloned as needed for regular LTO.
+
+;; -stats requires asserts
+; REQUIRES: asserts
+
+; RUN: opt -passes=memprof-context-disambiguation -supports-hot-cold-new \
+; RUN:  -stats %s -S 2>&1 | FileCheck %s --check-prefix=STATS --check-prefix=IR
+
+source_filename = "memprof-tailcall.cc"
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3barv()
+define ptr @_Z3barv() local_unnamed_addr #0 {
+entry:
+  ; IR: call ptr @_Znam(i64 10) #[[NOTCOLD:[0-9]+]]
+  %call = tail call ptr @_Znam(i64 10) #2, !memprof !0, !callsite !5
+  ret ptr %call
+}
+
+; Function Attrs: nobuiltin allocsize(0)
+declare ptr @_Znam(i64) #1
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3bazv()
+define ptr @_Z3bazv() #0 {
+entry:
+  ; IR: call ptr @_Z3barv()
+  %call = tail call ptr @_Z3barv()
+  ret ptr %call
+}
+
+; Function Attrs: noinline
+; IR-LABEL: @_Z3foov()
+define ptr @_Z3foov() #0 {
+entry:
+  ; IR: call ptr @_Z3bazv()
+  %call = tail call ptr @_Z3bazv()
+  ret ptr %call
+}
+
+; Function Attrs: noinline
+; IR-LABEL: @main()
+define i32 @main() #0 {
+  ;; The first call to foo is part of a cold context, and should use the
+  ;; original functions.
+  ;; allocation. The latter should call the cloned functions.
+  ; IR: call ptr @_Z3foov()
+  %call = tail call ptr @_Z3foov(), !callsite !6
+  ;; The second call to foo is part of a cold context, and should call the
+  ;; cloned functions.
+  ; IR: call ptr @_Z3foov.memprof.1()
+  %call1 = tail call ptr @_Z3foov(), !callsite !7
+  ret i32 0
+}
+
+; IR-LABEL: @_Z3bazv.memprof.1()
+; IR: call ptr @_Z3barv.memprof.1()
+; IR-LABEL: @_Z3foov.memprof.1()
+; IR: call ptr @_Z3bazv.memprof.1()
+; IR-LABEL: @_Z3barv.memprof.1()
+; IR: call ptr @_Znam(i64 10) #[[COLD:[0-9]+]]
+
+; IR: attributes #[[NOTCOLD]] = { builtin allocsize(0) "memprof"="notcold" }
+; IR: attributes #[[COLD]] = { builtin allocsize(0) "memprof"="cold" }
+
+; STATS: 2 memprof-context-disambiguation - Number of profiled callees found via tail calls
+; STATS: 4 memprof-context-disambiguation - Aggregate depth of profiled callees found via tail calls
+; STATS: 2 memprof-context-disambiguation - Maximum depth of profiled callees found via tail calls
+
+attributes #0 = { noinline }
+attributes #1 = { nobuiltin allocsize(0) }
+attributes #2 = { builtin allocsize(0) }
+
+!0 = !{!1, !3}
+!1 = !{!2, !"notcold"}
+!2 = !{i64 3186456655321080972, i64 8632435727821051414}
+!3 = !{!4, !"cold"}
+!4 = !{i64 3186456655321080972, i64 -3421689549917153178}
+!5 = !{i64 3186456655321080972}
+!6 = !{i64 8632435727821051414}
+!7 = !{i64 -3421689549917153178}



More information about the llvm-commits mailing list