[compiler-rt] [llvm] [ctxprof] Handle instrumenting functions with `musttail` calls (PR #135121)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 14 08:11:59 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/135121

>From 5d3c9e9955b02241ed8d3b831f9ad0edd74f795c Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 8 Apr 2025 14:09:51 -0700
Subject: [PATCH] [ctxprof] Handle musttail

---
 .../lib/ctx_profile/CtxInstrContextNode.h     |  5 +-
 .../lib/ctx_profile/CtxInstrProfiling.cpp     | 32 +++++++----
 .../lib/ctx_profile/CtxInstrProfiling.h       |  9 ++-
 .../lib/ctx_profile/RootAutoDetector.cpp      | 11 +++-
 .../tests/CtxInstrProfilingTest.cpp           | 10 ++++
 .../llvm/ProfileData/CtxInstrContextNode.h    |  5 +-
 .../Instrumentation/PGOCtxProfLowering.cpp    | 57 ++++++++++++++-----
 .../PGOProfile/ctx-instrumentation.ll         |  2 +-
 8 files changed, 98 insertions(+), 33 deletions(-)

diff --git a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
index 55423d95b3088..e4e310b2e987d 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrContextNode.h
@@ -125,10 +125,11 @@ class ContextNode final {
 /// VOLATILE_PTRDECL is the same as above, but for volatile pointers;
 ///
 /// MUTEXDECL takes one parameter, the name of a field that is a mutex.
-#define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL)            \
+#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, VOLATILE_PTRDECL,          \
+                              MUTEXDECL)                                       \
   PTRDECL(FunctionData, Next)                                                  \
   VOLATILE_PTRDECL(void, EntryAddress)                                         \
-  VOLATILE_PTRDECL(ContextRoot, CtxRoot)                                       \
+  CONTEXT_PTR                                                                  \
   VOLATILE_PTRDECL(ContextNode, FlatCtx)                                       \
   MUTEXDECL(Mutex)
 
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
index 4cf852fe3f667..2d173f0fcb19a 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
@@ -272,6 +272,8 @@ void setupContext(ContextRoot *Root, GUID Guid, uint32_t NumCounters,
 
 ContextRoot *FunctionData::getOrAllocateContextRoot() {
   auto *Root = CtxRoot;
+  if (!canBeRoot(Root))
+    return Root;
   if (Root)
     return Root;
   __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> L(&Mutex);
@@ -328,8 +330,10 @@ ContextNode *getUnhandledContext(FunctionData &Data, void *Callee, GUID Guid,
   if (!CtxRoot) {
     if (auto *RAD = getRootDetector())
       RAD->sample();
-    else if (auto *CR = Data.CtxRoot)
-      return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
+    else if (auto *CR = Data.CtxRoot) {
+      if (canBeRoot(CR))
+        return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
+    }
     if (IsUnderContext || !__sanitizer::atomic_load_relaxed(&ProfilingStarted))
       return TheScratchContext;
     else
@@ -404,20 +408,21 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
 ContextNode *__llvm_ctx_profile_start_context(FunctionData *FData, GUID Guid,
                                               uint32_t Counters,
                                               uint32_t Callsites) {
-
-  return tryStartContextGivenRoot(FData->getOrAllocateContextRoot(), Guid,
-                                  Counters, Callsites);
+  auto *Root = FData->getOrAllocateContextRoot();
+  assert(canBeRoot(Root));
+  return tryStartContextGivenRoot(Root, Guid, Counters, Callsites);
 }
 
 void __llvm_ctx_profile_release_context(FunctionData *FData)
     SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
   const auto *CurrentRoot = __llvm_ctx_profile_current_context_root;
-  if (!CurrentRoot || FData->CtxRoot != CurrentRoot)
+  auto *CR = FData->CtxRoot;
+  if (!CurrentRoot || CR != CurrentRoot)
     return;
   IsUnderContext = false;
-  assert(FData->CtxRoot);
+  assert(CR && canBeRoot(CR));
   __llvm_ctx_profile_current_context_root = nullptr;
-  FData->CtxRoot->Taken.Unlock();
+  CR->Taken.Unlock();
 }
 
 void __llvm_ctx_profile_start_collection(unsigned AutodetectDuration) {
@@ -481,10 +486,13 @@ bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
   // traversing it.
   const auto *Pos = reinterpret_cast<const FunctionData *>(
       __sanitizer::atomic_load_relaxed(&AllFunctionsData));
-  for (; Pos; Pos = Pos->Next)
-    if (!Pos->CtxRoot)
-      Writer.writeFlat(Pos->FlatCtx->guid(), Pos->FlatCtx->counters(),
-                       Pos->FlatCtx->counters_size());
+  for (; Pos; Pos = Pos->Next) {
+    const auto *CR = Pos->CtxRoot;
+    if (!CR && canBeRoot(CR)) {
+      const auto *FP = Pos->FlatCtx;
+      Writer.writeFlat(FP->guid(), FP->counters(), FP->counters_size());
+    }
+  }
   Writer.endFlatSection();
   return true;
 }
diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
index 4983f086d230d..9ca6e769e216f 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
@@ -145,7 +145,9 @@ struct FunctionData {
 #define _PTRDECL(T, N) T *N = nullptr;
 #define _VOLATILE_PTRDECL(T, N) T *volatile N = nullptr;
 #define _MUTEXDECL(N) ::__sanitizer::SpinMutex N;
-  CTXPROF_FUNCTION_DATA(_PTRDECL, _VOLATILE_PTRDECL, _MUTEXDECL)
+#define _CONTEXT_PTR ContextRoot *CtxRoot = nullptr;
+  CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_PTR, _VOLATILE_PTRDECL, _MUTEXDECL)
+#undef _CONTEXT_PTR
 #undef _PTRDECL
 #undef _VOLATILE_PTRDECL
 #undef _MUTEXDECL
@@ -167,6 +169,11 @@ inline bool isScratch(const void *Ctx) {
   return (reinterpret_cast<uint64_t>(Ctx) & 1);
 }
 
+// True if Ctx is either nullptr or not the 0x1 value.
+inline bool canBeRoot(const ContextRoot *Ctx) {
+  return reinterpret_cast<uintptr_t>(Ctx) != 1U;
+}
+
 } // namespace __ctx_profile
 
 extern "C" {
diff --git a/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp b/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
index 4aa169e202ea3..7bb3bbc63bd6e 100644
--- a/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
+++ b/compiler-rt/lib/ctx_profile/RootAutoDetector.cpp
@@ -66,7 +66,16 @@ void RootAutoDetector::start() {
                  atomic_load_relaxed(&RAD->FunctionDataListHead));
              FD; FD = FD->Next) {
           if (AllRoots.contains(reinterpret_cast<uptr>(FD->EntryAddress))) {
-            FD->getOrAllocateContextRoot();
+            if (canBeRoot(FD->CtxRoot)) {
+              FD->getOrAllocateContextRoot();
+            } else {
+              // FIXME: address this by informing the root detection algorithm
+              // to skip over such functions and pick the next down in the
+              // stack. At that point, this becomes an assert.
+              Printf("[ctxprof] Root auto-detector selected a musttail "
+                     "function for root (%p). Ignoring\n",
+                     FD->EntryAddress);
+            }
           }
         }
         atomic_store_relaxed(&RAD->Self, 0);
diff --git a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
index 83756fed0d6e6..80a9a96f2a16b 100644
--- a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
+++ b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
@@ -301,3 +301,13 @@ TEST_F(ContextTest, Dump) {
   EXPECT_EQ(W2.FlatsWritten, 1);
   EXPECT_EQ(W2.ExitedFlatCount, 1);
 }
+
+TEST_F(ContextTest, MustNotBeRoot) {
+  FunctionData FData;
+  FData.CtxRoot = reinterpret_cast<ContextRoot *>(1U);
+  int FakeCalleeAddress = 0;
+  __llvm_ctx_profile_start_collection();
+  auto *Subctx =
+      __llvm_ctx_profile_get_context(&FData, &FakeCalleeAddress, 2, 3, 1);
+  EXPECT_TRUE(isScratch(Subctx));
+}
diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
index 55423d95b3088..e4e310b2e987d 100644
--- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
+++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
@@ -125,10 +125,11 @@ class ContextNode final {
 /// VOLATILE_PTRDECL is the same as above, but for volatile pointers;
 ///
 /// MUTEXDECL takes one parameter, the name of a field that is a mutex.
-#define CTXPROF_FUNCTION_DATA(PTRDECL, VOLATILE_PTRDECL, MUTEXDECL)            \
+#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, VOLATILE_PTRDECL,          \
+                              MUTEXDECL)                                       \
   PTRDECL(FunctionData, Next)                                                  \
   VOLATILE_PTRDECL(void, EntryAddress)                                         \
-  VOLATILE_PTRDECL(ContextRoot, CtxRoot)                                       \
+  CONTEXT_PTR                                                                  \
   VOLATILE_PTRDECL(ContextNode, FlatCtx)                                       \
   MUTEXDECL(Mutex)
 
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index a314457423819..f99d7b9d03e02 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -12,9 +12,11 @@
 #include "llvm/Analysis/CtxProfAnalysis.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Analysis.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/GlobalValue.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
@@ -55,7 +57,7 @@ class CtxInstrumentationLowerer final {
   Module &M;
   ModuleAnalysisManager &MAM;
   Type *ContextNodeTy = nullptr;
-  Type *FunctionDataTy = nullptr;
+  StructType *FunctionDataTy = nullptr;
 
   DenseSet<const Function *> ContextRootSet;
   Function *StartCtx = nullptr;
@@ -63,6 +65,7 @@ class CtxInstrumentationLowerer final {
   Function *ReleaseCtx = nullptr;
   GlobalVariable *ExpectedCalleeTLS = nullptr;
   GlobalVariable *CallsiteInfoTLS = nullptr;
+  Constant *CannotBeRootInitializer = nullptr;
 
 public:
   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
@@ -117,12 +120,29 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
 
 #define _PTRDECL(_, __) PointerTy,
 #define _VOLATILE_PTRDECL(_, __) PointerTy,
+#define _CONTEXT_ROOT PointerTy,
 #define _MUTEXDECL(_) SanitizerMutexType,
 
   FunctionDataTy = StructType::get(
-      M.getContext(),
-      {CTXPROF_FUNCTION_DATA(_PTRDECL, _VOLATILE_PTRDECL, _MUTEXDECL)});
+      M.getContext(), {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
+                                             _VOLATILE_PTRDECL, _MUTEXDECL)});
 #undef _PTRDECL
+#undef _CONTEXT_ROOT
+#undef _VOLATILE_PTRDECL
+#undef _MUTEXDECL
+
+#define _PTRDECL(_, __) Constant::getNullValue(PointerTy),
+#define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __)
+#define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType),
+#define _CONTEXT_ROOT                                                          \
+  Constant::getIntegerValue(                                                   \
+      PointerTy,                                                               \
+      APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)),
+  CannotBeRootInitializer = ConstantStruct::get(
+      FunctionDataTy, {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
+                                             _VOLATILE_PTRDECL, _MUTEXDECL)});
+#undef _PTRDECL
+#undef _CONTEXT_ROOT
 #undef _VOLATILE_PTRDECL
 #undef _MUTEXDECL
 
@@ -134,8 +154,8 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
                                                       I32Ty, /*NumCallsites*/
                                                   });
 
-  // Define a global for each entrypoint. We'll reuse the entrypoint's name as
-  // prefix. We assume the entrypoint names to be unique.
+  // Define a global for each entrypoint. We'll reuse the entrypoint's name
+  // as prefix. We assume the entrypoint names to be unique.
   for (const auto &Fname : ContextRoots) {
     if (const auto *F = M.getFunction(Fname)) {
       if (F->isDeclaration())
@@ -145,10 +165,10 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
         for (const auto &I : BB)
           if (const auto *CB = dyn_cast<CallBase>(&I))
             if (CB->isMustTailCall()) {
-              M.getContext().emitError(
-                  "The function " + Fname +
-                  " was indicated as a context root, but it features musttail "
-                  "calls, which is not supported.");
+              M.getContext().emitError("The function " + Fname +
+                                       " was indicated as a context root, "
+                                       "but it features musttail "
+                                       "calls, which is not supported.");
             }
     }
   }
@@ -240,6 +260,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
     return false;
   }();
 
+  if (HasMusttail && ContextRootSet.contains(&F)) {
+    F.getContext().emitError(
+        "[ctx_prof] A function with musttail calls was explicitly requested as "
+        "root. That is not supported because we cannot instrument a return "
+        "instruction to release the context: " +
+        F.getName());
+    return false;
+  }
   auto &Head = F.getEntryBlock();
   for (auto &I : Head) {
     // Find the increment intrinsic in the entry basic block.
@@ -263,9 +291,14 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
       // regular function)
       // Don't set a name, they end up taking a lot of space and we don't need
       // them.
+
+      // Zero-initialize the FunctionData, except for functions that have
+      // musttail calls. There, we set the CtxRoot field to 1, which will be
+      // treated as a "can't be set as root".
       TheRootFuctionData = new GlobalVariable(
           M, FunctionDataTy, false, GlobalVariable::InternalLinkage,
-          Constant::getNullValue(FunctionDataTy));
+          HasMusttail ? CannotBeRootInitializer
+                      : Constant::getNullValue(FunctionDataTy));
 
       if (ContextRootSet.contains(&F)) {
         Context = Builder.CreateCall(
@@ -366,10 +399,6 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
       }
     }
   }
-  // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
-  // to disallow this, (so this then stays as an error), another is to detect
-  // that and then do a wrapper or disallow the tail call. This only affects
-  // instrumentation, when we want to detect the call graph.
   if (!HasMusttail && !ContextWasReleased)
     F.getContext().emitError(
         "[ctx_prof] A function that doesn't have musttail calls was "
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index 8f72711a9c8b1..6b2f25a585ec3 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -18,7 +18,7 @@ declare void @bar()
 ; LOWERING: @[[GLOB4:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @[[GLOB5:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @[[GLOB6:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB7:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB7:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } { ptr null, ptr null, ptr inttoptr (i64 1 to ptr), ptr null, i8 0 }
 ;.
 define void @foo(i32 %a, ptr %fct) {
 ; INSTRUMENT-LABEL: define void @foo(



More information about the llvm-commits mailing list