[compiler-rt] [llvm] [DO NOT SUBMIT] Contextual iFDO "demo" PR (PR #86036)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 16:24:02 PDT 2024


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/86036

Main components of the contextual instrumented FDO work. This is a "demo" PR, to accompany the RFC, to give an overall idea of what's involved. The change will be broken down into smaller pieces, to be reviewed piecemeal. Test coverage is minimal in this demo PR.

>From 24b15900137beedd9198e61b4620a73d96334a78 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Tue, 19 Mar 2024 20:16:35 -0700
Subject: [PATCH 1/3] WIP

---
 compiler-rt/lib/profile/CMakeLists.txt        |   5 +
 .../lib/profile/InstrProfilingContextual.cpp  | 259 ++++++++++++++++++
 .../lib/profile/InstrProfilingContextual.h    | 136 +++++++++
 compiler-rt/lib/profile/tests/CMakeLists.txt  |  72 +++++
 .../tests/InstrProfilingContextualTest.cpp    | 121 ++++++++
 compiler-rt/lib/profile/tests/driver.cpp      |  14 +
 llvm/include/llvm/IR/IntrinsicInst.h          |  13 +
 llvm/include/llvm/IR/Intrinsics.td            |   5 +
 .../Instrumentation/PGOCtxProfLowering.h      |  40 +++
 llvm/lib/IR/IntrinsicInst.cpp                 |   4 +
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassBuilderPipelines.cpp      |  24 +-
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../Transforms/Instrumentation/CMakeLists.txt |   1 +
 .../Instrumentation/InstrProfiling.cpp        |   3 +
 .../Instrumentation/PGOCtxProfLowering.cpp    | 214 +++++++++++++++
 .../Instrumentation/PGOInstrumentation.cpp    |  38 ++-
 llvm/test/Transforms/PGOProfile/ctx-basic.ll  |  86 ++++++
 18 files changed, 1023 insertions(+), 14 deletions(-)
 create mode 100644 compiler-rt/lib/profile/InstrProfilingContextual.cpp
 create mode 100644 compiler-rt/lib/profile/InstrProfilingContextual.h
 create mode 100644 compiler-rt/lib/profile/tests/CMakeLists.txt
 create mode 100644 compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
 create mode 100644 compiler-rt/lib/profile/tests/driver.cpp
 create mode 100644 llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
 create mode 100644 llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
 create mode 100644 llvm/test/Transforms/PGOProfile/ctx-basic.ll

diff --git a/compiler-rt/lib/profile/CMakeLists.txt b/compiler-rt/lib/profile/CMakeLists.txt
index 45e51648917515..a2913f28ca017f 100644
--- a/compiler-rt/lib/profile/CMakeLists.txt
+++ b/compiler-rt/lib/profile/CMakeLists.txt
@@ -51,6 +51,7 @@ add_compiler_rt_component(profile)
 set(PROFILE_SOURCES
   GCDAProfiling.c
   InstrProfiling.c
+  InstrProfilingContextual.cpp
   InstrProfilingInternal.c
   InstrProfilingValue.c
   InstrProfilingBuffer.c
@@ -142,3 +143,7 @@ else()
     ADDITIONAL_HEADERS ${PROFILE_HEADERS}
     PARENT_TARGET profile)
 endif()
+
+if(COMPILER_RT_INCLUDE_TESTS)
+  add_subdirectory(tests)
+endif()
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.cpp b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
new file mode 100644
index 00000000000000..e13d0688cca50f
--- /dev/null
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.cpp
@@ -0,0 +1,259 @@
+//===- InstrProfilingContextual.cpp - PGO runtime initialization ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "InstrProfiling.h"
+
+#include "sanitizer_common/sanitizer_atomic.h"
+#include "sanitizer_common/sanitizer_atomic_clang.h"
+#include "sanitizer_common/sanitizer_common.h"
+#include "sanitizer_common/sanitizer_dense_map.h"
+#include "sanitizer_common/sanitizer_mutex.h"
+#include "sanitizer_common/sanitizer_placement_new.h"
+#include "sanitizer_common/sanitizer_thread_safety.h"
+#include "sanitizer_common/sanitizer_vector.h"
+
+#include "InstrProfilingContextual.h"
+
+using namespace __profile;
+
+Arena *Arena::allocate(size_t Size, Arena *Prev) {
+  Arena *NewArena =
+      new (__sanitizer::InternalAlloc(Size + sizeof(Arena))) Arena(Size);
+  if (Prev)
+    Prev->Next = NewArena;
+  return NewArena;
+}
+
+inline ContextNode *ContextNode::alloc(char *Place, GUID Guid,
+                                       uint32_t NrCounters,
+                                       uint32_t NrCallsites,
+                                       ContextNode *Next) {
+  return new (Place) ContextNode(Guid, NrCounters, NrCallsites, Next);
+}
+
+void ContextNode::reset() {
+  for (uint32_t I = 0; I < NrCounters; ++I)
+    counters()[I] = 0;
+  for (uint32_t I = 0; I < NrCallsites; ++I)
+    for (auto *Next = subContexts()[I]; Next; Next = Next->Next)
+      Next->reset();
+}
+
+namespace {
+__sanitizer::SpinMutex AllContextsMutex;
+SANITIZER_GUARDED_BY(AllContextsMutex)
+__sanitizer::Vector<ContextRoot *> AllContextRoots;
+
+ContextNode * markAsScratch(const ContextNode* Ctx) {
+  return reinterpret_cast<ContextNode*>(reinterpret_cast<uint64_t>(Ctx) | 1);
+}
+
+template<typename T>
+T consume(T& V) {
+  auto R = V;
+  V = {0};
+  return R;
+}
+
+constexpr size_t kPower = 20;
+constexpr size_t kBuffSize = 1 << kPower;
+
+size_t getArenaAllocSize(size_t Needed) {
+  if (Needed >= kPower)
+    return 2 * Needed;
+  return kPower;
+}
+
+bool validate(const ContextRoot *Root) {
+  __sanitizer::DenseMap<uint64_t, bool> ContextStartAddrs;
+  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
+    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+      if (!ContextStartAddrs.insert({reinterpret_cast<uint64_t>(Ctx), true})
+               .second)
+        return false;
+      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
+                                            Ctx->size());
+    }
+  }
+
+  for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+    auto *Ctx = reinterpret_cast<ContextNode *>(Mem);
+    while (reinterpret_cast<char *>(Ctx) < Mem->pos()) {
+      for (uint32_t I = 0; I < Ctx->callsites_size(); ++I)
+        for (auto *Sub = Ctx->subContexts()[I]; Sub; Sub = Sub->next())
+          if (!ContextStartAddrs.find(reinterpret_cast<uint64_t>(Sub)))
+            return false;
+
+      Ctx = reinterpret_cast<ContextNode *>(reinterpret_cast<char *>(Ctx) +
+                                            Ctx->size());
+    }
+  }
+  return true;
+}
+} // namespace
+
+extern "C" {
+__thread char __Buffer[kBuffSize] = {0};
+
+#define TheNullContext markAsScratch(reinterpret_cast<ContextNode *>(__Buffer))
+__thread void *volatile __llvm_instrprof_expected_callee[2] = {nullptr, nullptr};
+__thread ContextNode **volatile __llvm_instrprof_callsite[2] = {0, 0};
+
+COMPILER_RT_VISIBILITY __thread ContextRoot
+    *volatile __llvm_instrprof_current_context_root = nullptr;
+
+COMPILER_RT_VISIBILITY ContextNode *
+__llvm_instprof_slow_get_callsite(uint64_t Guid, ContextNode **InsertionPoint,
+                                  uint32_t NrCounters, uint32_t NrCallsites) {
+  auto AllocSize = ContextNode::getAllocSize(NrCounters, NrCallsites);
+  auto *Mem = __llvm_instrprof_current_context_root->CurrentMem;
+  char* AllocPlace = Mem->tryAllocate(AllocSize);
+  if (!AllocPlace) {
+    __llvm_instrprof_current_context_root->CurrentMem = Mem =
+        Mem->allocate(getArenaAllocSize(AllocSize), Mem);
+  }
+  auto *Ret = ContextNode::alloc(AllocPlace, Guid, NrCounters, NrCallsites,
+                                 *InsertionPoint);
+  *InsertionPoint = Ret;
+  return Ret;
+}
+
+COMPILER_RT_VISIBILITY ContextNode *
+__llvm_instrprof_get_context(void *Callee, GUID Guid, uint32_t NrCounters,
+                            uint32_t NrCallsites) {
+  if (!__llvm_instrprof_current_context_root) {
+    return TheNullContext;
+  }
+  auto **CallsiteContext = consume(__llvm_instrprof_callsite[0]);
+  if (!CallsiteContext || isScratch(*CallsiteContext))
+    return TheNullContext;
+  auto *ExpectedCallee = consume(__llvm_instrprof_expected_callee[0]);
+  if (ExpectedCallee != Callee)
+    return TheNullContext;
+
+  auto *Callsite = *CallsiteContext;
+  while (Callsite && Callsite->guid() != Guid) {
+    Callsite = Callsite->next();
+  }
+  auto *Ret = Callsite ? Callsite
+                       : __llvm_instprof_slow_get_callsite(
+                             Guid, CallsiteContext, NrCounters, NrCallsites);
+  if (Ret->callsites_size() != NrCallsites || Ret->counters_size() != NrCounters)
+    __sanitizer::Printf("[ctxprof] Returned ctx differs from what's asked: "
+                        "Context: %p, Asked: %zu %u %u, Got: %zu %u %u \n",
+                        Ret, Guid, NrCallsites, NrCounters, Ret->guid(),
+                        Ret->callsites_size(), Ret->counters_size());
+  Ret->onEntry();
+  return Ret;
+}
+
+COMPILER_RT_VISIBILITY void
+__llvm_instprof_setup_context(ContextRoot *Root, GUID Guid, uint32_t NrCounters,
+                              uint32_t NrCallsites) {
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+  // Re-check - we got here without having had taken a lock.
+  if (Root->FirstMemBlock)
+    return;
+  const auto Needed = ContextNode::getAllocSize(NrCounters, NrCallsites);
+  auto *M = Arena::allocate(getArenaAllocSize(Needed));
+  Root->CurrentMem = M;
+  Root->FirstNode =
+      ContextNode::alloc(M->tryAllocate(Needed), Guid, NrCounters, NrCallsites);
+  AllContextRoots.PushBack(Root);
+}
+
+COMPILER_RT_VISIBILITY ContextNode *__llvm_instrprof_start_context(
+    ContextRoot *Root, GUID Guid, uint32_t Counters,
+    uint32_t Callsites) SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  if (!Root->FirstMemBlock) {
+    __llvm_instprof_setup_context(Root, Guid, Counters, Callsites);
+  }
+  if (Root->Taken.TryLock()) {
+    __llvm_instrprof_current_context_root = Root;
+    Root->FirstNode->onEntry();
+    return Root->FirstNode;
+  }
+  __llvm_instrprof_current_context_root = nullptr;
+  return TheNullContext;
+}
+
+COMPILER_RT_VISIBILITY void __llvm_instrprof_release_context(ContextRoot *Root)
+    SANITIZER_NO_THREAD_SAFETY_ANALYSIS {
+  if (__llvm_instrprof_current_context_root) {
+    __llvm_instrprof_current_context_root = nullptr;
+    Root->Taken.Unlock();
+  }
+}
+
+COMPILER_RT_VISIBILITY void __llvm_profile_reset_counters(void) {
+  size_t NrMemUnits = 0;
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+  for (uint32_t I = 0; I < AllContextRoots.Size(); ++I) {
+    auto *Root = AllContextRoots[I];
+    __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> Lock(
+        &Root->Taken);
+    for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next())
+      ++NrMemUnits;
+
+    Root->FirstNode->reset();
+  }
+  __sanitizer::Printf("[ctxprof] Initial NrMemUnits: %zu \n", NrMemUnits);
+}
+
+COMPILER_RT_VISIBILITY
+int __llvm_ctx_profile_dump(const char* Filename) {
+  __sanitizer::Printf("[ctxprof] Start Dump\n");
+  __sanitizer::GenericScopedLock<__sanitizer::SpinMutex> Lock(
+      &AllContextsMutex);
+
+  for (int I = 0, E = AllContextRoots.Size(); I < E; ++I) {
+    auto *Root = AllContextRoots[I];
+    __sanitizer::GenericScopedLock<__sanitizer::StaticSpinMutex> TakenLock(
+        &Root->Taken);
+    size_t NrMemUnits = 0;
+    size_t Allocated = 0;
+    for (auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+      ++NrMemUnits;
+      Allocated += reinterpret_cast<uint64_t>(Mem->pos()) -
+                    reinterpret_cast<uint64_t>(Mem);
+    }
+    auto Valid = validate(Root);
+  }
+
+  if (!Filename) {
+    PROF_ERR("Failed to write file : %s\n", "Filename not set");
+    return -1;
+  }
+  FILE *F = fopen(Filename, "w");
+  if (!F) {
+    PROF_ERR("Failed to open file : %s\n", Filename);
+    return -1;
+  }
+
+  for (int I = 0, E = AllContextRoots.Size(); I < E; ++I) {
+    const auto *Root = AllContextRoots[I];
+    for (const auto *Mem = Root->FirstMemBlock; Mem; Mem = Mem->next()) {
+      const uint64_t MemStartAddr =
+          reinterpret_cast<const uint64_t>(Mem->start());
+      if (fwrite(reinterpret_cast<const char *>(&MemStartAddr),
+                 sizeof(uint64_t), 1, F) != 1)
+        return -1;
+      if (fwrite(reinterpret_cast<const char *>(&kBuffSize), sizeof(uint64_t),
+                 1, F) != 1)
+        return -1;
+      if (fwrite(reinterpret_cast<const char *>(Mem), sizeof(char), kBuffSize,
+                 F) != kBuffSize)
+        return -1;
+    }
+  }
+  return fclose(F);
+}
+}
diff --git a/compiler-rt/lib/profile/InstrProfilingContextual.h b/compiler-rt/lib/profile/InstrProfilingContextual.h
new file mode 100644
index 00000000000000..7525a32cc32b13
--- /dev/null
+++ b/compiler-rt/lib/profile/InstrProfilingContextual.h
@@ -0,0 +1,136 @@
+/*===- InstrProfilingArena.h- Simple arena  -------------------------------===*\
+|*
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+|* See https://llvm.org/LICENSE.txt for license information.
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+|*
+\*===----------------------------------------------------------------------===*/
+
+#ifndef PROFILE_INSTRPROFILINGARENA_H_
+#define PROFILE_INSTRPROFILINGARENA_H_
+
+#include "InstrProfiling.h"
+#include "sanitizer_common/sanitizer_mutex.h"
+
+namespace __profile {
+using GUID = uint64_t;
+
+/// Arena forming a linked list, if more space is needed. Intentionally not
+/// thread safe.
+class Arena final {
+public:
+  static Arena *allocate(size_t Size, Arena *Prev = nullptr);
+  size_t size() const { return Size; }
+  char *tryAllocate(size_t S) {
+    if (Pos + S > Size)
+      return nullptr;
+    Pos += S;
+    return start() + (Pos - S);
+  }
+  Arena *next() const { return Next; }
+  const char *start() const { return const_cast<Arena*>(this)->start(); }
+  const char *pos() const { return start() + Pos; }
+
+private:
+  explicit Arena(size_t Size) : Size(Size) {}
+  char *start() { return reinterpret_cast<char *>(&this[1]); }
+
+  Arena *Next = nullptr;
+  size_t Pos = 0;
+  const size_t Size;
+};
+
+class ContextNode final {
+  const GUID Guid;
+  ContextNode *const Next;
+  const uint32_t NrCounters;
+  const uint32_t NrCallsites;
+
+public:
+  ContextNode(GUID Guid, uint32_t NrCounters, uint32_t NrCallsites,
+              ContextNode *Next = nullptr)
+      : Guid(Guid), Next(Next), NrCounters(NrCounters),
+        NrCallsites(NrCallsites) {}
+  static inline ContextNode *alloc(char *Place, GUID Guid, uint32_t NrCounters,
+                                   uint32_t NrCallsites,
+                                   ContextNode *Next = nullptr);
+
+  static inline size_t getAllocSize(uint32_t NrCounters, uint32_t NrCallsites) {
+    return sizeof(ContextNode) + sizeof(uint64_t) * NrCounters +
+           sizeof(ContextNode *) * NrCallsites;
+  }
+
+  uint64_t *counters() {
+    ContextNode *addr_after = &(this[1]);
+    return reinterpret_cast<uint64_t *>(reinterpret_cast<char *>(addr_after));
+  }
+
+  uint32_t counters_size() const { return NrCounters; }
+  uint32_t callsites_size() const { return NrCallsites; }
+
+  const uint64_t *counters() const {
+    return const_cast<ContextNode *>(this)->counters();
+  }
+
+  ContextNode **subContexts() {
+    return reinterpret_cast<ContextNode **>(&(counters()[NrCounters]));
+  }
+
+  ContextNode *const *subContexts() const {
+    return const_cast<ContextNode *>(this)->subContexts();
+  }
+
+  GUID guid() const { return Guid; }
+  ContextNode *next() { return Next; }
+
+  size_t size() const { return getAllocSize(NrCounters, NrCallsites); }
+
+  void reset();
+
+  void onEntry() {
+    ++counters()[0];
+  }
+
+  uint64_t entrycount() const {
+    return counters()[0];
+  }
+};
+
+
+// Exposed for test. Constructed and zero-initialized by LLVM. Implicitly,
+// LLVM must know the shape of this.
+struct ContextRoot {
+  ContextNode *FirstNode = nullptr;
+  Arena *FirstMemBlock = nullptr;
+  Arena *CurrentMem = nullptr;
+  // This is init-ed by the static zero initializer in LLVM.
+  ::__sanitizer::StaticSpinMutex Taken;
+};
+
+inline bool isScratch(const ContextNode* Ctx) {
+  return (reinterpret_cast<uint64_t>(Ctx) & 1);
+}
+
+} // namespace __profile
+
+extern "C" {
+
+// position 0 is used when the current context isn't scratch, 1 when it is.
+extern __thread void *volatile __llvm_instrprof_expected_callee[2];
+extern __thread __profile::ContextNode **volatile __llvm_instrprof_callsite[2];
+
+extern __thread __profile::ContextRoot
+    *volatile __llvm_instrprof_current_context_root;
+
+COMPILER_RT_VISIBILITY __profile::ContextNode *
+__llvm_instrprof_start_context(__profile::ContextRoot *Root,
+                              __profile::GUID Guid, uint32_t Counters,
+                              uint32_t Callsites);
+COMPILER_RT_VISIBILITY void
+__llvm_instrprof_release_context(__profile::ContextRoot *Root);
+
+COMPILER_RT_VISIBILITY __profile::ContextNode *
+__llvm_instrprof_get_context(void *Callee, __profile::GUID Guid,
+                            uint32_t NrCounters, uint32_t NrCallsites);
+}
+#endif
\ No newline at end of file
diff --git a/compiler-rt/lib/profile/tests/CMakeLists.txt b/compiler-rt/lib/profile/tests/CMakeLists.txt
new file mode 100644
index 00000000000000..c221d15b541829
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/CMakeLists.txt
@@ -0,0 +1,72 @@
+include(CheckCXXCompilerFlag)
+include(CompilerRTCompile)
+include(CompilerRTLink)
+
+set(PROFILE_UNITTEST_CFLAGS
+  ${COMPILER_RT_UNITTEST_CFLAGS}
+  ${COMPILER_RT_GTEST_CFLAGS}
+  ${COMPILER_RT_GMOCK_CFLAGS}
+  ${SANITIZER_TEST_CXX_CFLAGS}
+  -I${COMPILER_RT_SOURCE_DIR}/lib/
+  -DSANITIZER_COMMON_NO_REDEFINE_BUILTINS
+  -O2
+  -g
+  -fno-rtti
+  -Wno-pedantic
+  -fno-omit-frame-pointer)
+
+# Suppress warnings for gmock variadic macros for clang and gcc respectively.
+append_list_if(SUPPORTS_GNU_ZERO_VARIADIC_MACRO_ARGUMENTS_FLAG -Wno-gnu-zero-variadic-macro-arguments PROFILE_UNITTEST_CFLAGS)
+append_list_if(COMPILER_RT_HAS_WVARIADIC_MACROS_FLAG -Wno-variadic-macros PROFILE_UNITTEST_CFLAGS)
+
+file(GLOB PROFILE_HEADERS ../*.h)
+
+set(PROFILE_SOURCES
+  ../InstrProfilingContextual.cpp)
+
+set(PROFILE_UNITTESTS
+  InstrProfilingContextualTest.cpp
+  driver.cpp)
+
+include_directories(../../../include)
+
+set(PROFILE_UNIT_TEST_HEADERS
+  ${PROFILE_HEADERS})
+
+set(PROFILE_UNITTEST_LINK_FLAGS
+  ${COMPILER_RT_UNITTEST_LINK_FLAGS})
+
+if(NOT WIN32)
+  list(APPEND PROFILE_UNITTEST_LINK_FLAGS -pthread)
+endif()
+
+set(PROFILE_UNITTEST_LINK_LIBRARIES
+  ${COMPILER_RT_UNWINDER_LINK_LIBS}
+  ${SANITIZER_TEST_CXX_LIBRARIES})
+list(APPEND PROFILE_UNITTEST_LINK_LIBRARIES "dl")
+
+if(COMPILER_RT_DEFAULT_TARGET_ARCH IN_LIST PROFILE_SUPPORTED_ARCH)
+  # Profile unit tests are only run on the host machine.
+  set(arch ${COMPILER_RT_DEFAULT_TARGET_ARCH})
+
+  add_executable(ProfileUnitTests 
+    ${PROFILE_UNITTESTS}
+    ${COMPILER_RT_GTEST_SOURCE}
+    ${COMPILER_RT_GMOCK_SOURCE}
+    ${PROFILE_SOURCES}
+    $<TARGET_OBJECTS:RTSanitizerCommon.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonCoverage.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonLibc.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonSymbolizer.${arch}>
+    $<TARGET_OBJECTS:RTSanitizerCommonSymbolizerInternal.${arch}>)
+  set_target_compile_flags(ProfileUnitTests ${PROFILE_UNITTEST_CFLAGS})
+  set_target_link_flags(ProfileUnitTests ${PROFILE_UNITTEST_LINK_FLAGS})
+  target_link_libraries(ProfileUnitTests ${PROFILE_UNITTEST_LINK_LIBRARIES})
+
+  if (TARGET cxx-headers OR HAVE_LIBCXX)
+    add_dependencies(ProfileUnitTests cxx-headers)
+  endif()
+
+  set_target_properties(ProfileUnitTests PROPERTIES
+    RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
+endif()
diff --git a/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
new file mode 100644
index 00000000000000..b53fbb5bda7c49
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/InstrProfilingContextualTest.cpp
@@ -0,0 +1,121 @@
+#include "../InstrProfilingContextual.h"
+#include "gtest/gtest.h"
+#include <mutex>
+#include <thread>
+
+using namespace __profile;
+
+TEST(ArenaTest, Basic) {
+  Arena * A = Arena::allocate(1024);
+  EXPECT_EQ(A->size(), 1024U);
+  EXPECT_EQ(A->next(), nullptr);
+
+  auto *M1 = A->tryAllocate(1020); 
+  EXPECT_NE(M1, nullptr);
+  auto *M2 = A->tryAllocate(4);
+  EXPECT_NE(M2, nullptr);
+  EXPECT_EQ(M1 + 1020, M2);
+  EXPECT_EQ(A->tryAllocate(1), nullptr);
+  Arena *A2 = Arena::allocate(2024, A);
+  EXPECT_EQ(A->next(), A2);
+  EXPECT_EQ(A2->next(), nullptr);
+}
+
+TEST(ContextTest, Basic) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  EXPECT_NE(Ctx, nullptr);
+  EXPECT_EQ(Ctx->size(), sizeof(ContextNode) + 10 * sizeof(uint64_t) +
+                             4 * sizeof(ContextNode *));
+  EXPECT_EQ(Ctx->counters_size(), 10U);
+  EXPECT_EQ(Ctx->callsites_size(), 4U);
+  EXPECT_EQ(__llvm_instrprof_current_context_root, &Root);
+  Root.Taken.CheckLocked();
+  EXPECT_FALSE(Root.Taken.TryLock());
+  __llvm_instrprof_release_context(&Root);
+  EXPECT_EQ(__llvm_instrprof_current_context_root, nullptr);
+  EXPECT_TRUE(Root.Taken.TryLock());
+  Root.Taken.Unlock();
+}
+
+TEST(ContextTest, Callsite) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  int OpaqueValue = 0;
+  const bool IsScratch = isScratch(Ctx);
+  EXPECT_FALSE(IsScratch);
+  __llvm_instrprof_expected_callee[0] = &OpaqueValue;
+  __llvm_instrprof_callsite[0] = &Ctx->subContexts()[2];
+  auto *Subctx = __llvm_instrprof_get_context(&OpaqueValue, 2, 3, 1);
+  EXPECT_EQ(Ctx->subContexts()[2], Subctx);
+  EXPECT_EQ(Subctx->counters_size(), 3U);
+  EXPECT_EQ(Subctx->callsites_size(), 1U);
+  EXPECT_EQ(__llvm_instrprof_expected_callee[0], nullptr);
+  EXPECT_EQ(__llvm_instrprof_callsite[0], nullptr);
+  
+  EXPECT_EQ(Subctx->size(), sizeof(ContextNode) + 3 * sizeof(uint64_t) +
+                                1 * sizeof(ContextNode *));
+  __llvm_instrprof_release_context(&Root);
+}
+
+TEST(ContextTest, ScratchNoCollection) {
+  EXPECT_EQ(__llvm_instrprof_current_context_root, nullptr);
+  int OpaqueValue = 0;
+  // this would be the very first function executing this. the TLS is empty,
+  // too.
+  auto *Ctx = __llvm_instrprof_get_context(&OpaqueValue, 2, 3, 1);
+  EXPECT_TRUE(isScratch(Ctx));
+}
+
+TEST(ContextTest, ScratchDuringCollection) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+  int OpaqueValue = 0;
+  int OtherOpaqueValue = 0;
+  __llvm_instrprof_expected_callee[0] = &OpaqueValue;
+  __llvm_instrprof_callsite[0] = &Ctx->subContexts()[2];
+  auto *Subctx = __llvm_instrprof_get_context(&OtherOpaqueValue, 2, 3, 1);
+  EXPECT_TRUE(isScratch(Subctx));
+  EXPECT_EQ(__llvm_instrprof_expected_callee[0], nullptr);
+  EXPECT_EQ(__llvm_instrprof_callsite[0], nullptr);
+  
+  int ThirdOpaqueValue = 0;
+  __llvm_instrprof_expected_callee[1] = &ThirdOpaqueValue;
+  __llvm_instrprof_callsite[1] = &Subctx->subContexts()[0];
+
+  auto *Subctx2 = __llvm_instrprof_get_context(&ThirdOpaqueValue, 3, 0, 0);
+  EXPECT_TRUE(isScratch(Subctx2));
+  
+  __llvm_instrprof_release_context(&Root);
+}
+
+TEST(ContextTest, ConcurrentRootCollection) {
+  ContextRoot Root;
+  memset(&Root, 0, sizeof(ContextRoot));
+  std::atomic<int> NonScratch = 0;
+  std::atomic<int> Executions = 0;
+
+  __sanitizer::Semaphore GotCtx;
+
+  auto Entrypoint = [&]() {
+    ++Executions;
+    auto *Ctx = __llvm_instrprof_start_context(&Root, 1, 10, 4);
+    GotCtx.Post();
+    const bool IS = isScratch(Ctx);
+    NonScratch += (!IS);
+    if (!IS) {
+      GotCtx.Wait();
+      GotCtx.Wait();
+    }
+    __llvm_instrprof_release_context(&Root);
+  };
+  std::thread T1(Entrypoint);
+  std::thread T2(Entrypoint);
+  T1.join();
+  T2.join();
+  EXPECT_EQ(NonScratch, 1);
+  EXPECT_EQ(Executions, 2);
+}
\ No newline at end of file
diff --git a/compiler-rt/lib/profile/tests/driver.cpp b/compiler-rt/lib/profile/tests/driver.cpp
new file mode 100644
index 00000000000000..b402cec1126b33
--- /dev/null
+++ b/compiler-rt/lib/profile/tests/driver.cpp
@@ -0,0 +1,14 @@
+//===-- driver.cpp ----------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "gtest/gtest.h"
+
+int main(int argc, char **argv) {
+  testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index c07b83a81a63e1..6a4cce541aef29 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1435,6 +1435,7 @@ class InstrProfInstBase : public IntrinsicInst {
     case Intrinsic::instrprof_cover:
     case Intrinsic::instrprof_increment:
     case Intrinsic::instrprof_increment_step:
+    case Intrinsic::instrprof_callsite:
     case Intrinsic::instrprof_timestamp:
     case Intrinsic::instrprof_value_profile:
       return true;
@@ -1508,6 +1509,18 @@ class InstrProfIncrementInst : public InstrProfCntrInstBase {
   Value *getStep() const;
 };
 
+/// This represents the llvm.instrprof.increment intrinsic.
+class InstrProfCallsite : public InstrProfCntrInstBase {
+public:
+  static bool classof(const IntrinsicInst *I) {
+    return I->getIntrinsicID() == Intrinsic::instrprof_callsite;
+  }
+  static bool classof(const Value *V) {
+    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+  }
+  Value *getCallee() const;
+};
+
 /// This represents the llvm.instrprof.increment.step intrinsic.
 class InstrProfIncrementInstStep : public InstrProfIncrementInst {
 public:
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index 144298fd7c0162..af57042f0e8e14 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -909,6 +909,11 @@ def int_instrprof_increment_step : Intrinsic<[],
                                         [llvm_ptr_ty, llvm_i64_ty,
                                          llvm_i32_ty, llvm_i32_ty, llvm_i64_ty]>;
 
+// Callsite instrumentation for contextual profiling
+def int_instrprof_callsite : Intrinsic<[],
+                                        [llvm_ptr_ty, llvm_i64_ty,
+                                         llvm_i32_ty, llvm_i32_ty, llvm_ptr_ty]>;
+
 // A timestamp for instrumentation based profiling.
 def int_instrprof_timestamp : Intrinsic<[], [llvm_ptr_ty, llvm_i64_ty,
                                              llvm_i32_ty, llvm_i32_ty]>;
diff --git a/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
new file mode 100644
index 00000000000000..975775f18a5152
--- /dev/null
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
@@ -0,0 +1,40 @@
+//===--------- Definition of the MemProfiler class --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the MemProfiler class.
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
+#define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+class Type;
+
+class PGOCtxProfLoweringPass : public PassInfoMixin<PGOCtxProfLoweringPass> {
+public:
+  explicit PGOCtxProfLoweringPass() = default;
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+  static bool isContextualIRPGOEnabled();
+
+private:
+  Type *ContextNodeTy = nullptr;
+  Type *ContextRootTy = nullptr;
+
+  DenseMap<const Function*, Constant*> ContextRootMap;
+  Function *StartCtx = nullptr;
+  Function *GetCtx = nullptr;
+  Function *ReleaseCtx = nullptr;
+  GlobalVariable *ExpectedCalleeTLS = nullptr;
+  GlobalVariable *CallsiteInfoTLS = nullptr;
+
+  void lowerFunction(Function &F);
+};
+} // namespace llvm
+#endif
\ No newline at end of file
diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp
index 89403e1d7fcb4d..9bf8dae89458f3 100644
--- a/llvm/lib/IR/IntrinsicInst.cpp
+++ b/llvm/lib/IR/IntrinsicInst.cpp
@@ -291,6 +291,10 @@ Value *InstrProfIncrementInst::getStep() const {
   return ConstantInt::get(Type::getInt64Ty(Context), 1);
 }
 
+Value *InstrProfCallsite::getCallee() const {
+  return getArgOperand(4);
+}
+
 std::optional<RoundingMode> ConstrainedFPIntrinsic::getRoundingMode() const {
   unsigned NumOperands = arg_size();
   Metadata *MD = nullptr;
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 4d1eb10d2d41c6..ffc23988284196 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -174,6 +174,7 @@
 #include "llvm/Transforms/Instrumentation/KCFI.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
 #include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Instrumentation/PoisonChecking.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index cb892e30c4a0b9..d2550774400d69 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -74,6 +74,7 @@
 #include "llvm/Transforms/Instrumentation/InstrOrderFile.h"
 #include "llvm/Transforms/Instrumentation/InstrProfiling.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Scalar/ADCE.h"
@@ -826,16 +827,19 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
             /*UseBlockFrequencyInfo=*/false),
         PTO.EagerlyInvalidateAnalyses));
   }
-
-  // Add the profile lowering pass.
-  InstrProfOptions Options;
-  if (!ProfileFile.empty())
-    Options.InstrProfileOutput = ProfileFile;
-  // Do counter promotion at Level greater than O0.
-  Options.DoCounterPromotion = true;
-  Options.UseBFIInPromotion = IsCS;
-  Options.Atomic = AtomicCounterUpdate;
-  MPM.addPass(InstrProfilingLoweringPass(Options, IsCS));
+  if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) {
+    MPM.addPass(PGOCtxProfLoweringPass());
+  } else {
+    // Add the profile lowering pass.
+    InstrProfOptions Options;
+    if (!ProfileFile.empty())
+      Options.InstrProfileOutput = ProfileFile;
+    // Do counter promotion at Level greater than O0.
+    Options.DoCounterPromotion = true;
+    Options.UseBFIInPromotion = IsCS;
+    Options.Atomic = AtomicCounterUpdate;
+    MPM.addPass(InstrProfilingLoweringPass(Options, IsCS));
+  }
 }
 
 void PassBuilder::addPGOInstrPassesForO0(
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 41f16d0915bf23..4b6f927d5b6c46 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -77,6 +77,7 @@ MODULE_PASS("inliner-wrapper-no-mandatory-first",
 MODULE_PASS("insert-gcov-profiling", GCOVProfilerPass())
 MODULE_PASS("instrorderfile", InstrOrderFilePass())
 MODULE_PASS("instrprof", InstrProfilingLoweringPass())
+MODULE_PASS("pgo-ctx-instr-lower", PGOCtxProfLoweringPass())
 MODULE_PASS("internalize", InternalizePass())
 MODULE_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 MODULE_PASS("iroutliner", IROutlinerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
index b23a6ed1f08415..e3a1a1a7209e5a 100644
--- a/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
+++ b/llvm/lib/Transforms/Instrumentation/CMakeLists.txt
@@ -13,6 +13,7 @@ add_llvm_component_library(LLVMInstrumentation
   InstrOrderFile.cpp
   InstrProfiling.cpp
   KCFI.cpp
+  PGOCtxProfLowering.cpp
   PGOForceFunctionAttrs.cpp
   PGOInstrumentation.cpp
   PGOMemOPSizeOpt.cpp
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index c42c53edd51190..8f27c4738f7ef6 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -636,6 +636,9 @@ bool InstrLowerer::lowerIntrinsics(Function *F) {
       } else if (auto *IPTU = dyn_cast<InstrProfMCDCCondBitmapUpdate>(&Instr)) {
         lowerMCDCCondBitmapUpdate(IPTU);
         MadeChange = true;
+      } else if (isa<InstrProfCallsite>(Instr)) {
+        Instr.eraseFromParent();
+        MadeChange = true;
       }
     }
   }
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
new file mode 100644
index 00000000000000..1212f7acf53ca7
--- /dev/null
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -0,0 +1,214 @@
+//===- PGOCtxProfLowering.cpp - Contextual  PGO Instrumentation -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
+#include "llvm/IR/Analysis.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
+
+using namespace llvm;
+
+static cl::list<std::string> ContextRoots("profile-context-root");
+
+bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
+  return !ContextRoots.empty();
+}
+
+PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
+                                              ModuleAnalysisManager &MAM) {
+  ContextRootTy = nullptr;
+  auto *PointerTy = PointerType::get(M.getContext(), 0);
+  auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
+  auto *I32Ty = Type::getInt32Ty(M.getContext());
+  auto *I64Ty = Type::getInt64Ty(M.getContext());
+
+  ContextRootTy =
+      StructType::get(M.getContext(), {
+                                          PointerTy,          /*FirstNode*/
+                                          PointerTy,          /*FirstMemBlock*/
+                                          PointerTy,          /*CurrentMem*/
+                                          SanitizerMutexType, /*Taken*/
+                                      });
+  ContextNodeTy = StructType::get(M.getContext(), {
+                                                      I64Ty,     /*Guid*/
+                                                      PointerTy, /*Next*/
+                                                      I32Ty,     /*NrCounters*/
+                                                      I32Ty,     /*NrCallsites*/
+                                                  });
+
+  for (const auto &Fname : ContextRoots) {
+    if (const auto *F = M.getFunction(Fname)) {
+      if (F->isDeclaration())
+        continue;
+      auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
+      cast<GlobalVariable>(G)->setInitializer(
+          Constant::getNullValue(ContextRootTy));
+      ContextRootMap.insert(std::make_pair(F, G));
+    }
+  }
+
+  StartCtx = cast<Function>(
+      M.getOrInsertFunction(
+           "__llvm_instrprof_start_context",
+           FunctionType::get(ContextNodeTy->getPointerTo(),
+                             {ContextRootTy->getPointerTo(), /*ContextRoot*/
+                              I64Ty, /*Guid*/ I32Ty,
+                              /*NrCounters*/ I32Ty /*NrCallsites*/},
+                             false))
+          .getCallee());
+  GetCtx = cast<Function>(
+      M.getOrInsertFunction("__llvm_instrprof_get_context",
+                            FunctionType::get(ContextNodeTy->getPointerTo(),
+                                              {PointerTy, /*Callee*/
+                                               I64Ty,     /*Guid*/
+                                               I32Ty,     /*NrCounters*/
+                                               I32Ty},    /*NrCallsites*/
+                                              false))
+          .getCallee());
+  ReleaseCtx = cast<Function>(
+      M.getOrInsertFunction(
+           "__llvm_instrprof_release_context",
+           FunctionType::get(Type::getVoidTy(M.getContext()),
+                             {
+                                 ContextRootTy->getPointerTo(), /*ContextRoot*/
+                             },
+                             false))
+          .getCallee());
+  CallsiteInfoTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, "__llvm_instrprof_callsite");
+  CallsiteInfoTLS->setThreadLocal(true);
+  CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  ExpectedCalleeTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, "__llvm_instrprof_expected_callee");
+  ExpectedCalleeTLS->setThreadLocal(true);
+  ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  
+  for (auto &F : M)
+    lowerFunction(F);
+  return PreservedAnalyses::none();
+}
+
+void PGOCtxProfLoweringPass::lowerFunction(Function &F) {
+  if (F.isDeclaration())
+    return;
+
+  Value *Guid = nullptr;
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+  [&]() {
+    for (const auto &BB : F)
+      for (const auto &I : BB) {
+        if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
+          if (!NrCounters)
+            NrCounters = static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+        } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
+          if (!NrCallsites)
+            NrCallsites =
+                static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+        }
+        if (NrCounters && NrCallsites)
+          return;
+      }
+  }();
+
+  Value *Context = nullptr;
+  Value *RealContext = nullptr;
+
+  StructType *ThisContextType = nullptr;
+  Value* TheRootContext = nullptr;
+  Value *ExpectedCalleeTLSAddr = nullptr;
+  Value *CallsiteInfoTLSAddr = nullptr;
+
+  auto &Head = F.getEntryBlock();
+  for (auto &I : Head) {
+    if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
+      assert(Mark->getIndex()->isZero());
+
+      IRBuilder<> Builder(Mark);
+      // TODO!!!! use InstrProfSymtab::getCanonicalName
+      Guid = Builder.getInt64(F.getGUID());
+      ThisContextType = StructType::get(
+          F.getContext(),
+          {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
+           ArrayType::get(Builder.getPtrTy(), NrCallsites)});
+      auto Iter = ContextRootMap.find(&F);
+      if (Iter != ContextRootMap.end()) {
+        TheRootContext = Iter->second;
+        Context = Builder.CreateCall(
+            StartCtx,
+            {TheRootContext, Guid, Builder.getInt32(NrCounters),
+             Builder.getInt32(NrCallsites)});
+      } else {
+        Context =
+            Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
+                                        Builder.getInt32(NrCallsites)});
+      }
+      auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
+      if (NrCallsites > 0) {
+        auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
+        ExpectedCalleeTLSAddr = Builder.CreateGEP(
+            Builder.getInt8Ty()->getPointerTo(),
+            Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
+        CallsiteInfoTLSAddr = Builder.CreateGEP(
+            Builder.getInt32Ty(),
+            Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
+      }
+      RealContext = Builder.CreateIntToPtr(
+          Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
+          ThisContextType->getPointerTo());
+      I.eraseFromParent();
+      break;
+    }
+  }
+  if(!Context) {
+    dbgs() << "[instprof] Function doesn't have instrumentation, skipping "
+           << F.getName() << "\n";
+    return;
+  }
+
+  for (auto &BB : F) {
+    for (auto &I : llvm::make_early_inc_range(BB)) {
+      if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
+        IRBuilder<> Builder(Instr);
+        switch (Instr->getIntrinsicID()) {
+        case llvm::Intrinsic::instrprof_increment:
+        case llvm::Intrinsic::instrprof_increment_step: {
+          auto *AsStep = cast<InstrProfIncrementInst>(Instr);
+          auto *GEP = Builder.CreateGEP(
+              ThisContextType, RealContext,
+              {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
+          Builder.CreateStore(
+              Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
+                                AsStep->getStep()),
+              GEP);
+        } break;
+        case llvm::Intrinsic::instrprof_callsite:
+          auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
+          Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
+                              true);
+          Builder.CreateStore(
+              Builder.CreateGEP(ThisContextType, Context,
+                                {Builder.getInt32(0), Builder.getInt32(2),
+                                 CSIntrinsic->getIndex()}),
+              CallsiteInfoTLSAddr, true);
+          break;
+        }
+        I.eraseFromParent();
+      } else if (TheRootContext && isa<ReturnInst>(I)) {
+        IRBuilder<> Builder(&I);
+        Builder.CreateCall(ReleaseCtx, {TheRootContext});
+      }
+    }
+  }
+}
\ No newline at end of file
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 55728709cde556..5d393c1fda124d 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -110,6 +110,7 @@
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
 #include "llvm/Transforms/Instrumentation/CFGMST.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -330,6 +331,11 @@ extern cl::opt<std::string> ViewBlockFreqFuncName;
 extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
 } // namespace llvm
 
+bool shouldInstrumentEntryBB() {
+  return PGOInstrumentEntry ||
+         PGOCtxProfLoweringPass::isContextualIRPGOEnabled();
+}
+
 // Return a string describing the branch condition that can be
 // used in static branch probability heuristics:
 static std::string getBranchCondString(Instruction *TI) {
@@ -376,7 +382,7 @@ static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
   if (IsCS)
     ProfileVersion |= VARIANT_MASK_CSIR_PROF;
-  if (PGOInstrumentEntry)
+  if (shouldInstrumentEntryBB())
     ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
   if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
     ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -856,7 +862,7 @@ static void instrumentOneFunc(
   }
 
   FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
-      F, TLI, ComdatMembers, true, BPI, BFI, IsCS, PGOInstrumentEntry,
+      F, TLI, ComdatMembers, true, BPI, BFI, IsCS, shouldInstrumentEntryBB(),
       PGOBlockCoverage);
 
   auto Name = FuncInfo.FuncNameVar;
@@ -878,6 +884,31 @@ static void instrumentOneFunc(
   unsigned NumCounters =
       InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
 
+  auto *CSIntrinsic =
+      Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
+  auto Visit = [&](llvm::function_ref<void(CallBase * CB)> Visitor) {
+    for (auto &BB : F)
+      for (auto &Instr : BB)
+        if (auto *CS = dyn_cast<CallBase>(&Instr)) {
+          if ((CS->getCalledFunction() &&
+               CS->getCalledFunction()->isIntrinsic()) ||
+              dyn_cast<InlineAsm>(CS->getCalledOperand()))
+            continue;
+          Visitor(CS);
+        }
+  };
+  uint32_t TotalNrCallsites = 0;
+  Visit([&TotalNrCallsites](auto *) { ++TotalNrCallsites; });
+  uint32_t CallsiteIndex = 0;
+
+  Visit([&](auto *CB){
+    IRBuilder<> Builder(CB);
+    Builder.CreateCall(CSIntrinsic,
+                       {Name, CFGHash, Builder.getInt32(TotalNrCallsites),
+                        Builder.getInt32(CallsiteIndex++),
+                        CB->getCalledOperand()});
+  });
+
   uint32_t I = 0;
   if (PGOTemporalInstrumentation) {
     NumCounters += PGOBlockCoverage ? 8 : 1;
@@ -2001,8 +2032,7 @@ static bool annotateAllFunctions(
   // If the profile marked as always instrument the entry BB, do the
   // same. Note this can be overwritten by the internal option in CFGMST.h
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
-  if (PGOInstrumentEntry.getNumOccurrences() > 0)
-    InstrumentFuncEntry = PGOInstrumentEntry;
+  InstrumentFuncEntry = shouldInstrumentEntryBB();
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {
     if (skipPGOUse(F))
diff --git a/llvm/test/Transforms/PGOProfile/ctx-basic.ll b/llvm/test/Transforms/PGOProfile/ctx-basic.ll
new file mode 100644
index 00000000000000..5b7df61a033d48
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/ctx-basic.ll
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
+; RUN: opt -passes=pgo-instr-gen,pgo-ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN:   -S < %s | FileCheck %s
+
+declare void @bar()
+
+;.
+; CHECK: @__llvm_profile_raw_version = hidden constant i64 360287970189639690, comdat
+; CHECK: @__profn_foo = private constant [3 x i8] c"foo"
+; CHECK: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
+; CHECK: @an_entrypoint_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
+; CHECK: @__llvm_instrprof_callsite = external hidden thread_local global ptr
+; CHECK: @__llvm_instrprof_expected_callee = external hidden thread_local global ptr
+;.
+define void @foo(i32 %a) {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_instrprof_get_context(ptr @foo, i64 6699318081062747564, i32 2, i32 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_instrprof_expected_callee)
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_instrprof_callsite)
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; CHECK-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; CHECK-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; CHECK-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; CHECK:       yes:
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP9]], i32 0, i32 1, i32 1
+; CHECK-NEXT:    [[TMP11:%.*]] = load i64, ptr [[TMP10]], align 4
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[TMP11]], 1
+; CHECK-NEXT:    store i64 [[TMP12]], ptr [[TMP10]], align 4
+; CHECK-NEXT:    br label [[EXIT:%.*]]
+; CHECK:       no:
+; CHECK-NEXT:    store volatile ptr @bar, ptr [[TMP5]], align 8
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; CHECK-NEXT:    store volatile ptr [[TMP13]], ptr [[TMP7]], align 8
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    br label [[EXIT]]
+; CHECK:       exit:
+; CHECK-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+yes:
+  br label %exit
+no:
+  call void @bar()
+  br label %exit
+exit:
+  ret void
+}
+
+define void @an_entrypoint(i32 %a) {
+; CHECK-LABEL: define void @an_entrypoint(
+; CHECK-SAME: i32 [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_instrprof_start_context(ptr @an_entrypoint_ctx_root, i64 4909520559318251808, i32 2, i32 0)
+; CHECK-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
+; CHECK-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; CHECK-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; CHECK-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; CHECK:       yes:
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
+; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
+; CHECK-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
+; CHECK-NEXT:    call void @__llvm_instrprof_release_context(ptr @an_entrypoint_ctx_root)
+; CHECK-NEXT:    ret void
+; CHECK:       no:
+; CHECK-NEXT:    call void @__llvm_instrprof_release_context(ptr @an_entrypoint_ctx_root)
+; CHECK-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  ret void
+no:
+  ret void
+}
+;.
+; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind }
+; CHECK: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;.

>From 215a7c6b631dfe5668f99519ea1ede3c14dfc6a6 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 20 Mar 2024 16:07:56 -0700
Subject: [PATCH 2/3] converter (raw -> bitstream)

---
 llvm/tools/llvm-ctx-ifdo/CMakeLists.txt    |   8 ++
 llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp | 130 +++++++++++++++++++++
 2 files changed, 138 insertions(+)
 create mode 100644 llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
 create mode 100644 llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp

diff --git a/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt b/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
new file mode 100644
index 00000000000000..bdc6bb37898375
--- /dev/null
+++ b/llvm/tools/llvm-ctx-ifdo/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_LINK_COMPONENTS
+  Core
+  Support
+  )
+
+add_llvm_tool(llvm-ctx-ifdo
+  llvm-ctx-ifdo.cpp
+  )
diff --git a/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
new file mode 100644
index 00000000000000..91ab95eb64a999
--- /dev/null
+++ b/llvm/tools/llvm-ctx-ifdo/llvm-ctx-ifdo.cpp
@@ -0,0 +1,130 @@
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Bitstream/BitCodeEnums.h"
+#include "llvm/Bitstream/BitstreamWriter.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/raw_ostream.h"
+#include <cstdint>
+#include <functional>
+#include <map>
+#include <system_error>
+#include <utility>
+
+using namespace llvm;
+
+cl::opt<std::string> InputFile("raw-prof", cl::NotHidden, cl::Positional,
+                               cl::desc("<Raw contextual profile>"));
+cl::opt<std::string> OutputFile("output", cl::NotHidden, cl::Positional,
+                                cl::desc("<Converted contextual profile>"));
+
+enum Codes {
+  Invalid,
+  Guid,
+  CalleeIndex,
+  Counters,
+};
+
+struct ContextNode {
+  uint64_t Guid = 0;
+  uint64_t Next = 0;
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+};
+
+std::optional<StringRef>
+getContext(uint64_t Addr, const std::map<uint64_t, StringRef> &Pages,
+           std::function<std::optional<StringRef>()> Load) {
+  while (true) {
+    auto It = Pages.upper_bound(Addr);
+    --It;
+    if (It->first > Addr || Addr >= It->first + It->second.size()) {
+      if (!Load())
+        return std::nullopt;
+      continue;
+    }
+    assert(It->first <= Addr);
+    assert(Addr < It->first + It->second.size());
+    uint64_t Offset = Addr - It->first;
+    return It->second.substr(Offset);
+  }
+}
+
+const ContextNode *
+convertAddressToContext(uint64_t Addr,
+                        const std::map<uint64_t, StringRef> &Pages,
+                        std::function<std::optional<StringRef>()> Load) {
+  if (Addr == 0)
+    return nullptr;
+  return reinterpret_cast<const ContextNode *>(
+      getContext(Addr, Pages, Load).value().data());
+}
+
+void writeContext(StringRef FirstPage, BitstreamWriter &Writer,
+                  std::map<uint64_t, StringRef> &Pages,
+                  std::function<std::optional<StringRef>()> Load,
+                  std::optional<uint32_t> Index = std::nullopt) {
+  const auto *Root = reinterpret_cast<const ContextNode *>(FirstPage.data());
+  for (auto *N = Root; N; N = convertAddressToContext(N->Next, Pages, Load)) {
+    Writer.EnterSubblock(100, 2);
+    Writer.EmitRecord(Codes::Guid, SmallVector<uint64_t, 1>{N->Guid});
+    if (Index)
+      Writer.EmitRecord(Codes::CalleeIndex, SmallVector<uint32_t, 1>{*Index});
+    //--- these go together to emit an array
+    Writer.EmitCode(bitc::UNABBREV_RECORD);
+    Writer.EmitVBR(Codes::Counters, 6);
+    Writer.EmitVBR(N->NrCounters, 6);
+    const uint64_t *CounterStart = reinterpret_cast<const uint64_t *>(&N[1]);
+    for (auto I = 0U; I < N->NrCounters; ++I)
+      Writer.EmitVBR64(CounterStart[I], 6);
+    //---
+    auto *CallsitesStart = reinterpret_cast<const uint64_t *>(
+        reinterpret_cast<const char *>(&N[1]) +
+        sizeof(uint64_t) * N->NrCounters);
+    for (size_t I = 0; I < N->NrCallsites; ++I) {
+      uint64_t Addr = CallsitesStart[I];
+      if (!Addr)
+        continue;
+      if (auto S = getContext(Addr, Pages, Load))
+        writeContext(*S, Writer, Pages, Load, I);
+    }
+    Writer.ExitBlock();
+  }
+}
+
+int main(int argc, const char *argv[]) {
+  InitLLVM X(argc, argv);
+  cl::ParseCommandLineOptions(argc, argv,
+                              "LLVM Contextual Profile Converter\n");
+  SmallVector<char, 1 << 20> Buff;
+  std::error_code EC;
+  raw_fd_stream Out(OutputFile, EC);
+  if (EC)
+    return 1;
+  auto Input = MemoryBuffer::getFileOrSTDIN(InputFile);
+  if (!Input)
+    return 1;
+  auto In = (*Input).get()->getBuffer();
+  BitstreamWriter Writer(Buff, &Out, 0);
+  std::map<uint64_t, StringRef> Pages;
+  auto Load = [&]() -> std::optional<StringRef> {
+    if (In.size() < 2 * sizeof(uint64_t))
+      return std::nullopt;
+    auto *Data = reinterpret_cast<const uint64_t *>(In.data());
+    uint64_t Start = Data[0];
+    uint64_t Len = Data[1];
+    In = In.substr(2 * sizeof(uint64_t));
+    auto It = Pages.insert({Start, In.substr(0, Len)});
+    In = In.substr(Len);
+    return It.first->second;
+  };
+  uint32_t Magic = 0xfafababa;
+  Out.write(reinterpret_cast<const char *>(&Magic), sizeof(uint32_t));
+  while (auto S = Load()) {
+    writeContext(*S, Writer, Pages, Load);
+  }
+  Out.flush();
+  Out.close();
+  return 0;
+}
\ No newline at end of file

>From 603eccfd4b76ed4c07c9e22e43b44bef62d4adba Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 20 Mar 2024 16:15:39 -0700
Subject: [PATCH 3/3] bitstream profile data reader

---
 .../llvm/ProfileData/CtxInstrProfileReader.h  | 196 ++++++++++++++++++
 1 file changed, 196 insertions(+)
 create mode 100644 llvm/include/llvm/ProfileData/CtxInstrProfileReader.h

diff --git a/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h b/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h
new file mode 100644
index 00000000000000..7d597643df2030
--- /dev/null
+++ b/llvm/include/llvm/ProfileData/CtxInstrProfileReader.h
@@ -0,0 +1,196 @@
+//===--- CtxInstrProfileReader.h - Ctx iFDO profile reader ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+///
+/// Reader for contextual iFDO profile, in bitstream format.
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_PROFILEDATA_CTXINSTRPROFILEREADER_H
+#define LLVM_PROFILEDATA_CTXINSTRPROFILEREADER_H
+
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Bitstream/BitstreamReader.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include <map>
+#include <vector>
+
+namespace llvm {
+class ContextualProfile final {
+  friend class ContextualInstrProfReader;
+  GlobalValue::GUID GUID = 0;
+  SmallVector<uint64_t, 16> Counters;
+  std::vector<std::map<GlobalValue::GUID, ContextualProfile>> Callsites;
+
+  ContextualProfile(GlobalValue::GUID G, SmallVectorImpl<uint64_t> &&Counters)
+      : GUID(G), Counters(std::move(Counters)) {}
+
+  Expected<ContextualProfile &>
+  getOrEmplace(uint32_t Index, GlobalValue::GUID G,
+               SmallVectorImpl<uint64_t> &&Counters) {
+    if (Callsites.size() <= Index)
+      Callsites.resize(Index + 1);
+    auto I =
+        Callsites[Index].insert({G, ContextualProfile(G, std::move(Counters))});
+    if (!I.second)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Duplicate GUID for same callsite.");
+    return I.first->second;
+  }
+
+public:
+  ContextualProfile(const ContextualProfile &) = delete;
+  ContextualProfile &operator=(const ContextualProfile &) = delete;
+  ContextualProfile(ContextualProfile &&) = default;
+  ContextualProfile &operator=(ContextualProfile &&) = default;
+
+  GlobalValue::GUID guid() const { return GUID; }
+  const SmallVector<uint64_t, 16> &counters() const { return Counters; }
+  const std::vector<std::map<GlobalValue::GUID, ContextualProfile>> &
+  callsites() const {
+    return Callsites;
+  }
+
+  void getContainedGuids(DenseSet<GlobalValue::GUID> &Guids) const {
+    Guids.insert(GUID);
+    for (const auto &Callsite : Callsites)
+      for (const auto &[_, Callee] : Callsite)
+        Callee.getContainedGuids(Guids);
+  }
+};
+
+class ContextualInstrProfReader final {
+  enum Codes {
+    Invalid,
+    Guid,
+    CalleeIndex,
+    Counters,
+  };
+
+  BitstreamCursor Cursor;
+
+  struct ContextData {
+    GlobalValue::GUID GUID;
+    std::optional<uint32_t> Index;
+    SmallVector<uint64_t, 16> Counters;
+  };
+  Expected<unsigned>
+  readUnabbrevRecord(SmallVectorImpl<uint64_t> &Vals,
+                     std::optional<Codes> ExpectedCode = std::nullopt) {
+    auto Code = Cursor.ReadCode();
+    if (!Code)
+      return Code.takeError();
+    if (*Code != bitc::UNABBREV_RECORD)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Invalid code.");
+    auto Record = Cursor.readRecord(bitc::UNABBREV_RECORD, Vals);
+    if (!Record)
+      return Record.takeError();
+    if (!ExpectedCode)
+      return *Record;
+    if (*Record != *ExpectedCode)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Unexpected code.");
+    return *Record;
+  }
+
+  Expected<ContextData> readContextData() {
+    ContextData Ret;
+    SmallVector<uint64_t, 1> Data64;
+    auto GuidRec = readUnabbrevRecord(Data64, Codes::Guid);
+    if (!GuidRec)
+      return GuidRec.takeError();
+    Ret.GUID = Data64[0];
+    auto IndexOrCounters = readUnabbrevRecord(Ret.Counters);
+    if (!IndexOrCounters)
+      return IndexOrCounters.takeError();
+
+    if (*IndexOrCounters == Codes::CalleeIndex) {
+      Ret.Index = Ret.Counters[0];
+      Ret.Counters.clear();
+      auto NextRecord = readUnabbrevRecord(Ret.Counters, Codes::Counters);
+      if (!NextRecord)
+        return NextRecord.takeError();
+    } else if (*IndexOrCounters != Codes::Counters) {
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected counters.");
+    }
+    return Ret;
+  }
+
+  Error failIfCannotEnterSubBlock() {
+    auto MaybeEntry =
+        Cursor.advance(BitstreamCursor::AF_DontAutoprocessAbbrevs);
+    if (!MaybeEntry)
+      return MaybeEntry.takeError();
+    if (MaybeEntry->Kind != BitstreamEntry::SubBlock)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected a subblock.");
+    if (MaybeEntry->ID != 100)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Expected subblock ID 100.");
+    if (auto Err = Cursor.EnterSubBlock(MaybeEntry->ID))
+      return Err;
+    return Error::success();
+  }
+
+  Error failIfCannotReadSubContexts(ContextualProfile &Parent) {
+    while (!failIfCannotEnterSubBlock()) {
+      auto Ctx = readContextData();
+      if (!Ctx)
+        return Ctx.takeError();
+      if (!Ctx->Index)
+        return make_error<StringError>(
+            llvm::errc::invalid_argument,
+            "Invalid subcontext: should have an index.");
+      auto P =
+          Parent.getOrEmplace(*Ctx->Index, Ctx->GUID, std::move(Ctx->Counters));
+      if (!P)
+        return P.takeError();
+      auto Sub = failIfCannotReadSubContexts(*P);
+      if (Sub)
+        return Sub;
+    }
+    return Error::success();
+  }
+
+public:
+  ContextualInstrProfReader(StringRef ProfileFile) : Cursor(ProfileFile) {}
+
+  Expected<std::map<GlobalValue::GUID, ContextualProfile>> loadContexts() {
+    auto MaybeMagic = Cursor.Read(32);
+    if (!MaybeMagic)
+      return MaybeMagic.takeError();
+    if (*MaybeMagic != 0xfafababa)
+      return make_error<StringError>(llvm::errc::invalid_argument,
+                                     "Invalid magic.");
+    std::map<GlobalValue::GUID, ContextualProfile> Ret;
+    while (!failIfCannotEnterSubBlock()) {
+      auto Ctx = readContextData();
+      if (!Ctx)
+        return Ctx.takeError();
+      if (Ctx->Index)
+        return make_error<StringError>(llvm::errc::invalid_argument,
+                                       "Invalid root: should have no index.");
+      auto Ins = Ret.insert(
+          {Ctx->GUID, ContextualProfile(Ctx->GUID, std::move(Ctx->Counters))});
+      if (!Ins.second)
+        return make_error<StringError>(llvm::errc::invalid_argument,
+                                       "Duplicate GUID for same root.");
+      auto ReadRest = failIfCannotReadSubContexts(Ins.first->second);
+      if (ReadRest)
+        return ReadRest;
+    }
+    return Ret;
+  }
+};
+} // namespace llvm
+#endif
\ No newline at end of file



More information about the llvm-commits mailing list