[compiler-rt] [llvm] Reentry (PR #135656)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 16 10:41:12 PDT 2025


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/135656

>From b10c2e29194224dac69f6305d8b4296ca61fd363 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Mon, 14 Apr 2025 07:19:58 -0700
Subject: [PATCH] Reentry

---
 .../lib/ctx_profile/CtxInstrProfiling.cpp     | 151 ++++++++++++------
 .../tests/CtxInstrProfilingTest.cpp           | 115 ++++++++++++-
 .../llvm/ProfileData/CtxInstrContextNode.h    |   6 +-
 .../Instrumentation/PGOCtxProfLowering.cpp    |  82 ++++++----
 .../PGOProfile/ctx-instrumentation.ll         |   4 +-
 5 files changed, 269 insertions(+), 89 deletions(-)

diff --git a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
index 2d173f0fcb19a..2e26541c1acea 100644
--- a/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
+++ b/compiler-rt/lib/ctx_profile/CtxInstrProfiling.cpp
@@ -41,7 +41,44 @@ Arena *FlatCtxArena = nullptr;
 
 // Set to true when we enter a root, and false when we exit - regardless if this
 // thread collects a contextual profile for that root.
-__thread bool IsUnderContext = false;
+__thread int UnderContextRefCount = 0;
+__thread void *volatile EnteredContextAddress = 0;
+
+void onFunctionEntered(void *Address) {
+  UnderContextRefCount += (Address == EnteredContextAddress);
+  assert(UnderContextRefCount > 0);
+}
+
+void onFunctionExited(void *Address) {
+  UnderContextRefCount -= (Address == EnteredContextAddress);
+  assert(UnderContextRefCount >= 0);
+}
+
+// Returns true if it was entered the first time
+bool rootEnterIsFirst(void* Address) {
+  bool Ret = true;
+  if (!EnteredContextAddress) {
+    EnteredContextAddress = Address;
+    assert(UnderContextRefCount == 0);
+    Ret = true;
+  }
+  onFunctionEntered(Address);
+  return Ret;
+}
+
+// Return true if this also exits the root.
+bool exitsRoot(void* Address) {
+  onFunctionExited(Address);
+  if (UnderContextRefCount == 0) {
+    EnteredContextAddress = nullptr;
+    return true;
+  }
+  return false;
+
+}
+
+bool hasEnteredARoot() { return UnderContextRefCount > 0; }
+
 __sanitizer::atomic_uint8_t ProfilingStarted = {};
 
 __sanitizer::atomic_uintptr_t RootDetector = {};
@@ -287,62 +324,65 @@ ContextRoot *FunctionData::getOrAllocateContextRoot() {
   return Root;
 }
 
-ContextNode *tryStartContextGivenRoot(ContextRoot *Root, GUID Guid,
-                                      uint32_t Counters, uint32_t Callsites)
-    SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
-  IsUnderContext = true;
-  __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
-                                __sanitizer::memory_order_relaxed);
+ContextNode *tryStartContextGivenRoot(
+    ContextRoot *Root, void *EntryAddress, GUID Guid, uint32_t Counters,
+    uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+
+  if (rootEnterIsFirst(EntryAddress))
+    __sanitizer::atomic_fetch_add(&Root->TotalEntries, 1,
+                                  __sanitizer::memory_order_relaxed);
   if (!Root->FirstMemBlock) {
     setupContext(Root, Guid, Counters, Callsites);
   }
   if (Root->Taken.TryLock()) {
+    assert(__llvm_ctx_profile_current_context_root == nullptr);
     __llvm_ctx_profile_current_context_root = Root;
     onContextEnter(*Root->FirstNode);
     return Root->FirstNode;
   }
   // If this thread couldn't take the lock, return scratch context.
-  __llvm_ctx_profile_current_context_root = nullptr;
   return TheScratchContext;
 }
 
+ContextNode *getOrStartContextOutsideCollection(FunctionData &Data,
+                                                ContextRoot *OwnCtxRoot,
+                                                void *Callee, GUID Guid,
+                                                uint32_t NumCounters,
+                                                uint32_t NumCallsites) {
+  // This must only be called when __llvm_ctx_profile_current_context_root is
+  // null.
+  assert(__llvm_ctx_profile_current_context_root == nullptr);
+  // OwnCtxRoot is Data.CtxRoot. Since it's volatile, and is used by the caller,
+  // pre-load it.
+  assert(Data.CtxRoot == OwnCtxRoot);
+  // If we have a root detector, try sampling.
+  // Otherwise - regardless if we started profiling or not, if Data.CtxRoot is
+  // allocated, try starting a context tree - basically, as-if
+  // __llvm_ctx_profile_start_context were called.
+  if (auto *RAD = getRootDetector())
+    RAD->sample();
+  else if (reinterpret_cast<uintptr_t>(OwnCtxRoot) > 1)
+    return tryStartContextGivenRoot(OwnCtxRoot, Data.EntryAddress, Guid,
+                                    NumCounters, NumCallsites);
+
+  // If we didn't start profiling, or if we are under a context, just not
+  // collecting, return the scratch buffer.
+  if (hasEnteredARoot() ||
+      !__sanitizer::atomic_load_relaxed(&ProfilingStarted))
+    return TheScratchContext;
+  return markAsScratch(
+      onContextEnter(*getFlatProfile(Data, Callee, Guid, NumCounters)));
+}
+
 ContextNode *getUnhandledContext(FunctionData &Data, void *Callee, GUID Guid,
                                  uint32_t NumCounters, uint32_t NumCallsites,
-                                 ContextRoot *CtxRoot) {
-
-  // 1) if we are currently collecting a contextual profile, fetch a ContextNode
-  // in the `Unhandled` set. We want to do this regardless of `ProfilingStarted`
-  // to (hopefully) offset the penalty of creating these contexts to before
-  // profiling.
-  //
-  // 2) if we are under a root (regardless if this thread is collecting or not a
-  // contextual profile for that root), do not collect a flat profile. We want
-  // to keep flat profiles only for activations that can't happen under a root,
-  // to avoid confusing profiles. We can, for example, combine flattened and
-  // flat profiles meaningfully, as we wouldn't double-count anything.
-  //
-  // 3) to avoid lengthy startup, don't bother with flat profiles until the
-  // profiling has started. We would reset them anyway when profiling starts.
-  // HOWEVER. This does lose profiling for message pumps: those functions are
-  // entered once and never exit. They should be assumed to be entered before
-  // profiling starts - because profiling should start after the server is up
-  // and running (which is equivalent to "message pumps are set up").
-  if (!CtxRoot) {
-    if (auto *RAD = getRootDetector())
-      RAD->sample();
-    else if (auto *CR = Data.CtxRoot) {
-      if (canBeRoot(CR))
-        return tryStartContextGivenRoot(CR, Guid, NumCounters, NumCallsites);
-    }
-    if (IsUnderContext || !__sanitizer::atomic_load_relaxed(&ProfilingStarted))
-      return TheScratchContext;
-    else
-      return markAsScratch(
-          onContextEnter(*getFlatProfile(Data, Callee, Guid, NumCounters)));
-  }
-  auto [Iter, Ins] = CtxRoot->Unhandled.insert({Guid, nullptr});
+                                 ContextRoot &CtxRoot) {
+  // This nust only be called when
+  // __llvm_ctx_profile_current_context_root is not null
+  assert(__llvm_ctx_profile_current_context_root != nullptr);
+  auto [Iter, Ins] = CtxRoot.Unhandled.insert({Guid, nullptr});
   if (Ins)
-    Iter->second = getCallsiteSlow(Guid, &CtxRoot->FirstUnhandledCalleeNode,
+    Iter->second = getCallsiteSlow(Guid, &CtxRoot.FirstUnhandledCalleeNode,
                                    NumCounters, 0);
   return markAsScratch(onContextEnter(*Iter->second));
 }
@@ -351,10 +391,13 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
                                             GUID Guid, uint32_t NumCounters,
                                             uint32_t NumCallsites) {
   auto *CtxRoot = __llvm_ctx_profile_current_context_root;
-  // fast "out" if we're not even doing contextual collection.
+  auto *OwnCtxRoot = Data->CtxRoot;
   if (!CtxRoot)
-    return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
-                               nullptr);
+    return getOrStartContextOutsideCollection(*Data, OwnCtxRoot, Callee, Guid,
+                                              NumCounters, NumCallsites);
+  onFunctionEntered(Callee);
+  assert(canBeRoot(CtxRoot));
+  // should we re-enter the root we're currently collecting,
 
   // also fast "out" if the caller is scratch. We can see if it's scratch by
   // looking at the interior pointer into the subcontexts vector that the caller
@@ -364,7 +407,7 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
   auto **CallsiteContext = consume(__llvm_ctx_profile_callsite[0]);
   if (!CallsiteContext || isScratch(CallsiteContext))
     return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
-                               CtxRoot);
+                               *CtxRoot);
 
   // if the callee isn't the expected one, return scratch.
   // Signal handler(s) could have been invoked at any point in the execution.
@@ -383,7 +426,7 @@ ContextNode *__llvm_ctx_profile_get_context(FunctionData *Data, void *Callee,
   auto *ExpectedCallee = consume(__llvm_ctx_profile_expected_callee[0]);
   if (ExpectedCallee != Callee)
     return getUnhandledContext(*Data, Callee, Guid, NumCounters, NumCallsites,
-                               CtxRoot);
+                               *CtxRoot);
 
   auto *Callsite = *CallsiteContext;
   // in the case of indirect calls, we will have all seen targets forming a
@@ -410,16 +453,20 @@ ContextNode *__llvm_ctx_profile_start_context(FunctionData *FData, GUID Guid,
                                               uint32_t Callsites) {
   auto *Root = FData->getOrAllocateContextRoot();
   assert(canBeRoot(Root));
-  return tryStartContextGivenRoot(Root, Guid, Counters, Callsites);
+  auto *EntryAddress = FData->EntryAddress;
+  return tryStartContextGivenRoot(Root, EntryAddress, Guid, Counters,
+                                  Callsites);
 }
 
 void __llvm_ctx_profile_release_context(FunctionData *FData)
     SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  if (!exitsRoot(FData->EntryAddress))
+    return;
   const auto *CurrentRoot = __llvm_ctx_profile_current_context_root;
   auto *CR = FData->CtxRoot;
   if (!CurrentRoot || CR != CurrentRoot)
     return;
-  IsUnderContext = false;
+
   assert(CR && canBeRoot(CR));
   __llvm_ctx_profile_current_context_root = nullptr;
   CR->Taken.Unlock();
@@ -500,6 +547,10 @@ bool __llvm_ctx_profile_fetch(ProfileWriter &Writer) {
 void __llvm_ctx_profile_free() {
   __sanitizer::atomic_store_relaxed(&ProfilingStarted, false);
   {
+    __sanitizer::atomic_store_relaxed(&ProfilingStarted, false);
+    if (auto *RD = getRootDetector()) {
+      RD->join();
+    }
     __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
         &AllContextsMutex);
     for (int I = 0, E = AllContextRoots.Size(); I < E; ++I)
@@ -522,5 +573,7 @@ void __llvm_ctx_profile_free() {
     }
 
     FlatCtxArenaHead = nullptr;
+    UnderContextRefCount = 0;
+    EnteredContextAddress = nullptr;
   }
 }
diff --git a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
index 80a9a96f2a16b..39a225ac1dd93 100644
--- a/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
+++ b/compiler-rt/lib/ctx_profile/tests/CtxInstrProfilingTest.cpp
@@ -3,13 +3,27 @@
 #include <thread>
 
 using namespace __ctx_profile;
+using namespace __asan;
 
 class ContextTest : public ::testing::Test {
-  void SetUp() override { Root.getOrAllocateContextRoot(); }
+  int SomethingWithAddress = 0;
+  void SetUp() override {
+    Root.EntryAddress = &SomethingWithAddress;
+    Root.getOrAllocateContextRoot();
+  }
   void TearDown() override { __llvm_ctx_profile_free(); }
 
 public:
   FunctionData Root;
+  void initializeFData(std::vector<FunctionData> &FData,
+                       const std::vector<int> &FuncAddresses, bool AsRoots) {
+    ASSERT_EQ(FData.size(), FuncAddresses.size());
+    for (size_t I = 0, E = FData.size(); I < E; ++I) {
+      FData[I].EntryAddress = &FuncAddresses[I];
+      if (AsRoots)
+        FData[I].getOrAllocateContextRoot();
+    }
+  }
 };
 
 TEST(ArenaTest, ZeroInit) {
@@ -85,7 +99,11 @@ TEST_F(ContextTest, Callsite) {
 
   EXPECT_EQ(Subctx->size(), sizeof(ContextNode) + 3 * sizeof(uint64_t) +
                                 1 * sizeof(ContextNode *));
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&FData);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
   __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
 }
 
 TEST_F(ContextTest, ScratchNoCollectionProfilingNotStarted) {
@@ -122,11 +140,41 @@ TEST_F(ContextTest, ScratchNoCollectionProfilingStarted) {
   EXPECT_NE(FData.FlatCtx, nullptr);
   EXPECT_EQ(reinterpret_cast<uintptr_t>(FData.FlatCtx) + 1,
             reinterpret_cast<uintptr_t>(Ctx));
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+  __llvm_ctx_profile_release_context(&FData);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+}
+
+TEST_F(ContextTest, RootCallingRootDoesNotChangeCurrentContext) {
+  ASSERT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+  int FakeCalleeAddress[2]{0, 0};
+  FunctionData FData[2];
+  FData[0].EntryAddress = &FakeCalleeAddress[0];
+  FData[1].EntryAddress = &FakeCalleeAddress[1];
+  FData[0].getOrAllocateContextRoot();
+  FData[1].getOrAllocateContextRoot();
+  __llvm_ctx_profile_start_collection();
+  auto *Ctx1 = __llvm_ctx_profile_get_context(&FData[0], &FakeCalleeAddress[0],
+                                              1234U, 1U, 1U);
+  EXPECT_EQ(Ctx1, FData[0].CtxRoot->FirstNode);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, FData[0].CtxRoot);
+
+  __llvm_ctx_profile_expected_callee[0] = &FakeCalleeAddress[0];
+  __llvm_ctx_profile_callsite[0] = &Ctx1->subContexts()[0];
+  auto *Ctx2 =
+      __llvm_ctx_profile_get_context(&FData[1], &FakeCalleeAddress[1], 2, 1, 0);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, FData[0].CtxRoot);
+  __llvm_ctx_profile_release_context(&FData[1]);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, FData[0].CtxRoot);
+  __llvm_ctx_profile_release_context(&FData[0]);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
 }
 
 TEST_F(ContextTest, ScratchDuringCollection) {
   __llvm_ctx_profile_start_collection();
   auto *Ctx = __llvm_ctx_profile_start_context(&Root, 1, 10, 4);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+
   int FakeCalleeAddress = 0;
   int OtherFakeCalleeAddress = 0;
   __llvm_ctx_profile_expected_callee[0] = &FakeCalleeAddress;
@@ -164,7 +212,71 @@ TEST_F(ContextTest, ScratchDuringCollection) {
   EXPECT_TRUE(isScratch(Subctx3));
   EXPECT_EQ(FData[2].FlatCtx, nullptr);
 
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&FData[2]);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&FData[1]);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&FData[0]);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+}
+
+TEST_F(ContextTest, RecursiveRootExplicitlyRegistered) {
+  __llvm_ctx_profile_start_collection();
+  auto *Ctx = __llvm_ctx_profile_start_context(&Root, 1, 10, 4);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+
+  auto *Subctx = __llvm_ctx_profile_start_context(&Root, 1, 10, 4);
+  EXPECT_TRUE(isScratch(Subctx));
+
+  EXPECT_EQ(__sanitizer::atomic_load_relaxed(&Root.CtxRoot->TotalEntries), 1U);
+
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+}
+
+TEST_F(ContextTest, RecursiveRootAutoDiscovered) {
+  __llvm_ctx_profile_start_collection();
+  auto *Ctx =
+      __llvm_ctx_profile_get_context(&Root, Root.EntryAddress, 1, 10, 4);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+
+  auto *Subctx =
+      __llvm_ctx_profile_get_context(&Root, Root.EntryAddress, 1, 10, 4);
+  EXPECT_TRUE(isScratch(Subctx));
+
+  EXPECT_EQ(__sanitizer::atomic_load_relaxed(&Root.CtxRoot->TotalEntries), 1U);
+
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
   __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);
+}
+
+TEST_F(ContextTest, RootEntersOtherRoot) {
+  __llvm_ctx_profile_start_collection();
+  FData Roots[2];
+  std::vector<int> Addresses(2);
+  initializeFData(Roots, Addresses, true);
+  auto *Ctx = __llvm_ctx_profile_start_context(&Roots[0], 1, 10, 4);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Roots[0].CtxRoot);
+
+  auto *Subctx = __llvm_ctx_profile_start_context(&Roots[1], 1, 10, 4);
+  EXPECT_FALSE(isScratch(Subctx));
+
+  EXPECT_EQ(__sanitizer::atomic_load_relaxed(&Root.CtxRoot->TotalEntries), 1U);
+
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, Root.CtxRoot);
+  __llvm_ctx_profile_release_context(&Root);
+  EXPECT_EQ(__llvm_ctx_profile_current_context_root, nullptr);  
 }
 
 TEST_F(ContextTest, NeedMoreMemory) {
@@ -185,6 +297,7 @@ TEST_F(ContextTest, NeedMoreMemory) {
   EXPECT_EQ(Ctx->subContexts()[2], Subctx);
   EXPECT_NE(CurrentMem, CtxRoot.CurrentMem);
   EXPECT_NE(CtxRoot.CurrentMem, nullptr);
+  __llvm_ctx_profile_release_context(&Root);
 }
 
 TEST_F(ContextTest, ConcurrentRootCollection) {
diff --git a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
index e4e310b2e987d..a33549fb327ba 100644
--- a/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
+++ b/llvm/include/llvm/ProfileData/CtxInstrContextNode.h
@@ -125,10 +125,10 @@ class ContextNode final {
 /// VOLATILE_PTRDECL is the same as above, but for volatile pointers;
 ///
 /// MUTEXDECL takes one parameter, the name of a field that is a mutex.
-#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, VOLATILE_PTRDECL,          \
-                              MUTEXDECL)                                       \
+#define CTXPROF_FUNCTION_DATA(PTRDECL, CONTEXT_PTR, ENTRY_ADDRESS,             \
+                              VOLATILE_PTRDECL, MUTEXDECL)                     \
   PTRDECL(FunctionData, Next)                                                  \
-  VOLATILE_PTRDECL(void, EntryAddress)                                         \
+  ENTRY_ADDRESS                                                                \
   CONTEXT_PTR                                                                  \
   VOLATILE_PTRDECL(ContextNode, FlatCtx)                                       \
   MUTEXDECL(Mutex)
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index d741695d4e53c..22ffb727fbcf6 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -66,7 +66,12 @@ class CtxInstrumentationLowerer final {
   Function *ReleaseCtx = nullptr;
   GlobalVariable *ExpectedCalleeTLS = nullptr;
   GlobalVariable *CallsiteInfoTLS = nullptr;
-  Constant *CannotBeRootInitializer = nullptr;
+  Type *PointerTy = nullptr;
+  Type *SanitizerMutexType = nullptr;
+  Type *I32Ty = nullptr;
+  Type *I64Ty = nullptr;
+
+  Constant *getFunctionDataInitializer(Function &F, bool HasMusttail);
 
 public:
   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
@@ -114,42 +119,55 @@ void emitUnsupportedRootError(const Function &F, StringRef Reason) {
 }
 } // namespace
 
+Constant *
+CtxInstrumentationLowerer::getFunctionDataInitializer(Function &F,
+                                                      bool HasMusttail) {
+#define _PTRDECL(_, __) Constant::getNullValue(PointerTy),
+#define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __)
+#define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType),
+#define _ENTRY_ADDRESS                                                         \
+  (ContextRootSet.contains(&F) ? &F : Constant::getNullValue(PointerTy)),
+#define _CONTEXT_ROOT                                                          \
+  (HasMusttail                                                                 \
+       ? Constant::getIntegerValue(                                            \
+             PointerTy,                                                        \
+             APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)) \
+       : Constant::getNullValue(PointerTy)),
+  return ConstantStruct::get(
+      FunctionDataTy,
+      {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT, _ENTRY_ADDRESS,
+                             _VOLATILE_PTRDECL, _MUTEXDECL)});
+#undef _PTRDECL
+#undef _CONTEXT_ROOT
+#undef _ENTRY_ADDRESS
+#undef _VOLATILE_PTRDECL
+#undef _MUTEXDECL
+}
+
 // set up tie-in with compiler-rt.
 // NOTE!!!
 // These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
 CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
                                                      ModuleAnalysisManager &MAM)
     : M(M), MAM(MAM) {
-  auto *PointerTy = PointerType::get(M.getContext(), 0);
-  auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
-  auto *I32Ty = Type::getInt32Ty(M.getContext());
-  auto *I64Ty = Type::getInt64Ty(M.getContext());
+  PointerTy = PointerType::get(M.getContext(), 0);
+  SanitizerMutexType = Type::getInt8Ty(M.getContext());
+  I32Ty = Type::getInt32Ty(M.getContext());
+  I64Ty = Type::getInt64Ty(M.getContext());
 
 #define _PTRDECL(_, __) PointerTy,
 #define _VOLATILE_PTRDECL(_, __) PointerTy,
 #define _CONTEXT_ROOT PointerTy,
+#define _ENTRY_ADDRESS PointerTy,
 #define _MUTEXDECL(_) SanitizerMutexType,
 
   FunctionDataTy = StructType::get(
-      M.getContext(), {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
-                                             _VOLATILE_PTRDECL, _MUTEXDECL)});
-#undef _PTRDECL
-#undef _CONTEXT_ROOT
-#undef _VOLATILE_PTRDECL
-#undef _MUTEXDECL
-
-#define _PTRDECL(_, __) Constant::getNullValue(PointerTy),
-#define _VOLATILE_PTRDECL(_, __) _PTRDECL(_, __)
-#define _MUTEXDECL(_) Constant::getNullValue(SanitizerMutexType),
-#define _CONTEXT_ROOT                                                          \
-  Constant::getIntegerValue(                                                   \
-      PointerTy,                                                               \
-      APInt(M.getDataLayout().getPointerTypeSizeInBits(PointerTy), 1U)),
-  CannotBeRootInitializer = ConstantStruct::get(
-      FunctionDataTy, {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT,
-                                             _VOLATILE_PTRDECL, _MUTEXDECL)});
+      M.getContext(),
+      {CTXPROF_FUNCTION_DATA(_PTRDECL, _CONTEXT_ROOT, _ENTRY_ADDRESS,
+                             _VOLATILE_PTRDECL, _MUTEXDECL)});
 #undef _PTRDECL
 #undef _CONTEXT_ROOT
+#undef _ENYTR_ADDRESS
 #undef _VOLATILE_PTRDECL
 #undef _MUTEXDECL
 
@@ -286,9 +304,8 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
       // NumCallsites and NumCounters. We delcare it here because it's more
       // convenient - we have the Builder.
       ThisContextType = StructType::get(
-          F.getContext(),
-          {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NumCounters),
-           ArrayType::get(Builder.getPtrTy(), NumCallsites)});
+          F.getContext(), {ContextNodeTy, ArrayType::get(I64Ty, NumCounters),
+                           ArrayType::get(PointerTy, NumCallsites)});
       // Figure out which way we obtain the context object for this function -
       // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
       // former case, we also set TheRootFuctionData since we need to release it
@@ -302,8 +319,7 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
       // treated as a "can't be set as root".
       TheRootFuctionData = new GlobalVariable(
           M, FunctionDataTy, false, GlobalVariable::InternalLinkage,
-          HasMusttail ? CannotBeRootInitializer
-                      : Constant::getNullValue(FunctionDataTy));
+          getFunctionDataInitializer(F, HasMusttail));
 
       if (ContextRootSet.contains(&F)) {
         Context = Builder.CreateCall(
@@ -320,7 +336,7 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
         });
       }
       // The context could be scratch.
-      auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
+      auto *CtxAsInt = Builder.CreatePtrToInt(Context, I64Ty);
       if (NumCallsites > 0) {
         // Figure out which index of the TLS 2-element buffers to use.
         // Scratch context => we use index == 1. Real contexts => index == 0.
@@ -330,8 +346,7 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
             PointerType::getUnqual(F.getContext()),
             Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
         CallsiteInfoTLSAddr = Builder.CreateGEP(
-            Builder.getInt32Ty(),
-            Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
+            I32Ty, Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
       }
       // Because the context pointer may have LSB set (to indicate scratch),
       // clear it for the value we use as base address for the counter vector.
@@ -367,10 +382,9 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
           auto *GEP = Builder.CreateGEP(
               ThisContextType, RealContext,
               {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
-          Builder.CreateStore(
-              Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
-                                AsStep->getStep()),
-              GEP);
+          Builder.CreateStore(Builder.CreateAdd(Builder.CreateLoad(I64Ty, GEP),
+                                                AsStep->getStep()),
+                              GEP);
         } break;
         case llvm::Intrinsic::instrprof_callsite:
           // callsite lowering: write the called value in the expected callee
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index 71d54f98d26e1..ffcda01c06be7 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -12,8 +12,8 @@ declare void @bar()
 ; LOWERING: @__llvm_ctx_profile_callsite = external hidden thread_local global ptr
 ; LOWERING: @__llvm_ctx_profile_expected_callee = external hidden thread_local global ptr
 ; LOWERING: @[[GLOB0:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB1:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
-; LOWERING: @[[GLOB2:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @[[GLOB1:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } { ptr null, ptr @an_entrypoint, ptr null, ptr null, i8 0 }
+; LOWERING: @[[GLOB2:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } { ptr null, ptr @another_entrypoint_no_callees, ptr null, ptr null, i8 0 }
 ; LOWERING: @[[GLOB3:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @[[GLOB4:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @[[GLOB5:[0-9]+]] = internal global { ptr, ptr, ptr, ptr, i8 } zeroinitializer



More information about the llvm-commits mailing list