[compiler-rt] [llvm] [DO NOT SUBMIT] Contextual iFDO "demo" PR (PR #86036)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 25 09:20:59 PDT 2024


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/86036

>From 24b15900137beedd9198e61b4620a73d96334a78 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 19 Mar 2024 20:16:35 -0700
Subject: [PATCH 1/5] WIP

---
 compiler-rt/lib/profile/CMakeLists.txt        |   5 +
 .../lib/profile/InstrProfilingContextual.cpp  | 259 ++++++++++++++++++
 .../lib/profile/InstrProfilingContextual.h    | 136 +++++++++
 compiler-rt/lib/profile/tests/CMakeLists.txt  |  72 +++++
 .../tests/InstrProfilingContextualTest.cpp    | 121 ++++++++
 compiler-rt/lib/profile/tests/driver.cpp      |  14 +
 llvm/include/llvm/IR/IntrinsicInst.h          |  13 +
 llvm/include/llvm/IR/Intrinsics.td            |   5 +
 .../Instrumentation/PGOCtxProfLowering.h      |  40 +++
 llvm/lib/IR/IntrinsicInst.cpp                 |   4 +
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassBuilderPipelines.cpp      |  24 +-
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../Transforms/Instrumentation/CMakeLists.txt |   1 +
 .../Instrumentation/InstrProfiling.cpp        |   3 +
 .../Instrumentation/PGOCtxProfLowering.cpp    | 214 +++++++++++++++
 .../Instrumentation/PGOInstrumentation.cpp    |  38 ++-
 llvm/test/Transforms/PGOProfile/ctx-basic.ll  |  86 ++++++
 18 files changed, 1023 insertions(+), 14 deletions(-)
 create mode 100644 compiler-rt/lib/profile/InstrProfilingContextual.cpp
 create mode 100644 compiler-rt/lib/profile/InstrProfilingContextual.h
 create mode 100644 compiler-rt/lib/profile/tests/CMakeLists.txt
 create mode 100644 compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
 create mode 100644 compiler-rt/lib/profile/tests/driver.cpp
 create mode 100644 llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
 create mode 100644 llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
 create mode 100644 llvm/test/Transforms/PGOProfile/ctx-basic.ll

diff --git a/compiler-rt/lib/profile/CMakeLists.txt b/compiler-rt/lib/profile/CMakeLists.txt
index 45e51648917515..a2913f28ca017f 100644
--- a/compiler-rt/lib/profile/CMakeLists.txt
+++ b/compiler-rt/lib/profile/CMakeLists.txt
@@ -51,6 +51,7 @@ add_compiler_rt_component(profile)
 set(PROFILE_SOURCES
   GCDAProfiling.c
   InstrProfiling.c
+  InstrProfilingContextual.cpp
   InstrProfilingInternal.c
   InstrProfilingValue.c
   InstrProfilingBuffer.c
@@ -142,3 +143,7 @@ else()
     ADDITIONAL_HEADERS ${PROFILE_HEADERS}
     PARENT_TARGET profile)
 endif()
+
+if(COMPILER_RT_INCLUDE_TESTS)
+  add_subdirectory(tests)
+endif()
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.cpp b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
new file mode 100644
index 00000000000000..e13d0688cca50f
--- /dev/null
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
@@ -0,0 +1,259 @@
+//===- InstrProfilingContextual.cpp - PGO runtime initialization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "InstrProfiling.h"
+
+#include "sanitizer_common/sanitizer_atomic.h"
+#include "sanitizer_common/sanitizer_atomic_clang.h"
+#include "sanitizer_common/sanitizer_common.h"
+#include "sanitizer_common/sanitizer_dense_map.h"
+#include "sanitizer_common/sanitizer_mutex.h"
+#include "sanitizer_common/sanitizer_placement_new.h"
+#include "sanitizer_common/sanitizer_thread_safety.h"
+#include "sanitizer_common/sanitizer_vector.h"
+
+#include "InstrProfilingContextual.h"
+
+using namespace __profile;
+
+Arena *Arena::allocate(size_t Size, Arena *Prev) {
+  Arena *NewArena =
+      new (__sanitizer::InternalAlloc(Size + sizeof(Arena))) Arena(Size);
+  if (Prev)
+    Prev->Next = NewArena;
+  return NewArena;
+}
+
+inline ContextNode *ContextNode::alloc(char *Place, GUID Guid,
+                                       uint32_t NrCounters,
+                                       uint32_t NrCallsites,
+                                       ContextNode *Next) {
+  return new (Place) ContextNode(Guid, NrCounters, NrCallsites, Next);
+}
+
+void ContextNode::reset() {
+  for (uint32_t I = 0; I < NrCounters; ++I)
+    counters()[I] = 0;
+  for (uint32_t I = 0; I < NrCallsites; ++I)
+    for (auto *Next = subContexts()[I]; Next; Next = Next->Next)
+      Next->reset();
+}
+
+namespace {
+__sanitizer::SpinMutex AllContextsMutex;
+SANITIZER_GUARDED_BY(AllContextsMutex)
+__sanitizer::Vector<ContextRoot *> AllContextRoots;
+
+ContextNode * markAsScratch(const ContextNode* Ctx) {
+  return reinterpret_cast<ContextNode*>(reinterpret_cast<uint64_t>(Ctx) | 1);
+}
+
+template<typename T>
+T consume(T& V) {
+  auto R = V;
+  V = {0};
+  return R;
+}
+
+constexpr size_t kPower = 20;
+constexpr size_t kBuffSize = 1 << kPower;
+
+size_t getArenaAllocSize(size_t Needed) {
+  if (Needed >= kPower)
+    return 2 * Needed;
+  return kPower;
+}
+
+bool validate(const ContextRoot *Root) {
+  __sanitizer::DenseMap<uint64_t, bool> ContextStartAddrs;
+  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
+    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+      if (!ContextStartAddrs.insert({reinterpret_cast<uint64_t>(Ctx), true})
+               .second)
+        return false;
+      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
+                                            Ctx->size());
+    }
+  }
+
+  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
+    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+      for (uint32_t I = 0; I < Ctx->callsites_size(); ++I)
+        for (auto *Sub = Ctx->subContexts()[I]; Sub; Sub = Sub->next())
+          if (!ContextStartAddrs.find(reinterpret_cast<uint64_t>(Sub)))
+            return false;
+
+      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
+                                            Ctx->size());
+    }
+  }
+  return true;
+}
+} // namespace
+
+extern "C" {
+__thread char __Buffer[kBuffSize] = {0};
+
+#define TheNullContext markAsScratch(reinterpret_cast<ContextNode *>(__Buffer))
+__thread void *volatile __llvm_instrprof_expected_callee[2] = {nullptr, nullptr};
+__thread ContextNode **volatile __llvm_instrprof_callsite[2] = {0, 0};
+
+COMPILER_RT_VISIBILITY __thread ContextRoot
+    *volatile __llvm_instrprof_current_context_root = nullptr;
+
+COMPILER_RT_VISIBILITY ContextNode *
+__llvm_instprof_slow_get_callsite(uint64_t Guid, ContextNode **InsertionPoint,
+                                  uint32_t NrCounters, uint32_t NrCallsites) {
+  auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites);
+  auto *Mem = __llvm_instrprof_current_context_root->CurrentMem;
+  char* AllocPlace = Mem->tryAllocate(AllocSize);
+  if (!AllocPlace) {
+    __llvm_instrprof_current_context_root->CurrentMem = Mem =
+        Mem->allocate(getArenaAllocSize(AllocSize), Mem);
+  }
+  auto *Ret = ContextNode::alloc(AllocPlace, Guid, NrCounters, NrCallsites,
+                                 *InsertionPoint);
+  *InsertionPoint = Ret;
+  return Ret;
+}
+
+COMPILER_RT_VISIBILITY ContextNode *
+__llvm_instrprof_get_context(void *Callee, GUID Guid, uint32_t NrCounters,
+                            uint32_t NrCallsites) {
+  if (!__llvm_instrprof_current_context_root) {
+    return TheNullContext;
+  }
+  auto **CallsiteContext = consume(__llvm_instrprof_callsite[0]);
+  if (!CallsiteContext || isScratch(*CallsiteContext))
+    return TheNullContext;
+  auto *ExpectedCallee = consume(__llvm_instrprof_expected_callee[0]);
+  if (ExpectedCallee != Callee)
+    return TheNullContext;
+
+  auto *Callsite = *CallsiteContext;
+  while (Callsite && Callsite->guid() != Guid) {
+    Callsite = Callsite->next();
+  }
+  auto *Ret = Callsite ? Callsite
+                       : __llvm_instprof_slow_get_callsite(
+                             Guid, CallsiteContext, NrCounters, NrCallsites);
+  if (Ret->callsites_size() != NrCallsites || Ret->counters_size() != NrCounters)
+    __sanitizer::Printf("[ctxprof] Returned ctx differs from what's asked: "
+                        "Context: %p, Asked: %zu %u %u, Got: %zu %u %u \n",
+                        Ret, Guid, NrCallsites, NrCounters, Ret->guid(),
+                        Ret->callsites_size(), Ret->counters_size());
+  Ret->onEntry();
+  return Ret;
+}
+
+COMPILER_RT_VISIBILITY void
+__llvm_instprof_setup_context(ContextRoot *Root, GUID Guid, uint32_t NrCounters,
+                              uint32_t NrCallsites) {
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+  // Re-check - we got here without having had taken a lock.
+  if (Root->FirstMemBlock)
+    return;
+  const auto Needed = ContextNode::getAllocSize(NrCounters, NrCallsites);
+  auto *M = Arena::allocate(getArenaAllocSize(Needed));
+  Root->CurrentMem = M;
+  Root->FirstNode =
+      ContextNode::alloc(M->tryAllocate(Needed), Guid, NrCounters, NrCallsites);
+  AllContextRoots.PushBack(Root);
+}
+
+COMPILER_RT_VISIBILITY ContextNode *__llvm_instrprof_start_context(
+    ContextRoot *Root, GUID Guid, uint32_t Counters,
+    uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  if (!Root->FirstMemBlock) {
+    __llvm_instprof_setup_context(Root, Guid, Counters, Callsites);
+  }
+  if (Root->Taken.TryLock()) {
+    __llvm_instrprof_current_context_root = Root;
+    Root->FirstNode->onEntry();
+    return Root->FirstNode;
+  }
+  __llvm_instrprof_current_context_root = nullptr;
+  return TheNullContext;
+}
+
+COMPILER_RT_VISIBILITY void __llvm_instrprof_release_context(ContextRoot *Root)
+    SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  if (__llvm_instrprof_current_context_root) {
+    __llvm_instrprof_current_context_root = nullptr;
+    Root->Taken.Unlock();
+  }
+}
+
+COMPILER_RT_VISIBILITY void __llvm_profile_reset_counters(void) {
+  size_t NrMemUnits = 0;
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+  for (uint32_t I = 0; I < AllContextRoots.Size(); ++I) {
+    auto *Root = AllContextRoots[I];
+    __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> Lock(
+        &Root->Taken);
+    for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next())
+      ++NrMemUnits;
+
+    Root->FirstNode->reset();
+  }
+  __sanitizer::Printf("[ctxprof] Initial NrMemUnits: %zu \n", NrMemUnits);
+}
+
+COMPILER_RT_VISIBILITY
+int __llvm_ctx_profile_dump(const char* Filename) {
+  __sanitizer::Printf("[ctxprof] Start Dump\n");
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+
+  for (int I = 0, E = AllContextRoots.Size(); I < E; ++I) {
+    auto *Root = AllContextRoots[I];
+    __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> TakenLock(
+        &Root->Taken);
+    size_t NrMemUnits = 0;
+    size_t Allocated = 0;
+    for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+      ++NrMemUnits;
+      Allocated += reinterpret_cast<uint64_t>(Mem->pos()) -
+                    reinterpret_cast<uint64_t>(Mem);
+    }
+    auto Valid = validate(Root);
+  }
+
+  if (!Filename) {
+    PROF_ERR("Failed to write file : %s\n", "Filename not set");
+    return -1;
+  }
+  FILE *F = fopen(Filename, "w");
+  if (!F) {
+    PROF_ERR("Failed to open file : %s\n", Filename);
+    return -1;
+  }
+
+  for (int I = 0, E = AllContextRoots.Size(); I < E; ++I) {
+    const auto *Root = AllContextRoots[I];
+    for (const auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+      const uint64_t MemStartAddr =
+          reinterpret_cast<const uint64_t>(Mem->start());
+      if (fwrite(reinterpret_cast<const char *>(&MemStartAddr),
+                 sizeof(uint64_t), 1, F) != 1)
+        return -1;
+      if (fwrite(reinterpret_cast<const char *>(&kBuffSize), sizeof(uint64_t),
+                 1, F) != 1)
+        return -1;
+      if (fwrite(reinterpret_cast<const char *>(Mem), sizeof(char), kBuffSize,
+                 F) != kBuffSize)
+        return -1;
+    }
+  }
+  return fclose(F);
+}
+}
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.h b/compiler-rt/lib/profile/InstrProfilingContextual.h
new file mode 100644
index 00000000000000..7525a32cc32b13
--- /dev/null
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.h
@@ -0,0 +1,136 @@
+/*===- InstrProfilingArena.h- Simple arena  -------------------------------===*\
+|*
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+|* See https://llvm.org/LICENSE.txt for license information.
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+|*
+\*===----------------------------------------------------------------------===*/
+
+#ifndef PROFILE_INSTRPROFILINGARENA_H_
+#define PROFILE_INSTRPROFILINGARENA_H_
+
+#include "InstrProfiling.h"
+#include "sanitizer_common/sanitizer_mutex.h"
+
+namespace __profile {
+using GUID = uint64_t;
+
+/// Arena forming a linked list, if more space is needed. Intentionally not
+/// thread safe.
+class Arena final {
+public:
+  static Arena *allocate(size_t Size, Arena *Prev = nullptr);
+  size_t size() const { return Size; }
+  char *tryAllocate(size_t S) {
+    if (Pos + S > Size)
+      return nullptr;
+    Pos += S;
+    return start() + (Pos - S);
+  }
+  Arena *next() const { return Next; }
+  const char *start() const { return const_cast<Arena*>(this)->start(); }
+  const char *pos() const { return start() + Pos; }
+
+private:
+  explicit Arena(size_t Size) : Size(Size) {}
+  char *start() { return reinterpret_cast<char *>(&this[1]); }
+
+  Arena *Next = nullptr;
+  size_t Pos = 0;
+  const size_t Size;
+};
+
+class ContextNode final {
+  const GUID Guid;
+  ContextNode *const Next;
+  const uint32_t NrCounters;
+  const uint32_t NrCallsites;
+
+public:
+  ContextNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites,
+              ContextNode *Next = nullptr)
+      : Guid(Guid), Next(Next), NrCounters(NrCounters),
+        NrCallsites(NrCallsites) {}
+  static inline ContextNode *alloc(char *Place, GUID Guid, uint32_t NrCounters,
+                                   uint32_t NrCallsites,
+                                   ContextNode *Next = nullptr);
+
+  static inline size_t getAllocSize(uint32_t NrCounters, uint32_t NrCallsites) {
+    return sizeof(ContextNode) + sizeof(uint64_t) * NrCounters +
+           sizeof(ContextNode *) * NrCallsites;
+  }
+
+  uint64_t *counters() {
+    ContextNode *addr_after = &(this[1]);
+    return reinterpret_cast<uint64_t *>(reinterpret_cast<char *>(addr_after));
+  }
+
+  uint32_t counters_size() const { return NrCounters; }
+  uint32_t callsites_size() const { return NrCallsites; }
+
+  const uint64_t *counters() const {
+    return const_cast<ContextNode *>(this)->counters();
+  }
+
+  ContextNode **subContexts() {
+    return reinterpret_cast<ContextNode **>(&(counters()[NrCounters]));
+  }
+
+  ContextNode *const *subContexts() const {
+    return const_cast<ContextNode *>(this)->subContexts();
+  }
+
+  GUID guid() const { return Guid; }
+  ContextNode *next() { return Next; }
+
+  size_t size() const { return getAllocSize(NrCounters, NrCallsites); }
+
+  void reset();
+
+  void onEntry() {
+    ++counters()[0];
+  }
+
+  uint64_t entrycount() const {
+    return counters()[0];
+  }
+};
+
+
+// Exposed for test. Constructed and zero-initialized by LLVM. Implicitly,
+// LLVM must know the shape of this.
+struct ContextRoot {
+  ContextNode *FirstNode = nullptr;
+  Arena *FirstMemBlock = nullptr;
+  Arena *CurrentMem = nullptr;
+  // This is init-ed by the static zero initializer in LLVM.
+  ::__sanitizer::StaticSpinMutex Taken;
+};
+
+inline bool isScratch(const ContextNode* Ctx) {
+  return (reinterpret_cast<uint64_t>(Ctx) & 1);
+}
+
+} // namespace __profile
+
+extern "C" {
+
+// position 0 is used when the current context isn't scratch, 1 when it is.
+extern __thread void *volatile __llvm_instrprof_expected_callee[2];
+extern __thread __profile::ContextNode **volatile __llvm_instrprof_callsite[2];
+
+extern __thread __profile::ContextRoot
+    *volatile __llvm_instrprof_current_context_root;
+
+COMPILER_RT_VISIBILITY __profile::ContextNode *
+__llvm_instrprof_start_context(__profile::ContextRoot *Root,
+                              __profile::GUID Guid, uint32_t Counters,
+                              uint32_t Callsites);
+COMPILER_RT_VISIBILITY void
+__llvm_instrprof_release_context(__profile::ContextRoot *Root);
+
+COMPILER_RT_VISIBILITY __profile::ContextNode *
+__llvm_instrprof_get_context(void *Callee, __profile::GUID Guid,
+                            uint32_t NrCounters, uint32_t NrCallsites);
+}
+#endif
\ No newline at end of file
diff --git a/compiler-rt/lib/profile/tests/CMakeLists.txt b/compiler-rt/lib/profile/tests/CMakeLists.txt
new file mode 100644
index 00000000000000..c221d15b541829
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/CMakeLists.txt
@@ -0,0 +1,72 @@
+include(CheckCXXCompilerFlag)
+include(CompilerRTCompile)
+include(CompilerRTLink)
+
+set(PROFILE_UNITTEST_CFLAGS
+  ${COMPILER_RT_UNITTEST_CFLAGS}
+  ${COMPILER_RT_GTEST_CFLAGS}
+  ${COMPILER_RT_GMOCK_CFLAGS}
+  ${SANITIZER_TEST_CXX_CFLAGS}
+  -I${COMPILER_RT_SOURCE_DIR}/lib/
+  -DSANITIZER_COMMON_NO_REDEFINE_BUILTINS
+  -O2
+  -g
+  -fno-rtti
+  -Wno-pedantic
+  -fno-omit-frame-pointer)
+
+# Suppress warnings for gmock variadic macros for clang and gcc respectively.
+append_list_if(SUPPORTS_GNU_ZERO_VARIADIC_MACRO_ARGUMENTS_FLAG -Wno-gnu-zero-variadic-macro-arguments PROFILE_UNITTEST_CFLAGS)
+append_list_if(COMPILER_RT_HAS_WVARIADIC_MACROS_FLAG -Wno-variadic-macros PROFILE_UNITTEST_CFLAGS)
+
+file(GLOB PROFILE_HEADERS ../*.h)
+
+set(PROFILE_SOURCES
+  ../InstrProfilingContextual.cpp)
+
+set(PROFILE_UNITTESTS
+  InstrProfilingContextualTest.cpp
+  driver.cpp)
+
+include_directories(../../../include)
+
+set(PROFILE_UNIT_TEST_HEADERS
+  ${PROFILE_HEADERS})
+
+set(PROFILE_UNITTEST_LINK_FLAGS
+  ${COMPILER_RT_UNITTEST_LINK_FLAGS})
+
+if(NOT WIN32)
+  list(APPEND PROFILE_UNITTEST_LINK_FLAGS -pthread)
+endif()
+
+set(PROFILE_UNITTEST_LINK_LIBRARIES
+  ${COMPILER_RT_UNWINDER_LINK_LIBS}
+  ${SANITIZER_TEST_CXX_LIBRARIES})
+list(APPEND PROFILE_UNITTEST_LINK_LIBRARIES "dl")
+
+if(COMPILER_RT_DEFAULT_TARGET_ARCH IN_LIST PROFILE_SUPPORTED_ARCH)
+  # Profile unit tests are only run on the host machine.
+  set(arch ${COMPILER_RT_DEFAULT_TARGET_ARCH})
+
+  add_executable(ProfileUnitTests 
+    ${PROFILE_UNITTESTS}
+    ${COMPILER_RT_GTEST_SOURCE}
+    ${COMPILER_RT_GMOCK_SOURCE}
+    ${PROFILE_SOURCES}
+    $<TARGET_OBJECTS:RTSanitizerCommon.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonCoverage.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonLibc.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonSymbolizer.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonSymbolizerInternal.${arch}>)
+  set_target_compile_flags(ProfileUnitTests ${PROFILE_UNITTEST_CFLAGS})
+  set_target_link_flags(ProfileUnitTests ${PROFILE_UNITTEST_LINK_FLAGS})
+  target_link_libraries(ProfileUnitTests ${PROFILE_UNITTEST_LINK_LIBRARIES})
+
+  if (TARGET cxx-headers OR HAVE_LIBCXX)
+    add_dependencies(ProfileUnitTests cxx-headers)
+  endif()
+
+  set_target_properties(ProfileUnitTests PROPERTIES
+    RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
+endif()
diff --git a/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
new file mode 100644
index 00000000000000..b53fbb5bda7c49
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
@@ -0,0 +1,121 @@
+#include "../InstrProfilingContextual.h"
+#include "gtest/gtest.h"
+#include <mutex>
+#include <thread>
+
+using namespace __profile;
+
+TEST(ArenaTest, Basic) {
+  Arena * A = Arena::allocate(1024);
+  EXPECT_EQ(A->size(), 1024U);
+  EXPECT_EQ(A->next(), nullptr);
+
+  auto *M1 = A->tryAllocate(1020); 
+  EXPECT_NE(M1, nullptr);
+  auto *M2 = A->tryAllocate(4);
+  EXPECT_NE(M2, nullptr);
+  EXPECT_EQ(M1 + 1020, M2);
+  EXPECT_EQ(A->tryAllocate(1), nullptr);
+  Arena *A2 = Arena::allocate(2024, A);
+  EXPECT_EQ(A->next(), A2);
+  EXPECT_EQ(A2->next(), nullptr);
+}
+
+TEST(ContextTest, Basic) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  EXPECT_NE(Ctx, nullptr);
+  EXPECT_EQ(Ctx->size(), sizeof(ContextNode) + 10 * sizeof(uint64_t) +
+                             4 * sizeof(ContextNode *));
+  EXPECT_EQ(Ctx->counters_size(), 10U);
+  EXPECT_EQ(Ctx->callsites_size(), 4U);
+  EXPECT_EQ(__llvm_instrprof_current_context_root, &Root);
+  Root.Taken.CheckLocked();
+  EXPECT_FALSE(Root.Taken.TryLock());
+  __llvm_instrprof_release_context(&Root);
+  EXPECT_EQ(__llvm_instrprof_current_context_root, nullptr);
+  EXPECT_TRUE(Root.Taken.TryLock());
+  Root.Taken.Unlock();
+}
+
+TEST(ContextTest, Callsite) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  int OpaqueValue = 0;
+  const bool IsScratch = isScratch(Ctx);
+  EXPECT_FALSE(IsScratch);
+  __llvm_instrprof_expected_callee[0] = &OpaqueValue;
+  __llvm_instrprof_callsite[0] = &Ctx->subContexts()[2];
+  auto *Subctx = __llvm_instrprof_get_context(&OpaqueValue, 2, 3, 1);
+  EXPECT_EQ(Ctx->subContexts()[2], Subctx);
+  EXPECT_EQ(Subctx->counters_size(), 3U);
+  EXPECT_EQ(Subctx->callsites_size(), 1U);
+  EXPECT_EQ(__llvm_instrprof_expected_callee[0], nullptr);
+  EXPECT_EQ(__llvm_instrprof_callsite[0], nullptr);
+  
+  EXPECT_EQ(Subctx->size(), sizeof(ContextNode) + 3 * sizeof(uint64_t) +
+                                1 * sizeof(ContextNode *));
+  __llvm_instrprof_release_context(&Root);
+}
+
+TEST(ContextTest, ScratchNoCollection) {
+  EXPECT_EQ(__llvm_instrprof_current_context_root, nullptr);
+  int OpaqueValue = 0;
+  // this would be the very first function executing this. the TLS is empty,
+  // too.
+  auto *Ctx = __llvm_instrprof_get_context(&OpaqueValue, 2, 3, 1);
+  EXPECT_TRUE(isScratch(Ctx));
+}
+
+TEST(ContextTest, ScratchDuringCollection) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  int OpaqueValue = 0;
+  int OtherOpaqueValue = 0;
+  __llvm_instrprof_expected_callee[0] = &OpaqueValue;
+  __llvm_instrprof_callsite[0] = &Ctx->subContexts()[2];
+  auto *Subctx = __llvm_instrprof_get_context(&OtherOpaqueValue, 2, 3, 1);
+  EXPECT_TRUE(isScratch(Subctx));
+  EXPECT_EQ(__llvm_instrprof_expected_callee[0], nullptr);
+  EXPECT_EQ(__llvm_instrprof_callsite[0], nullptr);
+  
+  int ThirdOpaqueValue = 0;
+  __llvm_instrprof_expected_callee[1] = &ThirdOpaqueValue;
+  __llvm_instrprof_callsite[1] = &Subctx->subContexts()[0];
+
+  auto *Subctx2 = __llvm_instrprof_get_context(&ThirdOpaqueValue, 3, 0, 0);
+  EXPECT_TRUE(isScratch(Subctx2));
+  
+  __llvm_instrprof_release_context(&Root);
+}
+
+TEST(ContextTest, ConcurrentRootCollection) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  std::atomic<int> NonScratch = 0;
+  std::atomic<int> Executions = 0;
+
+  __sanitizer::Semaphore GotCtx;
+
+  auto Entrypoint = [&]() {
+    ++Executions;
+    auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+    GotCtx.Post();
+    const bool IS = isScratch(Ctx);
+    NonScratch += (!IS);
+    if (!IS) {
+      GotCtx.Wait();
+      GotCtx.Wait();
+    }
+    __llvm_instrprof_release_context(&Root);
+  };
+  std::thread T1(Entrypoint);
+  std::thread T2(Entrypoint);
+  T1.join();
+  T2.join();
+  EXPECT_EQ(NonScratch, 1);
+  EXPECT_EQ(Executions, 2);
+}
\ No newline at end of file
diff --git a/compiler-rt/lib/profile/tests/driver.cpp b/compiler-rt/lib/profile/tests/driver.cpp
new file mode 100644
index 00000000000000..b402cec1126b33
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/driver.cpp
@@ -0,0 +1,14 @@
+//===-- driver.cpp ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+
+int main(int argc, char **argv) {
+  testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index c07b83a81a63e1..6a4cce541aef29 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1435,6 +1435,7 @@ class InstrProfInstBase : public IntrinsicInst {
     case Intrinsic::instrprof_cover:
     case Intrinsic::instrprof_increment:
     case Intrinsic::instrprof_increment_step:
+    case Intrinsic::instrprof_callsite:
     case Intrinsic::instrprof_timestamp:
     case Intrinsic::instrprof_value_profile:
       return true;
@@ -1508,6 +1509,18 @@ class InstrProfIncrementInst : public InstrProfCntrInstBase {
   Value *getStep() const;
 };
 
+/// This represents the llvm.instrprof.increment intrinsic.
+class InstrProfCallsite : public InstrProfCntrInstBase {
+public:
+  static bool classof(const IntrinsicInst *I) {
+    return I->getIntrinsicID() == Intrinsic::instrprof_callsite;
+  }
+  static bool classof(const Value *V) {
+    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+  }
+  Value *getCallee() const;
+};
+
 /// This represents the llvm.instrprof.increment.step intrinsic.
 class InstrProfIncrementInstStep : public InstrProfIncrementInst {
 public:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 144298fd7c0162..af57042f0e8e14 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -909,6 +909,11 @@ def int_instrprof_increment_step : Intrinsic<[],
                                         [llvm_ptr_ty, llvm_i64_ty,
                                          llvm_i32_ty, llvm_i32_ty, llvm_i64_ty]>;
 
+// Callsite instrumentation for contextual profiling
+def int_instrprof_callsite : Intrinsic<[],
+                                        [llvm_ptr_ty, llvm_i64_ty,
+                                         llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty]>;
+
 // A timestamp for instrumentation based profiling.
 def int_instrprof_timestamp : Intrinsic<[], [llvm_ptr_ty, llvm_i64_ty,
                                              llvm_i32_ty, llvm_i32_ty]>;
diff --git a/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
new file mode 100644
index 00000000000000..975775f18a5152
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
@@ -0,0 +1,40 @@
+//===--------- Definition of the MemProfiler class --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the MemProfiler class.
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
+#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+class Type;
+
+class PGOCtxProfLoweringPass : public PassInfoMixin<PGOCtxProfLoweringPass> {
+public:
+  explicit PGOCtxProfLoweringPass() = default;
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+  static bool isContextualIRPGOEnabled();
+
+private:
+  Type *ContextNodeTy = nullptr;
+  Type *ContextRootTy = nullptr;
+
+  DenseMap<const Function*, Constant*> ContextRootMap;
+  Function *StartCtx = nullptr;
+  Function *GetCtx = nullptr;
+  Function *ReleaseCtx = nullptr;
+  GlobalVariable *ExpectedCalleeTLS = nullptr;
+  GlobalVariable *CallsiteInfoTLS = nullptr;
+
+  void lowerFunction(Function &F);
+};
+} // namespace llvm
+#endif
\ No newline at end of file
diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp
index 89403e1d7fcb4d..9bf8dae89458f3 100644
--- a/llvm/lib/IR/IntrinsicInst.cpp
+++ b/llvm/lib/IR/IntrinsicInst.cpp
@@ -291,6 +291,10 @@ Value *InstrProfIncrementInst::getStep() const {
   return ConstantInt::get(Type::getInt64Ty(Context), 1);
 }
 
+Value *InstrProfCallsite::getCallee() const {
+  return getArgOperand(4);
+}
+
 std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
   unsigned NumOperands = arg_size();
   Metadata *MD = nullptr;
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 4d1eb10d2d41c6..ffc23988284196 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -174,6 +174,7 @@
 #include "llvm/Transforms/Instrumentation/KCFI.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
 #include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Instrumentation/PoisonChecking.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index cb892e30c4a0b9..d2550774400d69 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -74,6 +74,7 @@
 #include "llvm/Transforms/Instrumentation/InstrOrderFile.h"
 #include "llvm/Transforms/Instrumentation/InstrProfiling.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Scalar/ADCE.h"
@@ -826,16 +827,19 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
             /*UseBlockFrequencyInfo=*/false),
         PTO.EagerlyInvalidateAnalyses));
   }
-
-  // Add the profile lowering pass.
-  InstrProfOptions Options;
-  if (!ProfileFile.empty())
-    Options.InstrProfileOutput = ProfileFile;
-  // Do counter promotion at Level greater than O0.
-  Options.DoCounterPromotion = true;
-  Options.UseBFIInPromotion = IsCS;
-  Options.Atomic = AtomicCounterUpdate;
-  MPM.addPass(InstrProfilingLoweringPass(Options, IsCS));
+  if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) {
+    MPM.addPass(PGOCtxProfLoweringPass());
+  } else {
+    // Add the profile lowering pass.
+    InstrProfOptions Options;
+    if (!ProfileFile.empty())
+      Options.InstrProfileOutput = ProfileFile;
+    // Do counter promotion at Level greater than O0.
+    Options.DoCounterPromotion = true;
+    Options.UseBFIInPromotion = IsCS;
+    Options.Atomic = AtomicCounterUpdate;
+    MPM.addPass(InstrProfilingLoweringPass(Options, IsCS));
+  }
 }
 
 void PassBuilder::addPGOInstrPassesForO0(
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 41f16d0915bf23..4b6f927d5b6c46 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -77,6 +77,7 @@ MODULE_PASS("inliner-wrapper-no-mandatory-first",
 MODULE_PASS("insert-gcov-profiling", GCOVProfilerPass())
 MODULE_PASS("instrorderfile", InstrOrderFilePass())
 MODULE_PASS("instrprof", InstrProfilingLoweringPass())
+MODULE_PASS("pgo-ctx-instr-lower", PGOCtxProfLoweringPass())
 MODULE_PASS("internalize", InternalizePass())
 MODULE_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 MODULE_PASS("iroutliner", IROutlinerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
index b23a6ed1f08415..e3a1a1a7209e5a 100644
--- a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
@@ -13,6 +13,7 @@ add_llvm_component_library(LLVMInstrumentation
   InstrOrderFile.cpp
   InstrProfiling.cpp
   KCFI.cpp
+  PGOCtxProfLowering.cpp
   PGOForceFunctionAttrs.cpp
   PGOInstrumentation.cpp
   PGOMemOPSizeOpt.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index c42c53edd51190..8f27c4738f7ef6 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -636,6 +636,9 @@ bool InstrLowerer::lowerIntrinsics(Function *F) {
       } else if (auto *IPTU = dyn_cast<InstrProfMCDCCondBitmapUpdate>(&Instr)) {
         lowerMCDCCondBitmapUpdate(IPTU);
         MadeChange = true;
+      } else if (isa<InstrProfCallsite>(Instr)) {
+        Instr.eraseFromParent();
+        MadeChange = true;
       }
     }
   }
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
new file mode 100644
index 00000000000000..1212f7acf53ca7
--- /dev/null
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -0,0 +1,214 @@
+//===- PGOCtxProfLowering.cpp - Contextual  PGO Instrumentation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+
+using namespace llvm;
+
+static cl::list<std::string> ContextRoots("profile-context-root");
+
+bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
+  return !ContextRoots.empty();
+}
+
+PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
+                                              ModuleAnalysisManager &MAM) {
+  ContextRootTy = nullptr;
+  auto *PointerTy = PointerType::get(M.getContext(), 0);
+  auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
+  auto *I32Ty = Type::getInt32Ty(M.getContext());
+  auto *I64Ty = Type::getInt64Ty(M.getContext());
+
+  ContextRootTy =
+      StructType::get(M.getContext(), {
+                                          PointerTy,          /*FirstNode*/
+                                          PointerTy,          /*FirstMemBlock*/
+                                          PointerTy,          /*CurrentMem*/
+                                          SanitizerMutexType, /*Taken*/
+                                      });
+  ContextNodeTy = StructType::get(M.getContext(), {
+                                                      I64Ty,     /*Guid*/
+                                                      PointerTy, /*Next*/
+                                                      I32Ty,     /*NrCounters*/
+                                                      I32Ty,     /*NrCallsites*/
+                                                  });
+
+  for (const auto &Fname : ContextRoots) {
+    if (const auto *F = M.getFunction(Fname)) {
+      if (F->isDeclaration())
+        continue;
+      auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
+      cast<GlobalVariable>(G)->setInitializer(
+          Constant::getNullValue(ContextRootTy));
+      ContextRootMap.insert(std::make_pair(F, G));
+    }
+  }
+
+  StartCtx = cast<Function>(
+      M.getOrInsertFunction(
+           "__llvm_instrprof_start_context",
+           FunctionType::get(ContextNodeTy->getPointerTo(),
+                             {ContextRootTy->getPointerTo(), /*ContextRoot*/
+                              I64Ty, /*Guid*/ I32Ty,
+                              /*NrCounters*/ I32Ty /*NrCallsites*/},
+                             false))
+          .getCallee());
+  GetCtx = cast<Function>(
+      M.getOrInsertFunction("__llvm_instrprof_get_context",
+                            FunctionType::get(ContextNodeTy->getPointerTo(),
+                                              {PointerTy, /*Callee*/
+                                               I64Ty,     /*Guid*/
+                                               I32Ty,     /*NrCounters*/
+                                               I32Ty},    /*NrCallsites*/
+                                              false))
+          .getCallee());
+  ReleaseCtx = cast<Function>(
+      M.getOrInsertFunction(
+           "__llvm_instrprof_release_context",
+           FunctionType::get(Type::getVoidTy(M.getContext()),
+                             {
+                                 ContextRootTy->getPointerTo(), /*ContextRoot*/
+                             },
+                             false))
+          .getCallee());
+  CallsiteInfoTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, "__llvm_instrprof_callsite");
+  CallsiteInfoTLS->setThreadLocal(true);
+  CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  ExpectedCalleeTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, "__llvm_instrprof_expected_callee");
+  ExpectedCalleeTLS->setThreadLocal(true);
+  ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  
+  for (auto &F : M)
+    lowerFunction(F);
+  return PreservedAnalyses::none();
+}
+
+void PGOCtxProfLoweringPass::lowerFunction(Function &F) {
+  if (F.isDeclaration())
+    return;
+
+  Value *Guid = nullptr;
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+  [&]() {
+    for (const auto &BB : F)
+      for (const auto &I : BB) {
+        if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
+          if (!NrCounters)
+            NrCounters = static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+        } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
+          if (!NrCallsites)
+            NrCallsites =
+                static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+        }
+        if (NrCounters && NrCallsites)
+          return;
+      }
+  }();
+
+  Value *Context = nullptr;
+  Value *RealContext = nullptr;
+
+  StructType *ThisContextType = nullptr;
+  Value* TheRootContext = nullptr;
+  Value *ExpectedCalleeTLSAddr = nullptr;
+  Value *CallsiteInfoTLSAddr = nullptr;
+
+  auto &Head = F.getEntryBlock();
+  for (auto &I : Head) {
+    if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
+      assert(Mark->getIndex()->isZero());
+
+      IRBuilder<> Builder(Mark);
+      // TODO!!!! use InstrProfSymtab::getCanonicalName
+      Guid = Builder.getInt64(F.getGUID());
+      ThisContextType = StructType::get(
+          F.getContext(),
+          {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
+           ArrayType::get(Builder.getPtrTy(), NrCallsites)});
+      auto Iter = ContextRootMap.find(&F);
+      if (Iter != ContextRootMap.end()) {
+        TheRootContext = Iter->second;
+        Context = Builder.CreateCall(
+            StartCtx,
+            {TheRootContext, Guid, Builder.getInt32(NrCounters),
+             Builder.getInt32(NrCallsites)});
+      } else {
+        Context =
+            Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
+                                        Builder.getInt32(NrCallsites)});
+      }
+      auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
+      if (NrCallsites > 0) {
+        auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
+        ExpectedCalleeTLSAddr = Builder.CreateGEP(
+            Builder.getInt8Ty()->getPointerTo(),
+            Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
+        CallsiteInfoTLSAddr = Builder.CreateGEP(
+            Builder.getInt32Ty(),
+            Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
+      }
+      RealContext = Builder.CreateIntToPtr(
+          Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
+          ThisContextType->getPointerTo());
+      I.eraseFromParent();
+      break;
+    }
+  }
+  if(!Context) {
+    dbgs() << "[instprof] Function doesn't have instrumentation, skipping "
+           << F.getName() << "\n";
+    return;
+  }
+
+  for (auto &BB : F) {
+    for (auto &I : llvm::make_early_inc_range(BB)) {
+      if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
+        IRBuilder<> Builder(Instr);
+        switch (Instr->getIntrinsicID()) {
+        case llvm::Intrinsic::instrprof_increment:
+        case llvm::Intrinsic::instrprof_increment_step: {
+          auto *AsStep = cast<InstrProfIncrementInst>(Instr);
+          auto *GEP = Builder.CreateGEP(
+              ThisContextType, RealContext,
+              {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
+          Builder.CreateStore(
+              Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
+                                AsStep->getStep()),
+              GEP);
+        } break;
+        case llvm::Intrinsic::instrprof_callsite:
+          auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
+          Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
+                              true);
+          Builder.CreateStore(
+              Builder.CreateGEP(ThisContextType, Context,
+                                {Builder.getInt32(0), Builder.getInt32(2),
+                                 CSIntrinsic->getIndex()}),
+              CallsiteInfoTLSAddr, true);
+          break;
+        }
+        I.eraseFromParent();
+      } else if (TheRootContext && isa<ReturnInst>(I)) {
+        IRBuilder<> Builder(&I);
+        Builder.CreateCall(ReleaseCtx, {TheRootContext});
+      }
+    }
+  }
+}
\ No newline at end of file
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 55728709cde556..5d393c1fda124d 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -110,6 +110,7 @@
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
 #include "llvm/Transforms/Instrumentation/CFGMST.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -330,6 +331,11 @@ extern cl::opt<std::string> ViewBlockFreqFuncName;
 extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
 } // namespace llvm
 
+bool shouldInstrumentEntryBB() {
+  return PGOInstrumentEntry ||
+         PGOCtxProfLoweringPass::isContextualIRPGOEnabled();
+}
+
 // Return a string describing the branch condition that can be
 // used in static branch probability heuristics:
 static std::string getBranchCondString(Instruction *TI) {
@@ -376,7 +382,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
   if (IsCS)
     ProfileVersion |= VARIANT_MASK_CSIR_PROF;
-  if (PGOInstrumentEntry)
+  if (shouldInstrumentEntryBB())
     ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
   if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
     ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -856,7 +862,7 @@ static void instrumentOneFunc(
   }
 
   FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
-      F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry,
+      F, TLI, ComdatMembers, true, BPI, BFI, IsCS, shouldInstrumentEntryBB(),
       PGOBlockCoverage);
 
   auto Name = FuncInfo.FuncNameVar;
@@ -878,6 +884,31 @@ static void instrumentOneFunc(
   unsigned NumCounters =
       InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
 
+  auto *CSIntrinsic =
+      Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
+  auto Visit = [&](llvm::function_ref<void(CallBase * CB)> Visitor) {
+    for (auto &BB : F)
+      for (auto &Instr : BB)
+        if (auto *CS = dyn_cast<CallBase>(&Instr)) {
+          if ((CS->getCalledFunction() &&
+               CS->getCalledFunction()->isIntrinsic()) ||
+              dyn_cast<InlineAsm>(CS->getCalledOperand()))
+            continue;
+          Visitor(CS);
+        }
+  };
+  uint32_t TotalNrCallsites = 0;
+  Visit([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
+  uint32_t CallsiteIndex = 0;
+
+  Visit([&](auto *CB){
+    IRBuilder<> Builder(CB);
+    Builder.CreateCall(CSIntrinsic,
+                       {Name, CFGHash, Builder.getInt32(TotalNrCallsites),
+                        Builder.getInt32(CallsiteIndex++),
+                        CB->getCalledOperand()});
+  });
+
   uint32_t I = 0;
   if (PGOTemporalInstrumentation) {
     NumCounters += PGOBlockCoverage ? 8 : 1;
@@ -2001,8 +2032,7 @@ static bool annotateAllFunctions(
   // If the profile marked as always instrument the entry BB, do the
   // same. Note this can be overwritten by the internal option in CFGMST.h
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
-  if (PGOInstrumentEntry.getNumOccurrences() > 0)
-    InstrumentFuncEntry = PGOInstrumentEntry;
+  InstrumentFuncEntry = shouldInstrumentEntryBB();
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {
     if (skipPGOUse(F))
diff --git a/llvm/test/Transforms/PGOProfile/ctx-basic.ll b/llvm/test/Transforms/PGOProfile/ctx-basic.ll
new file mode 100644
index 00000000000000..5b7df61a033d48
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/ctx-basic.ll
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
+; RUN: opt -passes=pgo-instr-gen,pgo-ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN:   -S < %s | FileCheck %s
+
+declare void @bar()
+
+;.
+; CHECK: @__llvm_profile_raw_version = hidden constant i64 360287970189639690, comdat
+; CHECK: @__profn_foo = private constant [3 x i8] c"foo"
+; CHECK: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
+; CHECK: @an_entrypoint_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
+; CHECK: @__llvm_instrprof_callsite = external hidden thread_local global ptr
+; CHECK: @__llvm_instrprof_expected_callee = external hidden thread_local global ptr
+;.
+define void @foo(i32 %a) {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_instrprof_get_context(ptr @foo, i64 6699318081062747564, i32 2, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_instrprof_expected_callee)
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_instrprof_callsite)
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; CHECK-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; CHECK-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; CHECK-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; CHECK:       yes:
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP9]], i32 0, i32 1, i32 1
+; CHECK-NEXT:    [[TMP11:%.*]] = load i64, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[TMP11]], 1
+; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP10]], align 4
+; CHECK-NEXT:    br label [[EXIT:%.*]]
+; CHECK:       no:
+; CHECK-NEXT:    store volatile ptr @bar, ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; CHECK-NEXT:    store volatile ptr [[TMP13]], ptr [[TMP7]], align 8
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+yes:
+  br label %exit
+no:
+  call void @bar()
+  br label %exit
+exit:
+  ret void
+}
+
+define void @an_entrypoint(i32 %a) {
+; CHECK-LABEL: define void @an_entrypoint(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_instrprof_start_context(ptr @an_entrypoint_ctx_root, i64 4909520559318251808, i32 2, i32 0)
+; CHECK-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
+; CHECK-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; CHECK-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; CHECK-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; CHECK:       yes:
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
+; CHECK-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
+; CHECK-NEXT:    call void @__llvm_instrprof_release_context(ptr @an_entrypoint_ctx_root)
+; CHECK-NEXT:    ret void
+; CHECK:       no:
+; CHECK-NEXT:    call void @__llvm_instrprof_release_context(ptr @an_entrypoint_ctx_root)
+; CHECK-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  ret void
+no:
+  ret void
+}
+;.
+; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind }
+; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;.

>From 215a7c6b631dfe5668f99519ea1ede3c14dfc6a6 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 20 Mar 2024 16:07:56 -0700
Subject: [PATCH 2/5] converter (raw -> bitstream)

---
 llvm/tools/llvm-ctx-ifdo/CMakeLists.txt    |   8 ++
 llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp | 130 +++++++++++++++++++++
 2 files changed, 138 insertions(+)
 create mode 100644 llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
 create mode 100644 llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp

diff --git a/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt b/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
new file mode 100644
index 00000000000000..bdc6bb37898375
--- /dev/null
+++ b/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_LINK_COMPONENTS
+  Core
+  Support
+  )
+
+add_llvm_tool(llvm-ctx-ifdo
+  llvm-ctx-ifdo.cpp
+  )
diff --git a/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
new file mode 100644
index 00000000000000..91ab95eb64a999
--- /dev/null
+++ b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
@@ -0,0 +1,130 @@
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Bitstream/BitCodeEnums.h"
+#include "llvm/Bitstream/BitstreamWriter.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <functional>
+#include <map>
+#include <system_error>
+#include <utility>
+
+using namespace llvm;
+
+cl::opt<std::string> InputFile("raw-prof", cl::NotHidden, cl::Positional,
+                               cl::desc("<Raw contextual profile>"));
+cl::opt<std::string> OutputFile("output", cl::NotHidden, cl::Positional,
+                                cl::desc("<Converted contextual profile>"));
+
+enum Codes {
+  Invalid,
+  Guid,
+  CalleeIndex,
+  Counters,
+};
+
+struct ContextNode {
+  uint64_t Guid = 0;
+  uint64_t Next = 0;
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+};
+
+std::optional<StringRef>
+getContext(uint64_t Addr, const std::map<uint64_t, StringRef> &Pages,
+           std::function<std::optional<StringRef>()> Load) {
+  while (true) {
+    auto It = Pages.upper_bound(Addr);
+    --It;
+    if (It->first > Addr || Addr >= It->first + It->second.size()) {
+      if (!Load())
+        return std::nullopt;
+      continue;
+    }
+    assert(It->first <= Addr);
+    assert(Addr < It->first + It->second.size());
+    uint64_t Offset = Addr - It->first;
+    return It->second.substr(Offset);
+  }
+}
+
+const ContextNode *
+convertAddressToContext(uint64_t Addr,
+                        const std::map<uint64_t, StringRef> &Pages,
+                        std::function<std::optional<StringRef>()> Load) {
+  if (Addr == 0)
+    return nullptr;
+  return reinterpret_cast<const ContextNode *>(
+      getContext(Addr, Pages, Load).value().data());
+}
+
+void writeContext(StringRef FirstPage, BitstreamWriter &Writer,
+                  std::map<uint64_t, StringRef> &Pages,
+                  std::function<std::optional<StringRef>()> Load,
+                  std::optional<uint32_t> Index = std::nullopt) {
+  const auto *Root = reinterpret_cast<const ContextNode *>(FirstPage.data());
+  for (auto *N = Root; N; N = convertAddressToContext(N->Next, Pages, Load)) {
+    Writer.EnterSubblock(100, 2);
+    Writer.EmitRecord(Codes::Guid, SmallVector<uint64_t, 1>{N->Guid});
+    if (Index)
+      Writer.EmitRecord(Codes::CalleeIndex, SmallVector<uint32_t, 1>{*Index});
+    //--- these go together to emit an array
+    Writer.EmitCode(bitc::UNABBREV_RECORD);
+    Writer.EmitVBR(Codes::Counters, 6);
+    Writer.EmitVBR(N->NrCounters, 6);
+    const uint64_t *CounterStart = reinterpret_cast<const uint64_t *>(&N[1]);
+    for (auto I = 0U; I < N->NrCounters; ++I)
+      Writer.EmitVBR64(CounterStart[I], 6);
+    //---
+    auto *CallsitesStart = reinterpret_cast<const uint64_t *>(
+        reinterpret_cast<const char *>(&N[1]) +
+        sizeof(uint64_t) * N->NrCounters);
+    for (size_t I = 0; I < N->NrCallsites; ++I) {
+      uint64_t Addr = CallsitesStart[I];
+      if (!Addr)
+        continue;
+      if (auto S = getContext(Addr, Pages, Load))
+        writeContext(*S, Writer, Pages, Load, I);
+    }
+    Writer.ExitBlock();
+  }
+}
+
+int main(int argc, const char *argv[]) {
+  InitLLVM X(argc, argv);
+  cl::ParseCommandLineOptions(argc, argv,
+                              "LLVM Contextual Profile Converter\n");
+  SmallVector<char, 1 << 20> Buff;
+  std::error_code EC;
+  raw_fd_stream Out(OutputFile, EC);
+  if (EC)
+    return 1;
+  auto Input = MemoryBuffer::getFileOrSTDIN(InputFile);
+  if (!Input)
+    return 1;
+  auto In = (*Input).get()->getBuffer();
+  BitstreamWriter Writer(Buff, &Out, 0);
+  std::map<uint64_t, StringRef> Pages;
+  auto Load = [&]() -> std::optional<StringRef> {
+    if (In.size() < 2 * sizeof(uint64_t))
+      return std::nullopt;
+    auto *Data = reinterpret_cast<const uint64_t *>(In.data());
+    uint64_t Start = Data[0];
+    uint64_t Len = Data[1];
+    In = In.substr(2 * sizeof(uint64_t));
+    auto It = Pages.insert({Start, In.substr(0, Len)});
+    In = In.substr(Len);
+    return It.first->second;
+  };
+  uint32_t Magic = 0xfafababa;
+  Out.write(reinterpret_cast<const char *>(&Magic), sizeof(uint32_t));
+  while (auto S = Load()) {
+    writeContext(*S, Writer, Pages, Load);
+  }
+  Out.flush();
+  Out.close();
+  return 0;
+}
\ No newline at end of file

>From 603eccfd4b76ed4c07c9e22e43b44bef62d4adba Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 20 Mar 2024 16:15:39 -0700
Subject: [PATCH 3/5] bitstream profile data reader

---
 .../llvm/ProfileData/CtxInstrProfileReader.h  | 196 ++++++++++++++++++
 1 file changed, 196 insertions(+)
 create mode 100644 llvm/include/llvm/ProfileData/CtxInstrProfileReader.h

diff --git a/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h b/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h
new file mode 100644
index 00000000000000..7d597643df2030
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h
@@ -0,0 +1,196 @@
+//===--- CtxInstrProfileReader.h - Ctx iFDO profile reader ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+///
+/// Reader for contextual iFDO profile, in bitstream format.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_CTXINSTRPROFILEREADER_H
+#define LLVM_PROFILEDATA_CTXINSTRPROFILEREADER_H
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Bitstream/BitstreamReader.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include <map>
+#include <vector>
+
+namespace llvm {
+class ContextualProfile final {
+  friend class ContextualInstrProfReader;
+  GlobalValue::GUID GUID = 0;
+  SmallVector<uint64_t, 16> Counters;
+  std::vector<std::map<GlobalValue::GUID, ContextualProfile>> Callsites;
+
+  ContextualProfile(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters)
+      : GUID(G), Counters(std::move(Counters)) {}
+
+  Expected<ContextualProfile &>
+  getOrEmplace(uint32_t Index, GlobalValue::GUID G,
+               SmallVectorImpl<uint64_t> &&Counters) {
+    if (Callsites.size() <= Index)
+      Callsites.resize(Index + 1);
+    auto I =
+        Callsites[Index].insert({G, ContextualProfile(G, std::move(Counters))});
+    if (!I.second)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Duplicate GUID for same callsite.");
+    return I.first->second;
+  }
+
+public:
+  ContextualProfile(const ContextualProfile &) = delete;
+  ContextualProfile &operator=(const ContextualProfile &) = delete;
+  ContextualProfile(ContextualProfile &&) = default;
+  ContextualProfile &operator=(ContextualProfile &&) = default;
+
+  GlobalValue::GUID guid() const { return GUID; }
+  const SmallVector<uint64_t, 16> &counters() const { return Counters; }
+  const std::vector<std::map<GlobalValue::GUID, ContextualProfile>> &
+  callsites() const {
+    return Callsites;
+  }
+
+  void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const {
+    Guids.insert(GUID);
+    for (const auto &Callsite : Callsites)
+      for (const auto &[_, Callee] : Callsite)
+        Callee.getContainedGuids(Guids);
+  }
+};
+
+class ContextualInstrProfReader final {
+  enum Codes {
+    Invalid,
+    Guid,
+    CalleeIndex,
+    Counters,
+  };
+
+  BitstreamCursor Cursor;
+
+  struct ContextData {
+    GlobalValue::GUID GUID;
+    std::optional<uint32_t> Index;
+    SmallVector<uint64_t, 16> Counters;
+  };
+  Expected<unsigned>
+  readUnabbrevRecord(SmallVectorImpl<uint64_t> &Vals,
+                     std::optional<Codes> ExpectedCode = std::nullopt) {
+    auto Code = Cursor.ReadCode();
+    if (!Code)
+      return Code.takeError();
+    if (*Code != bitc::UNABBREV_RECORD)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Invalid code.");
+    auto Record = Cursor.readRecord(bitc::UNABBREV_RECORD, Vals);
+    if (!Record)
+      return Record.takeError();
+    if (!ExpectedCode)
+      return *Record;
+    if (*Record != *ExpectedCode)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Unexpected code.");
+    return *Record;
+  }
+
+  Expected<ContextData> readContextData() {
+    ContextData Ret;
+    SmallVector<uint64_t, 1> Data64;
+    auto GuidRec = readUnabbrevRecord(Data64, Codes::Guid);
+    if (!GuidRec)
+      return GuidRec.takeError();
+    Ret.GUID = Data64[0];
+    auto IndexOrCounters = readUnabbrevRecord(Ret.Counters);
+    if (!IndexOrCounters)
+      return IndexOrCounters.takeError();
+
+    if (*IndexOrCounters == Codes::CalleeIndex) {
+      Ret.Index = Ret.Counters[0];
+      Ret.Counters.clear();
+      auto NextRecord = readUnabbrevRecord(Ret.Counters, Codes::Counters);
+      if (!NextRecord)
+        return NextRecord.takeError();
+    } else if (*IndexOrCounters != Codes::Counters) {
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected counters.");
+    }
+    return Ret;
+  }
+
+  Error failIfCannotEnterSubBlock() {
+    auto MaybeEntry =
+        Cursor.advance(BitstreamCursor::AF_DontAutoprocessAbbrevs);
+    if (!MaybeEntry)
+      return MaybeEntry.takeError();
+    if (MaybeEntry->Kind != BitstreamEntry::SubBlock)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected a subblock.");
+    if (MaybeEntry->ID != 100)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected subblock ID 100.");
+    if (auto Err = Cursor.EnterSubBlock(MaybeEntry->ID))
+      return Err;
+    return Error::success();
+  }
+
+  Error failIfCannotReadSubContexts(ContextualProfile &Parent) {
+    while (!failIfCannotEnterSubBlock()) {
+      auto Ctx = readContextData();
+      if (!Ctx)
+        return Ctx.takeError();
+      if (!Ctx->Index)
+        return make_error<StringError>(
+            llvm::errc::invalid_argument,
+            "Invalid subcontext: should have an index.");
+      auto P =
+          Parent.getOrEmplace(*Ctx->Index, Ctx->GUID, std::move(Ctx->Counters));
+      if (!P)
+        return P.takeError();
+      auto Sub = failIfCannotReadSubContexts(*P);
+      if (Sub)
+        return Sub;
+    }
+    return Error::success();
+  }
+
+public:
+  ContextualInstrProfReader(StringRef ProfileFile) : Cursor(ProfileFile) {}
+
+  Expected<std::map<GlobalValue::GUID, ContextualProfile>> loadContexts() {
+    auto MaybeMagic = Cursor.Read(32);
+    if (!MaybeMagic)
+      return MaybeMagic.takeError();
+    if (*MaybeMagic != 0xfafababa)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Invalid magic.");
+    std::map<GlobalValue::GUID, ContextualProfile> Ret;
+    while (!failIfCannotEnterSubBlock()) {
+      auto Ctx = readContextData();
+      if (!Ctx)
+        return Ctx.takeError();
+      if (Ctx->Index)
+        return make_error<StringError>(llvm::errc::invalid_argument,
+                                       "Invalid root: should have no index.");
+      auto Ins = Ret.insert(
+          {Ctx->GUID, ContextualProfile(Ctx->GUID, std::move(Ctx->Counters))});
+      if (!Ins.second)
+        return make_error<StringError>(llvm::errc::invalid_argument,
+                                       "Duplicate GUID for same root.");
+      auto ReadRest = failIfCannotReadSubContexts(Ins.first->second);
+      if (ReadRest)
+        return ReadRest;
+    }
+    return Ret;
+  }
+};
+} // namespace llvm
+#endif
\ No newline at end of file

>From 71557485f74151470de98ad4dd8b43dcdf570ab4 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 21 Mar 2024 16:31:17 -0700
Subject: [PATCH 4/5] disable vp when ctx prof; also don't emit callsite info
 unless ctx prof

simpler testing
---
 .../lib/profile/InstrProfilingContextual.cpp  | 10 +---
 .../Instrumentation/PGOInstrumentation.cpp    | 58 ++++++++++---------
 2 files changed, 35 insertions(+), 33 deletions(-)

diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.cpp b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
index e13d0688cca50f..1acc9ad02160e0 100644
--- a/compiler-rt/lib/profile/InstrProfilingContextual.cpp
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
@@ -218,14 +218,10 @@ int __llvm_ctx_profile_dump(const char* Filename) {
     auto *Root = AllContextRoots[I];
     __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> TakenLock(
         &Root->Taken);
-    size_t NrMemUnits = 0;
-    size_t Allocated = 0;
-    for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
-      ++NrMemUnits;
-      Allocated += reinterpret_cast<uint64_t>(Mem->pos()) -
-                    reinterpret_cast<uint64_t>(Mem);
+    if (!validate(Root)) {
+      PROF_ERR("Contextual Profile is %s\n", "invalid");
+      return 1;
     }
-    auto Valid = validate(Root);
   }
 
   if (!Filename) {
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 5d393c1fda124d..45c231d4dac70e 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -336,6 +336,10 @@ bool shouldInstrumentEntryBB() {
          PGOCtxProfLoweringPass::isContextualIRPGOEnabled();
 }
 
+bool isValueProfilingDisabled() {
+  return DisableValueProfiling ||
+         PGOCtxProfLoweringPass::isContextualIRPGOEnabled();
+}
 // Return a string describing the branch condition that can be
 // used in static branch probability heuristics:
 static std::string getBranchCondString(Instruction *TI) {
@@ -884,30 +888,32 @@ static void instrumentOneFunc(
   unsigned NumCounters =
       InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
 
-  auto *CSIntrinsic =
-      Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
-  auto Visit = [&](llvm::function_ref<void(CallBase * CB)> Visitor) {
-    for (auto &BB : F)
-      for (auto &Instr : BB)
-        if (auto *CS = dyn_cast<CallBase>(&Instr)) {
-          if ((CS->getCalledFunction() &&
-               CS->getCalledFunction()->isIntrinsic()) ||
-              dyn_cast<InlineAsm>(CS->getCalledOperand()))
-            continue;
-          Visitor(CS);
-        }
-  };
-  uint32_t TotalNrCallsites = 0;
-  Visit([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
-  uint32_t CallsiteIndex = 0;
-
-  Visit([&](auto *CB){
-    IRBuilder<> Builder(CB);
-    Builder.CreateCall(CSIntrinsic,
-                       {Name, CFGHash, Builder.getInt32(TotalNrCallsites),
-                        Builder.getInt32(CallsiteIndex++),
-                        CB->getCalledOperand()});
-  });
+  if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) {
+    auto *CSIntrinsic =
+        Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
+    auto Visit = [&](llvm::function_ref<void(CallBase * CB)> Visitor) {
+      for (auto &BB : F)
+        for (auto &Instr : BB)
+          if (auto *CS = dyn_cast<CallBase>(&Instr)) {
+            if ((CS->getCalledFunction() &&
+                 CS->getCalledFunction()->isIntrinsic()) ||
+                dyn_cast<InlineAsm>(CS->getCalledOperand()))
+              continue;
+            Visitor(CS);
+          }
+    };
+    uint32_t TotalNrCallsites = 0;
+    Visit([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
+    uint32_t CallsiteIndex = 0;
+
+    Visit([&](auto *CB) {
+      IRBuilder<> Builder(CB);
+      Builder.CreateCall(CSIntrinsic,
+                         {Name, CFGHash, Builder.getInt32(TotalNrCallsites),
+                          Builder.getInt32(CallsiteIndex++),
+                          CB->getCalledOperand()});
+    });
+  }
 
   uint32_t I = 0;
   if (PGOTemporalInstrumentation) {
@@ -940,7 +946,7 @@ static void instrumentOneFunc(
                                        FuncInfo.FunctionHash);
   assert(I == NumCounters);
 
-  if (DisableValueProfiling)
+  if (isValueProfilingDisabled())
     return;
 
   NumOfPGOICall += FuncInfo.ValueSites[IPVK_IndirectCallTarget].size();
@@ -1701,7 +1707,7 @@ void SelectInstVisitor::visitSelectInst(SelectInst &SI) {
 
 // Traverse all valuesites and annotate the instructions for all value kind.
 void PGOUseFunc::annotateValueSites() {
-  if (DisableValueProfiling)
+  if (isValueProfilingDisabled())
     return;
 
   // Create the PGOFuncName meta data.

>From b586ca5fcb5a50e5958dc89979003bc8c047c8b9 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Mon, 25 Mar 2024 17:20:32 +0100
Subject: [PATCH 5/5] Integration test and fixes

---
 compiler-rt/lib/profile/CMakeLists.txt        |  1 +
 .../lib/profile/InstrProfilingContextual.cpp  | 56 ++++++++++---------
 .../lib/profile/InstrProfilingContextual.h    |  9 +--
 .../tests/InstrProfilingContextualTest.cpp    |  2 +
 compiler-rt/test/profile/Linux/contextual.cpp | 31 ++++++++++
 compiler-rt/test/profile/lit.cfg.py           |  7 +++
 llvm/test/lit.cfg.py                          |  1 +
 llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp    |  7 ++-
 8 files changed, 83 insertions(+), 31 deletions(-)
 create mode 100644 compiler-rt/test/profile/Linux/contextual.cpp

diff --git a/compiler-rt/lib/profile/CMakeLists.txt b/compiler-rt/lib/profile/CMakeLists.txt
index a2913f28ca017f..3f2f1d43b1eb9f 100644
--- a/compiler-rt/lib/profile/CMakeLists.txt
+++ b/compiler-rt/lib/profile/CMakeLists.txt
@@ -138,6 +138,7 @@ else()
   add_compiler_rt_runtime(clang_rt.profile
     STATIC
     ARCHS ${PROFILE_SUPPORTED_ARCH}
+    OBJECT_LIBS RTSanitizerCommon RTSanitizerCommonLibc
     CFLAGS ${EXTRA_FLAGS}
     SOURCES ${PROFILE_SOURCES}
     ADDITIONAL_HEADERS ${PROFILE_HEADERS}
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.cpp b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
index 1acc9ad02160e0..3ba95c83f4b5ea 100644
--- a/compiler-rt/lib/profile/InstrProfilingContextual.cpp
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
@@ -64,34 +64,34 @@ constexpr size_t kPower = 20;
 constexpr size_t kBuffSize = 1 << kPower;
 
 size_t getArenaAllocSize(size_t Needed) {
-  if (Needed >= kPower)
+  if (Needed >= kBuffSize)
     return 2 * Needed;
-  return kPower;
+  return kBuffSize;
 }
 
 bool validate(const ContextRoot *Root) {
   __sanitizer::DenseMap<uint64_t, bool> ContextStartAddrs;
-  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
-    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
-    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+  for (const auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    const auto *Pos = Mem->start();
+    while (Pos < Mem->pos()) {
+      const auto *Ctx = reinterpret_cast<const ContextNode *>(Pos);
       if (!ContextStartAddrs.insert({reinterpret_cast<uint64_t>(Ctx), true})
                .second)
         return false;
-      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
-                                            Ctx->size());
+      Pos += Ctx->size();
     }
   }
 
-  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
-    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
-    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+  for (const auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    const auto *Pos = Mem->start();
+    while (Pos < Mem->pos()) {
+      const auto *Ctx = reinterpret_cast<const ContextNode *>(Pos);
       for (uint32_t I = 0; I < Ctx->callsites_size(); ++I)
         for (auto *Sub = Ctx->subContexts()[I]; Sub; Sub = Sub->next())
           if (!ContextStartAddrs.find(reinterpret_cast<uint64_t>(Sub)))
             return false;
 
-      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
-                                            Ctx->size());
+      Pos += Ctx->size();
     }
   }
   return true;
@@ -99,9 +99,11 @@ bool validate(const ContextRoot *Root) {
 } // namespace
 
 extern "C" {
+
 __thread char __Buffer[kBuffSize] = {0};
 
-#define TheNullContext markAsScratch(reinterpret_cast<ContextNode *>(__Buffer))
+#define TheScratchContext                                                      \
+  markAsScratch(reinterpret_cast<ContextNode *>(__Buffer))
 __thread void *volatile __llvm_instrprof_expected_callee[2] = {nullptr, nullptr};
 __thread ContextNode **volatile __llvm_instrprof_callsite[2] = {0, 0};
 
@@ -128,14 +130,15 @@ COMPILER_RT_VISIBILITY ContextNode *
 __llvm_instrprof_get_context(void *Callee, GUID Guid, uint32_t NrCounters,
                             uint32_t NrCallsites) {
   if (!__llvm_instrprof_current_context_root) {
-    return TheNullContext;
+    return TheScratchContext;
   }
   auto **CallsiteContext = consume(__llvm_instrprof_callsite[0]);
   if (!CallsiteContext || isScratch(*CallsiteContext))
-    return TheNullContext;
+    return TheScratchContext;
+
   auto *ExpectedCallee = consume(__llvm_instrprof_expected_callee[0]);
   if (ExpectedCallee != Callee)
-    return TheNullContext;
+    return TheScratchContext;
 
   auto *Callsite = *CallsiteContext;
   while (Callsite && Callsite->guid() != Guid) {
@@ -146,7 +149,7 @@ __llvm_instrprof_get_context(void *Callee, GUID Guid, uint32_t NrCounters,
                              Guid, CallsiteContext, NrCounters, NrCallsites);
   if (Ret->callsites_size() != NrCallsites || Ret->counters_size() != NrCounters)
     __sanitizer::Printf("[ctxprof] Returned ctx differs from what's asked: "
-                        "Context: %p, Asked: %zu %u %u, Got: %zu %u %u \n",
+                        "Context: %p, Asked: %lu %u %u, Got: %lu %u %u \n",
                         Ret, Guid, NrCallsites, NrCounters, Ret->guid(),
                         Ret->callsites_size(), Ret->counters_size());
   Ret->onEntry();
@@ -163,6 +166,7 @@ __llvm_instprof_setup_context(ContextRoot *Root, GUID Guid, uint32_t NrCounters,
     return;
   const auto Needed = ContextNode::getAllocSize(NrCounters, NrCallsites);
   auto *M = Arena::allocate(getArenaAllocSize(Needed));
+  Root->FirstMemBlock = M;
   Root->CurrentMem = M;
   Root->FirstNode =
       ContextNode::alloc(M->tryAllocate(Needed), Guid, NrCounters, NrCallsites);
@@ -181,7 +185,7 @@ COMPILER_RT_VISIBILITY ContextNode *__llvm_instrprof_start_context(
     return Root->FirstNode;
   }
   __llvm_instrprof_current_context_root = nullptr;
-  return TheNullContext;
+  return TheScratchContext;
 }
 
 COMPILER_RT_VISIBILITY void __llvm_instrprof_release_context(ContextRoot *Root)
@@ -192,7 +196,7 @@ COMPILER_RT_VISIBILITY void __llvm_instrprof_release_context(ContextRoot *Root)
   }
 }
 
-COMPILER_RT_VISIBILITY void __llvm_profile_reset_counters(void) {
+COMPILER_RT_VISIBILITY void __llvm_profile_reset_ctx_counters(void) {
   size_t NrMemUnits = 0;
   __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
       &AllContextsMutex);
@@ -237,19 +241,21 @@ int __llvm_ctx_profile_dump(const char* Filename) {
   for (int I = 0, E = AllContextRoots.Size(); I < E; ++I) {
     const auto *Root = AllContextRoots[I];
     for (const auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
-      const uint64_t MemStartAddr =
+      const uint64_t ContextStartAddr =
           reinterpret_cast<const uint64_t>(Mem->start());
-      if (fwrite(reinterpret_cast<const char *>(&MemStartAddr),
+      if (fwrite(reinterpret_cast<const char *>(&ContextStartAddr),
                  sizeof(uint64_t), 1, F) != 1)
         return -1;
-      if (fwrite(reinterpret_cast<const char *>(&kBuffSize), sizeof(uint64_t),
-                 1, F) != 1)
+      const uint64_t Size = Mem->size();
+      if (fwrite(reinterpret_cast<const char *>(&Size), sizeof(uint64_t), 1,
+                 F) != 1)
         return -1;
-      if (fwrite(reinterpret_cast<const char *>(Mem), sizeof(char), kBuffSize,
-                 F) != kBuffSize)
+      if (fwrite(reinterpret_cast<const char *>(Mem->start()), sizeof(char),
+                 Size, F) != Size)
         return -1;
     }
   }
+  __sanitizer::Printf("[ctxprof] End Dump. Closing file.\n");
   return fclose(F);
 }
 }
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.h b/compiler-rt/lib/profile/InstrProfilingContextual.h
index 7525a32cc32b13..fa205a480ea824 100644
--- a/compiler-rt/lib/profile/InstrProfilingContextual.h
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.h
@@ -11,6 +11,7 @@
 
 #include "InstrProfiling.h"
 #include "sanitizer_common/sanitizer_mutex.h"
+#include "sanitizer_common/sanitizer_vector.h"
 
 namespace __profile {
 using GUID = uint64_t;
@@ -20,7 +21,7 @@ using GUID = uint64_t;
 class Arena final {
 public:
   static Arena *allocate(size_t Size, Arena *Prev = nullptr);
-  size_t size() const { return Size; }
+  uint64_t size() const { return Size; }
   char *tryAllocate(size_t S) {
     if (Pos + S > Size)
       return nullptr;
@@ -32,12 +33,12 @@ class Arena final {
   const char *pos() const { return start() + Pos; }
 
 private:
-  explicit Arena(size_t Size) : Size(Size) {}
+  explicit Arena(uint32_t Size) : Size(Size) {}
   char *start() { return reinterpret_cast<char *>(&this[1]); }
 
   Arena *Next = nullptr;
-  size_t Pos = 0;
-  const size_t Size;
+  uint64_t Pos = 0;
+  const uint64_t Size;
 };
 
 class ContextNode final {
diff --git a/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
index b53fbb5bda7c49..c5a0ddab2387eb 100644
--- a/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
+++ b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
@@ -26,6 +26,8 @@ TEST(ContextTest, Basic) {
   memset(&Root, 0, sizeof(ContextRoot));
   auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
   EXPECT_NE(Ctx, nullptr);
+  EXPECT_NE(Root.CurrentMem, nullptr);
+  EXPECT_EQ(Root.FirstMemBlock, Root.CurrentMem);
   EXPECT_EQ(Ctx->size(), sizeof(ContextNode) + 10 * sizeof(uint64_t) +
                              4 * sizeof(ContextNode *));
   EXPECT_EQ(Ctx->counters_size(), 10U);
diff --git a/compiler-rt/test/profile/Linux/contextual.cpp b/compiler-rt/test/profile/Linux/contextual.cpp
new file mode 100644
index 00000000000000..8be69eb4af2870
--- /dev/null
+++ b/compiler-rt/test/profile/Linux/contextual.cpp
@@ -0,0 +1,31 @@
+// RUN: %clangxx_pgogen %s -O2 -g -o %s.bin -fno-exceptions -mllvm -profile-context-root=the_root
+// RUN: %s.bin %t.rawprof
+// RUN: %llvm-ctx-ifdo %t.rawprof %t.bitstream
+// RUN: stat -c%%s %t.rawprof | FileCheck %s --check-prefix=RAW
+// RUN: llvm-bcanalyzer --dump %t.bitstream 2>&1 | FileCheck %s --check-prefix=BC
+
+// RAW: 1048592
+// BC:      <UnknownBlock100 NumWords=4 BlockCodeSize=2>
+// BC-NEXT:   <UnknownCode1 op0=-7380956406374790822/>
+// BC-NEXT:   <UnknownCode3 op0=1/>
+// BC-NEXT: </UnknownBlock100>
+
+#include <cstdio>
+extern "C" int __llvm_ctx_profile_dump(const char *Filename);
+
+extern "C" {
+void someFunction() { printf("check 2\n"); }
+
+// block inlining because the pre-inliner otherwise will inline this - it's
+// too small.
+__attribute__((noinline)) void the_root() {
+  printf("check 1\n");
+  someFunction();
+  someFunction();
+}
+}
+
+int main(int argc, char **argv) {
+  the_root();
+  return __llvm_ctx_profile_dump(argv[1]);
+}
\ No newline at end of file
diff --git a/compiler-rt/test/profile/lit.cfg.py b/compiler-rt/test/profile/lit.cfg.py
index d3ba115731c5dc..06d8b64b82c372 100644
--- a/compiler-rt/test/profile/lit.cfg.py
+++ b/compiler-rt/test/profile/lit.cfg.py
@@ -158,6 +158,13 @@ def exclude_unsupported_files_for_aix(dirname):
     )
 )
 
+config.substitutions.append(
+    (
+        "%llvm-ctx-ifdo",
+        os.path.join(config.llvm_tools_dir, "llvm-ctx-ifdo")
+    )
+)
+
 if config.host_os not in [
     "Windows",
     "Darwin",
diff --git a/llvm/test/lit.cfg.py b/llvm/test/lit.cfg.py
index 4c05317036d1a3..451e538c32f0a0 100644
--- a/llvm/test/lit.cfg.py
+++ b/llvm/test/lit.cfg.py
@@ -182,6 +182,7 @@ def get_asan_rtlib():
         "llvm-bitcode-strip",
         "llvm-config",
         "llvm-cov",
+        "llvm-ctx-ifdo",
         "llvm-cxxdump",
         "llvm-cvtres",
         "llvm-debuginfod-find",
diff --git a/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
index 91ab95eb64a999..364bd5ac6f817f 100644
--- a/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
+++ b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
@@ -3,6 +3,7 @@
 #include "llvm/Bitstream/BitCodeEnums.h"
 #include "llvm/Bitstream/BitstreamWriter.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FileSystem.h"
 #include "llvm/Support/InitLLVM.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/raw_ostream.h"
@@ -70,7 +71,7 @@ void writeContext(StringRef FirstPage, BitstreamWriter &Writer,
     Writer.EnterSubblock(100, 2);
     Writer.EmitRecord(Codes::Guid, SmallVector<uint64_t, 1>{N->Guid});
     if (Index)
-      Writer.EmitRecord(Codes::CalleeIndex, SmallVector<uint32_t, 1>{*Index});
+      p Writer.EmitRecord(Codes::CalleeIndex, SmallVector<uint32_t, 1>{*Index});
     //--- these go together to emit an array
     Writer.EmitCode(bitc::UNABBREV_RECORD);
     Writer.EmitVBR(Codes::Counters, 6);
@@ -100,8 +101,10 @@ int main(int argc, const char *argv[]) {
   SmallVector<char, 1 << 20> Buff;
   std::error_code EC;
   raw_fd_stream Out(OutputFile, EC);
-  if (EC)
+  if (EC) {
+    errs() << "Could not open output file: " << EC.message() << "\n";
     return 1;
+  }
   auto Input = MemoryBuffer::getFileOrSTDIN(InputFile);
   if (!Input)
     return 1;



More information about the llvm-commits mailing list