[compiler-rt] [llvm] [ctxprof] root autodetection mechanism (PR #133147)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 31 12:25:57 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/133147

>From 1a12853fd9f31f02ca289ed2bce8339d1ef181d5 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Mon, 24 Mar 2025 12:01:10 -0700
Subject: [PATCH] RootAutodetect

---
 compiler-rt/lib/ctx_profile/CMakeLists.txt    |   2 +-
 .../lib/ctx_profile/CtxInstrContextNode.h     |   1 +
 .../lib/ctx_profile/CtxInstrProfiling.cpp     | 119 +++++++----
 .../lib/ctx_profile/CtxInstrProfiling.h       |   2 +-
 .../lib/ctx_profile/RootAutoDetector.cpp      |  94 +++++++++
 .../lib/ctx_profile/RootAutoDetector.h        |  31 +++
 .../TestCases/autodetect-roots.cpp            | 188 ++++++++++++++++++
 .../TestCases/generate-context.cpp            |   5 +-
 .../llvm/ProfileData/CtxInstrContextNode.h    |   1 +
 .../Instrumentation/PGOCtxProfLowering.cpp    |  26 ++-
 .../PGOProfile/ctx-instrumentation.ll         |  50 ++++-
 11 files changed, 457 insertions(+), 62 deletions(-)
 create mode 100644 compiler-rt/test/ctx_profile/TestCases/autodetect-roots.cpp

diff --git a/compiler-rt/lib/ctx_profile/CMakeLists.txt b/compiler-rt/lib/ctx_profile/CMakeLists.txt
index bb606449c61b1..446ebc96408dd 100644
--- a/compiler-rt/lib/ctx_profile/CMakeLists.txt
+++ b/compiler-rt/lib/ctx_profile/CMakeLists.txt
@@ -27,7 +27,7 @@ endif()
 add_compiler_rt_runtime(clang_rt.ctx_profile
   STATIC
   ARCHS ${CTX_PROFILE_SUPPORTED_ARCH}
-  OBJECT_LIBS RTSanitizerCommon RTSanitizerCommonLibc
+  OBJECT_LIBS RTSanitizerCommon RTSanitizerCommonLibc RTSanitizerCommonSymbolizer
   CFLAGS ${EXTRA_FLAGS}
   SOURCES ${CTX_PROFILE_SOURCES}
   ADDITIONAL_HEADERS ${CTX_PROFILE_HEADERS}
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
index a42bf9ebb01ea..55423d95b3088 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
@@ -127,6 +127,7 @@ class ContextNode final {
 /// MUTEXDECL takes one parameter, the name of a field that is a mutex.
 #define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL)            \
   PTRDECL(FunctionData, Next)                                                  \
+  VOLATILE_PTRDECL(void, EntryAddress)                                         \
   VOLATILE_PTRDECL(ContextRoot, CtxRoot)                                       \
   VOLATILE_PTRDECL(ContextNode, FlatCtx)                                       \
   MUTEXDECL(Mutex)
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
index da291e0bbabdd..09ed607cde3aa 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "CtxInstrProfiling.h"
+#include "RootAutoDetector.h"
 #include "sanitizer_common/sanitizer_allocator_internal.h"
 #include "sanitizer_common/sanitizer_atomic.h"
 #include "sanitizer_common/sanitizer_atomic_clang.h"
@@ -43,6 +44,12 @@ Arena *FlatCtxArena = nullptr;
 __thread bool IsUnderContext = false;
 __sanitizer::atomic_uint8_t ProfilingStarted = {};
 
+__sanitizer::atomic_uintptr_t RootDetector = {};
+RootAutoDetector *getRootDetector() {
+  return reinterpret_cast<RootAutoDetector *>(
+      __sanitizer::atomic_load_relaxed(&RootDetector));
+}
+
 // utility to taint a pointer by setting the LSB. There is an assumption
 // throughout that the addresses of contexts are even (really, they should be
 // align(8), but "even"-ness is the minimum assumption)
@@ -201,7 +208,7 @@ ContextNode *getCallsiteSlow(GUID Guid, ContextNode **InsertionPoint,
   return Ret;
 }
 
-ContextNode *getFlatProfile(FunctionData &Data, GUID Guid,
+ContextNode *getFlatProfile(FunctionData &Data, void *Callee, GUID Guid,
                             uint32_t NumCounters) {
   if (ContextNode *Existing = Data.FlatCtx)
     return Existing;
@@ -232,6 +239,7 @@ ContextNode *getFlatProfile(FunctionData &Data, GUID Guid,
     auto *Ret = allocContextNode(AllocBuff, Guid, NumCounters, 0);
     Data.FlatCtx = Ret;
 
+    Data.EntryAddress = Callee;
     Data.Next = reinterpret_cast<FunctionData *>(
         __sanitizer::atomic_load_relaxed(&AllFunctionsData));
     while (!__sanitizer::atomic_compare_exchange_strong(
@@ -277,8 +285,29 @@ ContextRoot *FunctionData::getOrAllocateContextRoot() {
   return Root;
 }
 
-ContextNode *getUnhandledContext(FunctionData &Data, GUID Guid,
-                                 uint32_t NumCounters) {
+ContextNode *tryStartContextGivenRoot(ContextRoot *Root, GUID Guid,
+                                      uint32_t Counters, uint32_t Callsites)
+    SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  IsUnderContext = true;
+  __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
+                                __sanitizer::memory_order_relaxed);
+
+  if (!Root->FirstMemBlock) {
+    setupContext(Root, Guid, Counters, Callsites);
+  }
+  if (Root->Taken.TryLock()) {
+    __llvm_ctx_profile_current_context_root = Root;
+    onContextEnter(*Root->FirstNode);
+    return Root->FirstNode;
+  }
+  // If this thread couldn't take the lock, return scratch context.
+  __llvm_ctx_profile_current_context_root = nullptr;
+  return TheScratchContext;
+}
+
+ContextNode *getUnhandledContext(FunctionData &Data, void *Callee, GUID Guid,
+                                 uint32_t NumCounters, uint32_t NumCallsites,
+                                 ContextRoot *CtxRoot) {
 
   // 1) if we are currently collecting a contextual profile, fetch a ContextNode
   // in the `Unhandled` set. We want to do this regardless of `ProfilingStarted`
@@ -297,27 +326,32 @@ ContextNode *getUnhandledContext(FunctionData &Data, GUID Guid,
   // entered once and never exit. They should be assumed to be entered before
   // profiling starts - because profiling should start after the server is up
   // and running (which is equivalent to "message pumps are set up").
-  ContextRoot *R = __llvm_ctx_profile_current_context_root;
-  if (!R) {
+  if (!CtxRoot) {
+    if (auto *RAD = getRootDetector())
+      RAD->sample();
+    else if (auto *CR = Data.CtxRoot)
+      return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
     if (IsUnderContext || !__sanitizer::atomic_load_relaxed(&ProfilingStarted))
       return TheScratchContext;
     else
       return markAsScratch(
-          onContextEnter(*getFlatProfile(Data, Guid, NumCounters)));
+          onContextEnter(*getFlatProfile(Data, Callee, Guid, NumCounters)));
   }
-  auto [Iter, Ins] = R->Unhandled.insert({Guid, nullptr});
+  auto [Iter, Ins] = CtxRoot->Unhandled.insert({Guid, nullptr});
   if (Ins)
-    Iter->second =
-        getCallsiteSlow(Guid, &R->FirstUnhandledCalleeNode, NumCounters, 0);
+    Iter->second = getCallsiteSlow(Guid, &CtxRoot->FirstUnhandledCalleeNode,
+                                   NumCounters, 0);
   return markAsScratch(onContextEnter(*Iter->second));
 }
 
 ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
                                             GUID Guid, uint32_t NumCounters,
                                             uint32_t NumCallsites) {
+  auto *CtxRoot = __llvm_ctx_profile_current_context_root;
   // fast "out" if we're not even doing contextual collection.
-  if (!__llvm_ctx_profile_current_context_root)
-    return getUnhandledContext(*Data, Guid, NumCounters);
+  if (!CtxRoot)
+    return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
+                               nullptr);
 
   // also fast "out" if the caller is scratch. We can see if it's scratch by
   // looking at the interior pointer into the subcontexts vector that the caller
@@ -326,7 +360,8 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
   // precisely, aligned - 8 values)
   auto **CallsiteContext = consume(__llvm_ctx_profile_callsite[0]);
   if (!CallsiteContext || isScratch(CallsiteContext))
-    return getUnhandledContext(*Data, Guid, NumCounters);
+    return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
+                               CtxRoot);
 
   // if the callee isn't the expected one, return scratch.
   // Signal handler(s) could have been invoked at any point in the execution.
@@ -344,7 +379,8 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
   // for that case.
   auto *ExpectedCallee = consume(__llvm_ctx_profile_expected_callee[0]);
   if (ExpectedCallee != Callee)
-    return getUnhandledContext(*Data, Guid, NumCounters);
+    return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
+                               CtxRoot);
 
   auto *Callsite = *CallsiteContext;
   // in the case of indirect calls, we will have all seen targets forming a
@@ -366,40 +402,26 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
   return Ret;
 }
 
-ContextNode *__llvm_ctx_profile_start_context(
-    FunctionData *FData, GUID Guid, uint32_t Counters,
-    uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
-  IsUnderContext = true;
-
-  auto *Root = FData->getOrAllocateContextRoot();
-
-  __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
-                                __sanitizer::memory_order_relaxed);
+ContextNode *__llvm_ctx_profile_start_context(FunctionData *FData, GUID Guid,
+                                              uint32_t Counters,
+                                              uint32_t Callsites) {
 
-  if (!Root->FirstMemBlock) {
-    setupContext(Root, Guid, Counters, Callsites);
-  }
-  if (Root->Taken.TryLock()) {
-    __llvm_ctx_profile_current_context_root = Root;
-    onContextEnter(*Root->FirstNode);
-    return Root->FirstNode;
-  }
-  // If this thread couldn't take the lock, return scratch context.
-  __llvm_ctx_profile_current_context_root = nullptr;
-  return TheScratchContext;
+  return tryStartContextGivenRoot(FData->getOrAllocateContextRoot(), Guid,
+                                  Counters, Callsites);
 }
 
 void __llvm_ctx_profile_release_context(FunctionData *FData)
     SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  const auto *CurrentRoot = __llvm_ctx_profile_current_context_root;
+  if (!CurrentRoot || FData->CtxRoot != CurrentRoot)
+    return;
   IsUnderContext = false;
-  if (__llvm_ctx_profile_current_context_root) {
-    __llvm_ctx_profile_current_context_root = nullptr;
-    assert(FData->CtxRoot);
-    FData->CtxRoot->Taken.Unlock();
-  }
+  assert(FData->CtxRoot);
+  __llvm_ctx_profile_current_context_root = nullptr;
+  FData->CtxRoot->Taken.Unlock();
 }
 
-void __llvm_ctx_profile_start_collection() {
+void __llvm_ctx_profile_start_collection(unsigned AutodetectDuration) {
   size_t NumMemUnits = 0;
   __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
       &AllContextsMutex);
@@ -415,12 +437,24 @@ void __llvm_ctx_profile_start_collection() {
       resetContextNode(*Root->FirstUnhandledCalleeNode);
     __sanitizer::atomic_store_relaxed(&Root->TotalEntries, 0);
   }
+  if (AutodetectDuration) {
+    auto *RD = new (__sanitizer::InternalAlloc(sizeof(RootAutoDetector)))
+        RootAutoDetector(AllFunctionsData, RootDetector, AutodetectDuration);
+    RD->start();
+  } else {
+    __sanitizer::Printf("[ctxprof] Initial NumMemUnits: %zu \n", NumMemUnits);
+  }
   __sanitizer::atomic_store_relaxed(&ProfilingStarted, true);
-  __sanitizer::Printf("[ctxprof] Initial NumMemUnits: %zu \n", NumMemUnits);
 }
 
 bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
   __sanitizer::atomic_store_relaxed(&ProfilingStarted, false);
+  if (auto *RD = getRootDetector()) {
+    __sanitizer::Printf("[ctxprof] Expected the root autodetector to have "
+                        "finished well before attempting to fetch a context");
+    RD->join();
+  }
+
   __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
       &AllContextsMutex);
 
@@ -445,8 +479,9 @@ bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
   const auto *Pos = reinterpret_cast<const FunctionData *>(
       __sanitizer::atomic_load_relaxed(&AllFunctionsData));
   for (; Pos; Pos = Pos->Next)
-    Writer.writeFlat(Pos->FlatCtx->guid(), Pos->FlatCtx->counters(),
-                     Pos->FlatCtx->counters_size());
+    if (!Pos->CtxRoot)
+      Writer.writeFlat(Pos->FlatCtx->guid(), Pos->FlatCtx->counters(),
+                       Pos->FlatCtx->counters_size());
   Writer.endFlatSection();
   return true;
 }
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
index 6326beaa53085..4983f086d230d 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
@@ -207,7 +207,7 @@ ContextNode *__llvm_ctx_profile_get_context(__ctx_profile::FunctionData *FData,
 
 /// Prepares for collection. Currently this resets counter values but preserves
 /// internal context tree structure.
-void __llvm_ctx_profile_start_collection();
+void __llvm_ctx_profile_start_collection(unsigned AutodetectDuration = 0);
 
 /// Completely free allocated memory.
 void __llvm_ctx_profile_free();
diff --git a/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp b/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
index 483c55c25eefe..281ce5e33865a 100644
--- a/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
+++ b/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
@@ -8,6 +8,7 @@
 
 #include "RootAutoDetector.h"
 
+#include "CtxInstrProfiling.h"
 #include "sanitizer_common/sanitizer_common.h"
 #include "sanitizer_common/sanitizer_placement_new.h" // IWYU pragma: keep (DenseMap)
 #include <assert.h>
@@ -17,6 +18,99 @@
 using namespace __ctx_profile;
 template <typename T> using Set = DenseMap<T, bool>;
 
+namespace __sanitizer {
+void BufferedStackTrace::UnwindImpl(uptr pc, uptr bp, void *context,
+                                    bool request_fast, u32 max_depth) {
+  // We can't implement the fast variant. The fast variant ends up invoking an
+  // external allocator, because of pthread_attr_getstack. If this happens
+  // during an allocation of the program being instrumented, a non-reentrant
+  // lock may be taken (this was observed). The allocator called by
+  // pthread_attr_getstack will also try to take that lock.
+  UnwindSlow(pc, max_depth);
+}
+} // namespace __sanitizer
+
+RootAutoDetector::PerThreadSamples::PerThreadSamples(RootAutoDetector &Parent) {
+  GenericScopedLock<SpinMutex> L(&Parent.AllSamplesMutex);
+  Parent.AllSamples.PushBack(this);
+}
+
+void RootAutoDetector::start() {
+  atomic_store_relaxed(&Self, reinterpret_cast<uintptr_t>(this));
+  pthread_create(
+      &WorkerThread, nullptr,
+      +[](void *Ctx) -> void * {
+        RootAutoDetector *RAD = reinterpret_cast<RootAutoDetector *>(Ctx);
+        SleepForSeconds(RAD->WaitSeconds);
+        // To avoid holding the AllSamplesMutex, make a snapshot of all the
+        // thread samples collected so far
+        Vector<PerThreadSamples *> SamplesSnapshot;
+        {
+          GenericScopedLock<SpinMutex> M(&RAD->AllSamplesMutex);
+          SamplesSnapshot.Resize(RAD->AllSamples.Size());
+          for (uptr I = 0; I < RAD->AllSamples.Size(); ++I)
+            SamplesSnapshot[I] = RAD->AllSamples[I];
+        }
+        DenseMap<uptr, uint64_t> AllRoots;
+        for (uptr I = 0; I < SamplesSnapshot.Size(); ++I) {
+          GenericScopedLock<SpinMutex>(&SamplesSnapshot[I]->M);
+          SamplesSnapshot[I]->TrieRoot.determineRoots().forEach([&](auto &KVP) {
+            auto [FAddr, Count] = KVP;
+            AllRoots[FAddr] += Count;
+            return true;
+          });
+        }
+        // FIXME: as a next step, establish a minimum relative nr of samples
+        // per root that would qualify it as a root.
+        for (auto *FD = reinterpret_cast<FunctionData *>(
+                 atomic_load_relaxed(&RAD->FunctionDataListHead));
+             FD; FD = FD->Next) {
+          if (AllRoots.contains(reinterpret_cast<uptr>(FD->EntryAddress))) {
+            FD->getOrAllocateContextRoot();
+          }
+        }
+        atomic_store_relaxed(&RAD->Self, 0);
+        return nullptr;
+      },
+      this);
+}
+
+void RootAutoDetector::join() { pthread_join(WorkerThread, nullptr); }
+
+void RootAutoDetector::sample() {
+  // tracking reentry in case we want to re-explore fast stack unwind - which
+  // does potentially re-enter the runtime because it calls the instrumented
+  // allocator because of pthread_attr_getstack. See the notes also on
+  // UnwindImpl above.
+  static thread_local bool Entered = false;
+  static thread_local uint64_t Entries = 0;
+  if (Entered || (++Entries % SampleRate))
+    return;
+  Entered = true;
+  collectStack();
+  Entered = false;
+}
+
+void RootAutoDetector::collectStack() {
+  GET_CALLER_PC_BP;
+  BufferedStackTrace CurrentStack;
+  CurrentStack.Unwind(pc, bp, nullptr, false);
+  // 2 stack frames would be very unlikely to mean anything, since at least the
+  // compiler-rt frame - which can't be inlined - should be observable, which
+  // counts as 1; we can be even more aggressive with this number.
+  if (CurrentStack.size <= 2)
+    return;
+  static thread_local PerThreadSamples *ThisThreadSamples =
+      new (__sanitizer::InternalAlloc(sizeof(PerThreadSamples)))
+          PerThreadSamples(*this);
+
+  if (!ThisThreadSamples->M.TryLock())
+    return;
+
+  ThisThreadSamples->TrieRoot.insertStack(CurrentStack);
+  ThisThreadSamples->M.Unlock();
+}
+
 uptr PerThreadCallsiteTrie::getFctStartAddr(uptr CallsiteAddress) const {
   // this requires --linkopt=-Wl,--export-dynamic
   Dl_info Info;
diff --git a/compiler-rt/lib/ctx_profile/RootAutoDetector.h b/compiler-rt/lib/ctx_profile/RootAutoDetector.h
index 85dd5ef1c32d9..5c2abaeb1d0fa 100644
--- a/compiler-rt/lib/ctx_profile/RootAutoDetector.h
+++ b/compiler-rt/lib/ctx_profile/RootAutoDetector.h
@@ -12,6 +12,7 @@
 #include "sanitizer_common/sanitizer_dense_map.h"
 #include "sanitizer_common/sanitizer_internal_defs.h"
 #include "sanitizer_common/sanitizer_stacktrace.h"
+#include "sanitizer_common/sanitizer_vector.h"
 #include <pthread.h>
 #include <sanitizer/common_interface_defs.h>
 
@@ -53,5 +54,35 @@ class PerThreadCallsiteTrie {
   /// thread, together with the number of samples that included them.
   DenseMap<uptr, uint64_t> determineRoots() const;
 };
+
+class RootAutoDetector final {
+  static const uint64_t SampleRate = 6113;
+  const unsigned WaitSeconds;
+  pthread_t WorkerThread;
+
+  struct PerThreadSamples {
+    PerThreadSamples(RootAutoDetector &Parent);
+
+    PerThreadCallsiteTrie TrieRoot;
+    SpinMutex M;
+  };
+  SpinMutex AllSamplesMutex;
+  SANITIZER_GUARDED_BY(AllSamplesMutex)
+  Vector<PerThreadSamples *> AllSamples;
+  atomic_uintptr_t &FunctionDataListHead;
+  atomic_uintptr_t &Self;
+  void collectStack();
+
+public:
+  RootAutoDetector(atomic_uintptr_t &FunctionDataListHead,
+                   atomic_uintptr_t &Self, unsigned WaitSeconds)
+      : WaitSeconds(WaitSeconds), FunctionDataListHead(FunctionDataListHead),
+        Self(Self) {}
+
+  void sample();
+  void start();
+  void join();
+};
+
 } // namespace __ctx_profile
 #endif
diff --git a/compiler-rt/test/ctx_profile/TestCases/autodetect-roots.cpp b/compiler-rt/test/ctx_profile/TestCases/autodetect-roots.cpp
new file mode 100644
index 0000000000000..d4d4eb0230fc6
--- /dev/null
+++ b/compiler-rt/test/ctx_profile/TestCases/autodetect-roots.cpp
@@ -0,0 +1,188 @@
+// Root autodetection test for contextual profiling
+//
+// Copy the header defining ContextNode.
+// RUN: mkdir -p %t_include
+// RUN: cp %llvm_src/include/llvm/ProfileData/CtxInstrContextNode.h %t_include/
+//
+// Compile with ctx instrumentation "on". We use -profile-context-root as signal
+// that we want contextual profiling, but we can specify anything there, that
+// won't be matched with any function, and result in the behavior we are aiming
+// for here.
+//
+// RUN: %clangxx %s %ctxprofilelib -I%t_include -O2 -o %t.bin \
+// RUN:   -mllvm -profile-context-root="<autodetect>" -g -Wl,-export-dynamic
+//
+// Run the binary, and observe the profile fetch handler's output.
+// RUN %t.bin | FileCheck %s
+
+#include "CtxInstrContextNode.h"
+#include <atomic>
+#include <cstdio>
+#include <iostream>
+#include <thread>
+
+using namespace llvm::ctx_profile;
+extern "C" void __llvm_ctx_profile_start_collection(unsigned);
+extern "C" bool __llvm_ctx_profile_fetch(ProfileWriter &);
+
+// avoid name mangling
+extern "C" {
+__attribute__((noinline)) void anotherFunction() {}
+__attribute__((noinline)) void mock1() {}
+__attribute__((noinline)) void mock2() {}
+__attribute__((noinline)) void someFunction(int I) {
+  if (I % 2)
+    mock1();
+  else
+    mock2();
+  anotherFunction();
+}
+
+// block inlining because the pre-inliner otherwise will inline this - it's
+// too small.
+__attribute__((noinline)) void theRoot() {
+  someFunction(1);
+#pragma nounroll
+  for (auto I = 0; I < 2; ++I) {
+    someFunction(I);
+  }
+  anotherFunction();
+}
+}
+
+class TestProfileWriter : public ProfileWriter {
+  void printProfile(const ContextNode &Node, const std::string &Indent,
+                    const std::string &Increment) {
+    std::cout << Indent << "Guid: " << Node.guid() << std::endl;
+    std::cout << Indent << "Entries: " << Node.entrycount() << std::endl;
+    std::cout << Indent << Node.counters_size() << " counters and "
+              << Node.callsites_size() << " callsites" << std::endl;
+    std::cout << Indent << "Counter values: ";
+    for (uint32_t I = 0U; I < Node.counters_size(); ++I)
+      std::cout << Node.counters()[I] << " ";
+    std::cout << std::endl;
+    for (uint32_t I = 0U; I < Node.callsites_size(); ++I)
+      for (const auto *N = Node.subContexts()[I]; N; N = N->next()) {
+        std::cout << Indent << "At Index " << I << ":" << std::endl;
+        printProfile(*N, Indent + Increment, Increment);
+      }
+  }
+
+  void startContextSection() override {
+    std::cout << "Entered Context Section" << std::endl;
+  }
+
+  void endContextSection() override {
+    std::cout << "Exited Context Section" << std::endl;
+  }
+
+  void writeContextual(const ContextNode &RootNode,
+                       const ContextNode *Unhandled,
+                       uint64_t EntryCount) override {
+    std::cout << "Entering Root " << RootNode.guid()
+              << " with total entry count " << EntryCount << std::endl;
+    for (const auto *P = Unhandled; P; P = P->next())
+      std::cout << "Unhandled GUID: " << P->guid() << " entered "
+                << P->entrycount() << " times" << std::endl;
+    printProfile(RootNode, " ", " ");
+  }
+
+  void startFlatSection() override {
+    std::cout << "Entered Flat Section" << std::endl;
+  }
+
+  void writeFlat(GUID Guid, const uint64_t *Buffer,
+                 size_t BufferSize) override {
+    std::cout << "Flat: " << Guid << " " << Buffer[0];
+    for (size_t I = 1U; I < BufferSize; ++I)
+      std::cout << "," << Buffer[I];
+    std::cout << std::endl;
+  };
+
+  void endFlatSection() override {
+    std::cout << "Exited Flat Section" << std::endl;
+  }
+};
+
+// Guid:3950394326069683896 is anotherFunction
+// Guid:6759619411192316602 is someFunction
+// These are expected to be the auto-detected roots. This is because we cannot
+// discerne (with the current autodetection mechanism) if theRoot
+// (Guid:8657661246551306189) is ever re-entered.
+//
+// CHECK:      Entered Context Section
+// CHECK-NEXT: Entering Root 6759619411192316602 with total entry count 12463157
+// CHECK-NEXT: Guid: 6759619411192316602
+// CHECK-NEXT:  Entries: 5391142
+// CHECK-NEXT:  2 counters and 3 callsites
+// CHECK-NEXT:  Counter values: 5391142 1832357
+// CHECK-NEXT:  At Index 0:
+// CHECK-NEXT:   Guid: 434762725428799310
+// CHECK-NEXT:   Entries: 3558785
+// CHECK-NEXT:   1 counters and 0 callsites
+// CHECK-NEXT:   Counter values: 3558785
+// CHECK-NEXT:  At Index 1:
+// CHECK-NEXT:   Guid: 5578595117440393467
+// CHECK-NEXT:   Entries: 1832357
+// CHECK-NEXT:   1 counters and 0 callsites
+// CHECK-NEXT:   Counter values: 1832357
+// CHECK-NEXT:  At Index 2:
+// CHECK-NEXT:   Guid: 3950394326069683896
+// CHECK-NEXT:   Entries: 5391142
+// CHECK-NEXT:   1 counters and 0 callsites
+// CHECK-NEXT:   Counter values: 5391142
+// CHECK-NEXT: Entering Root 3950394326069683896 with total entry count 11226401
+// CHECK-NEXT:  Guid: 3950394326069683896
+// CHECK-NEXT:  Entries: 10767423
+// CHECK-NEXT:  1 counters and 0 callsites
+// CHECK-NEXT:  Counter values: 10767423
+// CHECK-NEXT: Exited Context Section
+// CHECK-NEXT: Entered Flat Section
+// CHECK-NEXT: Flat: 2597020043743142491 1
+// CHECK-NEXT: Flat: 4321328481998485159 1
+// CHECK-NEXT: Flat: 8657661246551306189 9114175,18099613
+// CHECK-NEXT: Flat: 434762725428799310 10574815
+// CHECK-NEXT: Flat: 5578595117440393467 5265754
+// CHECK-NEXT: Flat: 12566320182004153844 1
+// CHECK-NEXT: Exited Flat Section
+
+bool profileWriter() {
+  TestProfileWriter W;
+  return __llvm_ctx_profile_fetch(W);
+}
+
+int main(int argc, char **argv) {
+  std::atomic<bool> Stop = false;
+  std::atomic<int> Started = 0;
+  std::thread T1([&]() {
+    ++Started;
+    while (!Stop) {
+      theRoot();
+    }
+  });
+
+  std::thread T2([&]() {
+    ++Started;
+    while (!Stop) {
+      theRoot();
+    }
+  });
+
+  std::thread T3([&]() {
+    while (Started < 2) {
+    }
+    __llvm_ctx_profile_start_collection(5);
+  });
+
+  T3.join();
+  using namespace std::chrono_literals;
+
+  std::this_thread::sleep_for(10s);
+  Stop = true;
+  T1.join();
+  T2.join();
+
+  // This would be implemented in a specific RPC handler, but here we just call
+  // it directly.
+  return !profileWriter();
+}
diff --git a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
index 3dc53637a35d8..08f366cbcd51a 100644
--- a/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
+++ b/compiler-rt/test/ctx_profile/TestCases/generate-context.cpp
@@ -16,7 +16,8 @@
 #include <iostream>
 
 using namespace llvm::ctx_profile;
-extern "C" void __llvm_ctx_profile_start_collection();
+extern "C" void
+__llvm_ctx_profile_start_collection(unsigned AutodetectDuration = 0);
 extern "C" bool __llvm_ctx_profile_fetch(ProfileWriter &);
 
 // avoid name mangling
@@ -97,7 +98,7 @@ class TestProfileWriter : public ProfileWriter {
     for (const auto *P = Unhandled; P; P = P->next())
       std::cout << "Unhandled GUID: " << P->guid() << " entered "
                 << P->entrycount() << " times" << std::endl;
-    printProfile(RootNode, "", "");
+    printProfile(RootNode, " ", " ");
   }
 
   void startFlatSection() override {
diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
index a42bf9ebb01ea..55423d95b3088 100644
--- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
+++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
@@ -127,6 +127,7 @@ class ContextNode final {
 /// MUTEXDECL takes one parameter, the name of a field that is a mutex.
 #define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL)            \
   PTRDECL(FunctionData, Next)                                                  \
+  VOLATILE_PTRDECL(void, EntryAddress)                                         \
   VOLATILE_PTRDECL(ContextRoot, CtxRoot)                                       \
   VOLATILE_PTRDECL(ContextNode, FlatCtx)                                       \
   MUTEXDECL(Mutex)
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index 58748a19db972..7b36661c9a404 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -219,6 +219,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
   Value *TheRootFuctionData = nullptr;
   Value *ExpectedCalleeTLSAddr = nullptr;
   Value *CallsiteInfoTLSAddr = nullptr;
+  const bool HasMusttail = [&]() {
+    for (auto &BB : F)
+      for (auto &I : BB)
+        if (auto *CB = dyn_cast<CallBase>(&I))
+          if (CB->isMustTailCall())
+            return true;
+    return false;
+  }();
 
   auto &Head = F.getEntryBlock();
   for (auto &I : Head) {
@@ -243,19 +251,18 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
       // regular function)
       // Don't set a name, they end up taking a lot of space and we don't need
       // them.
-      auto *FData = new GlobalVariable(M, FunctionDataTy, false,
-                                       GlobalVariable::InternalLinkage,
-                                       Constant::getNullValue(FunctionDataTy));
+      TheRootFuctionData = new GlobalVariable(
+          M, FunctionDataTy, false, GlobalVariable::InternalLinkage,
+          Constant::getNullValue(FunctionDataTy));
 
       if (ContextRootSet.contains(&F)) {
         Context = Builder.CreateCall(
-            StartCtx, {FData, Guid, Builder.getInt32(NumCounters),
+            StartCtx, {TheRootFuctionData, Guid, Builder.getInt32(NumCounters),
                        Builder.getInt32(NumCallsites)});
-        TheRootFuctionData = FData;
         ORE.emit(
             [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
       } else {
-        Context = Builder.CreateCall(GetCtx, {FData, &F, Guid,
+        Context = Builder.CreateCall(GetCtx, {TheRootFuctionData, &F, Guid,
                                               Builder.getInt32(NumCounters),
                                               Builder.getInt32(NumCallsites)});
         ORE.emit([&] {
@@ -339,7 +346,7 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
           break;
         }
         I.eraseFromParent();
-      } else if (TheRootFuctionData && isa<ReturnInst>(I)) {
+      } else if (!HasMusttail && isa<ReturnInst>(I)) {
         // Remember to release the context if we are an entrypoint.
         IRBuilder<> Builder(&I);
         Builder.CreateCall(ReleaseCtx, {TheRootFuctionData});
@@ -351,9 +358,10 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
   // to disallow this, (so this then stays as an error), another is to detect
   // that and then do a wrapper or disallow the tail call. This only affects
   // instrumentation, when we want to detect the call graph.
-  if (TheRootFuctionData && !ContextWasReleased)
+  if (!HasMusttail && !ContextWasReleased)
     F.getContext().emitError(
-        "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
+        "[ctx_prof] A function that doesn't have musttail calls was "
+        "instrumented but it has no `ret` "
         "instructions above which to release the context: " +
         F.getName());
   return true;
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index ed3cb0824c504..75f292deb71c2 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -11,13 +11,14 @@ declare void @bar()
 ;.
 ; LOWERING: @__llvm_ctx_profile_callsite = external hidden thread_local global ptr
 ; LOWERING: @__llvm_ctx_profile_expected_callee = external hidden thread_local global ptr
-; LOWERING: @[[GLOB0:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB1:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB2:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB3:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB4:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB5:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB6:[0-9]+]] = internal global { ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB0:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB1:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB2:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB3:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB4:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB5:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB6:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB7:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
 ;.
 define void @foo(i32 %a, ptr %fct) {
 ; INSTRUMENT-LABEL: define void @foo(
@@ -67,6 +68,7 @@ define void @foo(i32 %a, ptr %fct) {
 ; LOWERING-NEXT:    call void @bar()
 ; LOWERING-NEXT:    br label [[EXIT]]
 ; LOWERING:       exit:
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB0]])
 ; LOWERING-NEXT:    ret void
 ;
   %t = icmp eq i32 %a, 0
@@ -185,6 +187,7 @@ define void @simple(i32 %a) {
 ; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
 ; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
 ; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB3]])
 ; LOWERING-NEXT:    ret void
 ;
   ret void
@@ -216,8 +219,10 @@ define i32 @no_callsites(i32 %a) {
 ; LOWERING-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
 ; LOWERING-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
 ; LOWERING-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB4]])
 ; LOWERING-NEXT:    ret i32 1
 ; LOWERING:       no:
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB4]])
 ; LOWERING-NEXT:    ret i32 0
 ;
   %c = icmp eq i32 %a, 0
@@ -250,6 +255,7 @@ define void @no_counters() {
 ; LOWERING-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [1 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
 ; LOWERING-NEXT:    store volatile ptr [[TMP10]], ptr [[TMP7]], align 8
 ; LOWERING-NEXT:    call void @bar()
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB5]])
 ; LOWERING-NEXT:    ret void
 ;
   call void @bar()
@@ -270,11 +276,40 @@ define void @inlineasm() {
 ; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
 ; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
 ; LOWERING-NEXT:    call void asm "nop", ""()
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @[[GLOB6]])
 ; LOWERING-NEXT:    ret void
 ;
   call void asm "nop", ""()
   ret void
 }
+
+define void @has_musttail_calls() {
+; INSTRUMENT-LABEL: define void @has_musttail_calls() {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @has_musttail_calls, i64 742261418966908927, i32 1, i32 0)
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.callsite(ptr @has_musttail_calls, i64 742261418966908927, i32 1, i32 0, ptr @bar)
+; INSTRUMENT-NEXT:    musttail call void @bar()
+; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @has_musttail_calls(
+; LOWERING-SAME: ) !guid [[META7:![0-9]+]] {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @[[GLOB7]], ptr @has_musttail_calls, i64 -4680624981836544329, i32 1, i32 1)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; LOWERING-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_expected_callee)
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_callsite)
+; LOWERING-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; LOWERING-NEXT:    store volatile ptr @bar, ptr [[TMP5]], align 8
+; LOWERING-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [1 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; LOWERING-NEXT:    store volatile ptr [[TMP10]], ptr [[TMP7]], align 8
+; LOWERING-NEXT:    musttail call void @bar()
+; LOWERING-NEXT:    ret void
+;
+  musttail call void @bar()
+  ret void
+}
 ;.
 ; LOWERING: attributes #[[ATTR0:[0-9]+]] = { nounwind }
 ; LOWERING: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
@@ -288,4 +323,5 @@ define void @inlineasm() {
 ; LOWERING: [[META4]] = !{i64 5679753335911435902}
 ; LOWERING: [[META5]] = !{i64 5458232184388660970}
 ; LOWERING: [[META6]] = !{i64 -3771893999295659109}
+; LOWERING: [[META7]] = !{i64 -4680624981836544329}
 ;.



More information about the llvm-commits mailing list