[llvm] [llvm][ctx_profile] Add instrumentation lowering (PR #90821)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 15:18:26 PDT 2024


https://github.com/mtrofin updated https://github.com/llvm/llvm-project/pull/90821

>From 8aaef96563bfc5812ec6317b9d5d1cda52bef80d Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 1 May 2024 22:31:17 -0700
Subject: [PATCH 1/5] [llvm][ctx_profile] Add instrumentation lowering

This adds the instrumentation lowering pass.

(Tracking Issue: #89287, RFC referenced there)
---
 .../Instrumentation/PGOCtxProfLowering.h      |   5 +-
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassBuilderPipelines.cpp      |   5 +
 llvm/lib/Passes/PassRegistry.def              |   1 +
 .../Instrumentation/PGOCtxProfLowering.cpp    | 301 ++++++++++++++++++
 .../PGOProfile/ctx-instrumentation.ll         | 161 ++++++++++
 6 files changed, 473 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
index 38afa0c6fd32..5256aff56205 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOCtxProfLowering.h
@@ -12,13 +12,16 @@
 #ifndef LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
 #define LLVM_TRANSFORMS_INSTRUMENTATION_PGOCTXPROFLOWERING_H
 
+#include "llvm/IR/PassManager.h"
 namespace llvm {
 class Type;
 
-class PGOCtxProfLoweringPass {
+class PGOCtxProfLoweringPass : public PassInfoMixin<PGOCtxProfLoweringPass> {
 public:
   explicit PGOCtxProfLoweringPass() = default;
   static bool isContextualIRPGOEnabled();
+
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
 };
 } // namespace llvm
 #endif
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 30d3e7a1ec05..22fd2aef4ea6 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -175,6 +175,7 @@
 #include "llvm/Transforms/Instrumentation/LowerAllowCheckPass.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
 #include "llvm/Transforms/Instrumentation/MemorySanitizer.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Instrumentation/PoisonChecking.h"
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 100889c0845b..1d7f0510450c 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -74,6 +74,7 @@
 #include "llvm/Transforms/Instrumentation/InstrOrderFile.h"
 #include "llvm/Transforms/Instrumentation/InstrProfiling.h"
 #include "llvm/Transforms/Instrumentation/MemProfiler.h"
+#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Instrumentation/PGOForceFunctionAttrs.h"
 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
 #include "llvm/Transforms/Scalar/ADCE.h"
@@ -834,6 +835,10 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
         PTO.EagerlyInvalidateAnalyses));
   }
 
+  if (PGOCtxProfLoweringPass::isContextualIRPGOEnabled()) {
+    MPM.addPass(PGOCtxProfLoweringPass());
+    return;
+  }
   // Add the profile lowering pass.
   InstrProfOptions Options;
   if (!ProfileFile.empty())
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 9b670e4e3a44..8f79601d0351 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -77,6 +77,7 @@ MODULE_PASS("inliner-wrapper-no-mandatory-first",
 MODULE_PASS("insert-gcov-profiling", GCOVProfilerPass())
 MODULE_PASS("instrorderfile", InstrOrderFilePass())
 MODULE_PASS("instrprof", InstrProfilingLoweringPass())
+MODULE_PASS("pgo-ctx-instr-lower", PGOCtxProfLoweringPass())
 MODULE_PASS("internalize", InternalizePass())
 MODULE_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 MODULE_PASS("iroutliner", IROutlinerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index 9d6dd5ccb38b..7442d8010ab0 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -8,10 +8,19 @@
 //
 
 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/DiagnosticInfo.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/PassManager.h"
 #include "llvm/Support/CommandLine.h"
+#include <utility>
 
 using namespace llvm;
 
+#define DEBUG_TYPE "ctx-profile-lower"
+
 static cl::list<std::string> ContextRoots(
     "profile-context-root", cl::Hidden,
     cl::desc(
@@ -22,3 +31,295 @@ static cl::list<std::string> ContextRoots(
 bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
   return !ContextRoots.empty();
 }
+
+// the names of symbols we expect in compiler-rt. Using a namespace for
+// readability.
+namespace CompilerRtAPINames {
+static auto StartCtx = "__llvm_ctx_profile_start_context";
+static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
+static auto GetCtx = "__llvm_ctx_profile_get_context";
+static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
+static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
+} // namespace CompilerRtAPINames
+
+namespace {
+// The lowering logic and state.
+class CtxInstrumentationLowerer final {
+  Module &M;
+  ModuleAnalysisManager &MAM;
+  Type *ContextNodeTy = nullptr;
+  Type *ContextRootTy = nullptr;
+
+  DenseMap<const Function *, Constant *> ContextRootMap;
+  Function *StartCtx = nullptr;
+  Function *GetCtx = nullptr;
+  Function *ReleaseCtx = nullptr;
+  GlobalVariable *ExpectedCalleeTLS = nullptr;
+  GlobalVariable *CallsiteInfoTLS = nullptr;
+
+public:
+  CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
+  void lowerFunction(Function &F);
+};
+
+std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+  for (const auto &BB : F) {
+    for (const auto &I : BB) {
+      if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
+        if (!NrCounters)
+          NrCounters =
+              static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+      } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
+        if (!NrCallsites)
+          NrCallsites =
+              static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+      }
+      if (NrCounters && NrCallsites)
+        return std::make_pair(NrCounters, NrCallsites);
+    }
+  }
+  return {0, 0};
+}
+} // namespace
+
+// set up tie-in with compiler-rt.
+// NOTE!!!
+// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
+                                                     ModuleAnalysisManager &MAM)
+    : M(M), MAM(MAM) {
+  auto *PointerTy = PointerType::get(M.getContext(), 0);
+  auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
+  auto *I32Ty = Type::getInt32Ty(M.getContext());
+  auto *I64Ty = Type::getInt64Ty(M.getContext());
+
+  // The ContextRoot type
+  ContextRootTy =
+      StructType::get(M.getContext(), {
+                                          PointerTy,          /*FirstNode*/
+                                          PointerTy,          /*FirstMemBlock*/
+                                          PointerTy,          /*CurrentMem*/
+                                          SanitizerMutexType, /*Taken*/
+                                      });
+  // The Context header.
+  ContextNodeTy = StructType::get(M.getContext(), {
+                                                      I64Ty,     /*Guid*/
+                                                      PointerTy, /*Next*/
+                                                      I32Ty,     /*NrCounters*/
+                                                      I32Ty,     /*NrCallsites*/
+                                                  });
+
+  // Define a global for each entrypoint. We'll reuse the entrypoint's name as
+  // prefix. We assume the entrypoint names to be unique.
+  for (const auto &Fname : ContextRoots) {
+    if (const auto *F = M.getFunction(Fname)) {
+      if (F->isDeclaration())
+        continue;
+      auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
+      cast<GlobalVariable>(G)->setInitializer(
+          Constant::getNullValue(ContextRootTy));
+      ContextRootMap.insert(std::make_pair(F, G));
+    }
+  }
+
+  // Declare the functions we will call.
+  StartCtx = cast<Function>(
+      M.getOrInsertFunction(
+           CompilerRtAPINames::StartCtx,
+           FunctionType::get(ContextNodeTy->getPointerTo(),
+                             {ContextRootTy->getPointerTo(), /*ContextRoot*/
+                              I64Ty, /*Guid*/ I32Ty,
+                              /*NrCounters*/ I32Ty /*NrCallsites*/},
+                             false))
+          .getCallee());
+  GetCtx = cast<Function>(
+      M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
+                            FunctionType::get(ContextNodeTy->getPointerTo(),
+                                              {PointerTy, /*Callee*/
+                                               I64Ty,     /*Guid*/
+                                               I32Ty,     /*NrCounters*/
+                                               I32Ty},    /*NrCallsites*/
+                                              false))
+          .getCallee());
+  ReleaseCtx = cast<Function>(
+      M.getOrInsertFunction(
+           CompilerRtAPINames::ReleaseCtx,
+           FunctionType::get(Type::getVoidTy(M.getContext()),
+                             {
+                                 ContextRootTy->getPointerTo(), /*ContextRoot*/
+                             },
+                             false))
+          .getCallee());
+
+  // Declare the TLSes we will need to use.
+  CallsiteInfoTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, CompilerRtAPINames::CallsiteTLS);
+  CallsiteInfoTLS->setThreadLocal(true);
+  CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  ExpectedCalleeTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
+  ExpectedCalleeTLS->setThreadLocal(true);
+  ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+}
+
+PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
+                                              ModuleAnalysisManager &MAM) {
+  CtxInstrumentationLowerer Lowerer(M, MAM);
+  for (auto &F : M)
+    Lowerer.lowerFunction(F);
+  return PreservedAnalyses::none();
+}
+
+void CtxInstrumentationLowerer::lowerFunction(Function &F) {
+  if (F.isDeclaration())
+    return;
+  auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+
+  Value *Guid = nullptr;
+  auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
+
+  Value *Context = nullptr;
+  Value *RealContext = nullptr;
+
+  StructType *ThisContextType = nullptr;
+  Value *TheRootContext = nullptr;
+  Value *ExpectedCalleeTLSAddr = nullptr;
+  Value *CallsiteInfoTLSAddr = nullptr;
+
+  auto &Head = F.getEntryBlock();
+  for (auto &I : Head) {
+    // Find the increment intrinsic in the entry basic block.
+    if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
+      assert(Mark->getIndex()->isZero());
+
+      IRBuilder<> Builder(Mark);
+      // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
+      Guid = Builder.getInt64(F.getGUID());
+      // The type of the context of this function is now knowable since we have
+      // NrCallsites and NrCounters. We delcare it here because it's more
+      // convenient - we have the Builder.
+      ThisContextType = StructType::get(
+          F.getContext(),
+          {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
+           ArrayType::get(Builder.getPtrTy(), NrCallsites)});
+      // Figure out which way we obtain the context object for this function -
+      // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
+      // former case, we also set TheRootContext since we need it to release it
+      // at the end (plus it can be used to know if we have an entrypoint or a
+      // regular function)
+      auto Iter = ContextRootMap.find(&F);
+      if (Iter != ContextRootMap.end()) {
+        TheRootContext = Iter->second;
+        Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
+                                                Builder.getInt32(NrCounters),
+                                                Builder.getInt32(NrCallsites)});
+        ORE.emit(
+            [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
+      } else {
+        Context =
+            Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
+                                        Builder.getInt32(NrCallsites)});
+        ORE.emit([&] {
+          return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
+        });
+      }
+      // The context could be scratch.
+      auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
+      if (NrCallsites > 0) {
+        // Figure out which index of the TLS 2-element buffers to use.
+        // Scratch context => we use index == 1. Real contexts => index == 0.
+        auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
+        // The GEPs corresponding to that index, in the respective TLS.
+        ExpectedCalleeTLSAddr = Builder.CreateGEP(
+            Builder.getInt8Ty()->getPointerTo(),
+            Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
+        CallsiteInfoTLSAddr = Builder.CreateGEP(
+            Builder.getInt32Ty(),
+            Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
+      }
+      // Because the context pointer may have LSB set (to indicate scratch),
+      // clear it for the value we use as base address for the counter vector.
+      // This way, if later we want to have "real" (not clobbered) buffers
+      // acting as scratch, the lowering (at least this part of it that deals
+      // with counters) stays the same.
+      RealContext = Builder.CreateIntToPtr(
+          Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
+          ThisContextType->getPointerTo());
+      I.eraseFromParent();
+      break;
+    }
+  }
+  if (!Context) {
+    ORE.emit([&] {
+      return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
+             << "Function doesn't have instrumentation, skipping";
+    });
+    return;
+  }
+
+  bool ContextWasReleased = false;
+  for (auto &BB : F) {
+    for (auto &I : llvm::make_early_inc_range(BB)) {
+      if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
+        IRBuilder<> Builder(Instr);
+        switch (Instr->getIntrinsicID()) {
+        case llvm::Intrinsic::instrprof_increment:
+        case llvm::Intrinsic::instrprof_increment_step: {
+          // Increments (or increment-steps) are just a typical load - increment
+          // - store in the RealContext.
+          auto *AsStep = cast<InstrProfIncrementInst>(Instr);
+          auto *GEP = Builder.CreateGEP(
+              ThisContextType, RealContext,
+              {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
+          Builder.CreateStore(
+              Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
+                                AsStep->getStep()),
+              GEP);
+        } break;
+        case llvm::Intrinsic::instrprof_callsite:
+          // callsite lowering: write the called value in the expected callee
+          // TLS we treat the TLS as volatile because of signal handlers and to
+          // avoid these being moved away from the callsite they decorate.
+          auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
+          Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
+                              true);
+          // write the GEP of the slot in the sub-contexts portion of the
+          // context in TLS. Now, here, we use the actual Context value - as
+          // returned from compiler-rt - which may have the LSB set if the
+          // Context was scratch. Since the header of the context object and
+          // then the values are all 8-aligned (or, really, insofar as we care,
+          // they are even) - if the context is scratch (meaning, an odd value),
+          // so will the GEP. This is important because this is then visible to
+          // compiler-rt which will produce scratch contexts for callers that
+          // have a scratch context.
+          Builder.CreateStore(
+              Builder.CreateGEP(ThisContextType, Context,
+                                {Builder.getInt32(0), Builder.getInt32(2),
+                                 CSIntrinsic->getIndex()}),
+              CallsiteInfoTLSAddr, true);
+          break;
+        }
+        I.eraseFromParent();
+      } else if (TheRootContext && isa<ReturnInst>(I)) {
+        // Remember to release the context if we are an entrypoint.
+        IRBuilder<> Builder(&I);
+        Builder.CreateCall(ReleaseCtx, {TheRootContext});
+        ContextWasReleased = true;
+      }
+    }
+  }
+  // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
+  // to disallow this, (so this then stays as an error), another is to detect
+  // that and then do a wrapper or disallow the tail call. This only affects
+  // instrumentation, when we want to detect the call graph.
+  if (TheRootContext && !ContextWasReleased)
+    F.getContext().emitError(
+        "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
+        "instructions above which to release the context: " +
+        F.getName());
+}
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index 2ad95ab51cc6..7fa14f6cd30b 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -1,11 +1,27 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
 ; RUN: opt -passes=pgo-instr-gen -profile-context-root=an_entrypoint \
 ; RUN:   -S < %s | FileCheck --check-prefix=INSTRUMENT %s
+; RUN: opt -passes=pgo-instr-gen,pgo-ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN:   -profile-context-root=another_entrypoint_no_callees \
+; RUN:   -S < %s | FileCheck --check-prefix=LOWERING %s
+
 
 declare void @bar()
 
 ;.
 ; INSTRUMENT: @__profn_foo = private constant [3 x i8] c"foo"
+; INSTRUMENT: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
+; INSTRUMENT: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
+; INSTRUMENT: @__profn_simple = private constant [6 x i8] c"simple"
+;.
+; LOWERING: @__profn_foo = private constant [3 x i8] c"foo"
+; LOWERING: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
+; LOWERING: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
+; LOWERING: @__profn_simple = private constant [6 x i8] c"simple"
+; LOWERING: @an_entrypoint_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @another_entrypoint_no_callees_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
+; LOWERING: @__llvm_ctx_profile_callsite = external hidden thread_local global ptr
+; LOWERING: @__llvm_ctx_profile_expected_callee = external hidden thread_local global ptr
 ;.
 define void @foo(i32 %a, ptr %fct) {
 ; INSTRUMENT-LABEL: define void @foo(
@@ -24,6 +40,38 @@ define void @foo(i32 %a, ptr %fct) {
 ; INSTRUMENT-NEXT:    br label [[EXIT]]
 ; INSTRUMENT:       exit:
 ; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @foo(
+; LOWERING-SAME: i32 [[A:%.*]], ptr [[FCT:%.*]]) {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @foo, i64 6699318081062747564, i32 2, i32 2)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; LOWERING-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_expected_callee)
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_callsite)
+; LOWERING-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; LOWERING-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; LOWERING-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; LOWERING:       yes:
+; LOWERING-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [2 x ptr] }, ptr [[TMP9]], i32 0, i32 1, i32 1
+; LOWERING-NEXT:    [[TMP11:%.*]] = load i64, ptr [[TMP10]], align 4
+; LOWERING-NEXT:    [[TMP12:%.*]] = add i64 [[TMP11]], 1
+; LOWERING-NEXT:    store i64 [[TMP12]], ptr [[TMP10]], align 4
+; LOWERING-NEXT:    store volatile ptr [[FCT]], ptr [[TMP5]], align 8
+; LOWERING-NEXT:    [[TMP13:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [2 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; LOWERING-NEXT:    store volatile ptr [[TMP13]], ptr [[TMP7]], align 8
+; LOWERING-NEXT:    call void [[FCT]](i32 [[A]])
+; LOWERING-NEXT:    br label [[EXIT:%.*]]
+; LOWERING:       no:
+; LOWERING-NEXT:    store volatile ptr @bar, ptr [[TMP5]], align 8
+; LOWERING-NEXT:    [[TMP14:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [2 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 1
+; LOWERING-NEXT:    store volatile ptr [[TMP14]], ptr [[TMP7]], align 8
+; LOWERING-NEXT:    call void @bar()
+; LOWERING-NEXT:    br label [[EXIT]]
+; LOWERING:       exit:
+; LOWERING-NEXT:    ret void
 ;
   %t = icmp eq i32 %a, 0
   br i1 %t, label %yes, label %no
@@ -36,6 +84,119 @@ no:
 exit:
   ret void
 }
+
+define void @an_entrypoint(i32 %a) {
+; INSTRUMENT-LABEL: define void @an_entrypoint(
+; INSTRUMENT-SAME: i32 [[A:%.*]]) {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; INSTRUMENT-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; INSTRUMENT:       yes:
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.callsite(ptr @__profn_an_entrypoint, i64 784007058953177093, i32 1, i32 0, ptr @foo)
+; INSTRUMENT-NEXT:    call void @foo(i32 1, ptr null)
+; INSTRUMENT-NEXT:    ret void
+; INSTRUMENT:       no:
+; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @an_entrypoint(
+; LOWERING-SAME: i32 [[A:%.*]]) {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_start_context(ptr @an_entrypoint_ctx_root, i64 4909520559318251808, i32 2, i32 1)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; LOWERING-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_expected_callee)
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_callsite)
+; LOWERING-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; LOWERING-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; LOWERING-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; LOWERING:       yes:
+; LOWERING-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP9]], i32 0, i32 1, i32 1
+; LOWERING-NEXT:    [[TMP11:%.*]] = load i64, ptr [[TMP10]], align 4
+; LOWERING-NEXT:    [[TMP12:%.*]] = add i64 [[TMP11]], 1
+; LOWERING-NEXT:    store i64 [[TMP12]], ptr [[TMP10]], align 4
+; LOWERING-NEXT:    store volatile ptr @foo, ptr [[TMP5]], align 8
+; LOWERING-NEXT:    [[TMP13:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; LOWERING-NEXT:    store volatile ptr [[TMP13]], ptr [[TMP7]], align 8
+; LOWERING-NEXT:    call void @foo(i32 1, ptr null)
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @an_entrypoint_ctx_root)
+; LOWERING-NEXT:    ret void
+; LOWERING:       no:
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @an_entrypoint_ctx_root)
+; LOWERING-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  call void @foo(i32 1, ptr null)
+  ret void
+no:
+  ret void
+}
+
+define void @another_entrypoint_no_callees(i32 %a) {
+; INSTRUMENT-LABEL: define void @another_entrypoint_no_callees(
+; INSTRUMENT-SAME: i32 [[A:%.*]]) {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; INSTRUMENT-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; INSTRUMENT:       yes:
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_another_entrypoint_no_callees, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT:    ret void
+; INSTRUMENT:       no:
+; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @another_entrypoint_no_callees(
+; LOWERING-SAME: i32 [[A:%.*]]) {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_start_context(ptr @another_entrypoint_no_callees_ctx_root, i64 -6371873725078000974, i32 0, i32 0)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; LOWERING-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
+; LOWERING-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
+; LOWERING:       yes:
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [0 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
+; LOWERING-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
+; LOWERING-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
+; LOWERING-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @another_entrypoint_no_callees_ctx_root)
+; LOWERING-NEXT:    ret void
+; LOWERING:       no:
+; LOWERING-NEXT:    call void @__llvm_ctx_profile_release_context(ptr @another_entrypoint_no_callees_ctx_root)
+; LOWERING-NEXT:    ret void
+;
+  %t = icmp eq i32 %a, 0
+  br i1 %t, label %yes, label %no
+
+yes:
+  ret void
+no:
+  ret void
+}
+
+define void @simple(i32 %a) {
+; INSTRUMENT-LABEL: define void @simple(
+; INSTRUMENT-SAME: i32 [[A:%.*]]) {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_simple, i64 742261418966908927, i32 1, i32 0)
+; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @simple(
+; LOWERING-SAME: i32 [[A:%.*]]) {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @simple, i64 -3006003237940970099, i32 0, i32 0)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; LOWERING-NEXT:    ret void
+;
+  ret void
+}
+
 ;.
 ; INSTRUMENT: attributes #[[ATTR0:[0-9]+]] = { nounwind }
 ;.
+; LOWERING: attributes #[[ATTR0:[0-9]+]] = { nounwind }
+; LOWERING: attributes #[[ATTR1:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;.

>From ecaad76f812aefcefe930e0b3a32dc6dd15611eb Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 1 May 2024 22:43:53 -0700
Subject: [PATCH 2/5] don't report changes if none happened.

---
 .../Instrumentation/PGOCtxProfLowering.cpp       | 16 ++++++++++------
 1 file changed, 10 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index 7442d8010ab0..b3b0197a775c 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -9,6 +9,7 @@
 
 #include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/IR/Analysis.h"
 #include "llvm/IR/DiagnosticInfo.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
@@ -59,7 +60,8 @@ class CtxInstrumentationLowerer final {
 
 public:
   CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
-  void lowerFunction(Function &F);
+  // return true if lowering happened (i.e. a change was made)
+  bool lowerFunction(Function &F);
 };
 
 std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
@@ -169,14 +171,15 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
 PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
                                               ModuleAnalysisManager &MAM) {
   CtxInstrumentationLowerer Lowerer(M, MAM);
+  bool Changed = false;
   for (auto &F : M)
-    Lowerer.lowerFunction(F);
-  return PreservedAnalyses::none();
+    Changed |= Lowerer.lowerFunction(F);
+  return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
 }
 
-void CtxInstrumentationLowerer::lowerFunction(Function &F) {
+bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
   if (F.isDeclaration())
-    return;
+    return false;
   auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
   auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
 
@@ -259,7 +262,7 @@ void CtxInstrumentationLowerer::lowerFunction(Function &F) {
       return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
              << "Function doesn't have instrumentation, skipping";
     });
-    return;
+    return false;
   }
 
   bool ContextWasReleased = false;
@@ -322,4 +325,5 @@ void CtxInstrumentationLowerer::lowerFunction(Function &F) {
         "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
         "instructions above which to release the context: " +
         F.getName());
+  return true;
 }

>From ab5895fbd8d1ebbc34bf465b982db9ac44f8f4e3 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 8 May 2024 11:23:34 -0700
Subject: [PATCH 3/5] feedback

---
 llvm/lib/Passes/PassRegistry.def              |  2 +-
 .../Instrumentation/PGOCtxProfLowering.cpp    | 39 +++++++---
 .../ctx-instrumentation-invalid-roots.ll      | 17 +++++
 .../PGOProfile/ctx-instrumentation.ll         | 76 ++++++++++++++++++-
 4 files changed, 120 insertions(+), 14 deletions(-)
 create mode 100644 llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll

diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 8f79601d0351..ffbac27efe3a 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -77,7 +77,7 @@ MODULE_PASS("inliner-wrapper-no-mandatory-first",
 MODULE_PASS("insert-gcov-profiling", GCOVProfilerPass())
 MODULE_PASS("instrorderfile", InstrOrderFilePass())
 MODULE_PASS("instrprof", InstrProfilingLoweringPass())
-MODULE_PASS("pgo-ctx-instr-lower", PGOCtxProfLoweringPass())
+MODULE_PASS("ctx-instr-lower", PGOCtxProfLoweringPass())
 MODULE_PASS("internalize", InternalizePass())
 MODULE_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 MODULE_PASS("iroutliner", IROutlinerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
index b3b0197a775c..76afa2f22461 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOCtxProfLowering.cpp
@@ -20,7 +20,7 @@
 
 using namespace llvm;
 
-#define DEBUG_TYPE "ctx-profile-lower"
+#define DEBUG_TYPE "ctx-instr-lower"
 
 static cl::list<std::string> ContextRoots(
     "profile-context-root", cl::Hidden,
@@ -64,25 +64,37 @@ class CtxInstrumentationLowerer final {
   bool lowerFunction(Function &F);
 };
 
+// llvm.instrprof.increment[.step] captures the total number of counters as one
+// of its parameters, and llvm.instrprof.callsite captures the total number of
+// callsites. Those values are the same for instances of those intrinsics in
+// this function. Find the first instance of each and return them.
 std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
   uint32_t NrCounters = 0;
   uint32_t NrCallsites = 0;
   for (const auto &BB : F) {
     for (const auto &I : BB) {
       if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
-        if (!NrCounters)
-          NrCounters =
-              static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+        uint32_t V =
+            static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+        assert((!NrCounters || V == NrCounters) &&
+               "expected all llvm.instrprof.increment[.step] intrinsics to "
+               "have the same total nr of counters parameter");
+        NrCounters = V;
       } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
-        if (!NrCallsites)
-          NrCallsites =
-              static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+        uint32_t V =
+            static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+        assert((!NrCallsites || V == NrCallsites) &&
+               "expected all llvm.instrprof.callsite intrinsics to have the "
+               "same total nr of callsites parameter");
+        NrCallsites = V;
       }
+#if NDEBUG
       if (NrCounters && NrCallsites)
         return std::make_pair(NrCounters, NrCallsites);
+#endif
     }
   }
-  return {0, 0};
+  return {NrCounters, NrCallsites};
 }
 } // namespace
 
@@ -123,6 +135,15 @@ CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
       cast<GlobalVariable>(G)->setInitializer(
           Constant::getNullValue(ContextRootTy));
       ContextRootMap.insert(std::make_pair(F, G));
+      for (const auto &BB : *F)
+        for (const auto &I : BB)
+          if (const auto *CB = dyn_cast<CallBase>(&I))
+            if (CB->isMustTailCall()) {
+              M.getContext().emitError(
+                  "The function " + Fname +
+                  " was indicated as a context root, but it features musttail "
+                  "calls, which is not supported.");
+            }
     }
   }
 
@@ -212,7 +233,7 @@ bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
            ArrayType::get(Builder.getPtrTy(), NrCallsites)});
       // Figure out which way we obtain the context object for this function -
       // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
-      // former case, we also set TheRootContext since we need it to release it
+      // former case, we also set TheRootContext since we need to release it
       // at the end (plus it can be used to know if we have an entrypoint or a
       // regular function)
       auto Iter = ContextRootMap.find(&F);
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
new file mode 100644
index 000000000000..99c7762a67df
--- /dev/null
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
@@ -0,0 +1,17 @@
+; RUN: not opt -passes=pgo-instr-gen,ctx-instr-lower -profile-context-root=good \
+; RUN:   -profile-context-root=bad \
+; RUN:   -S < %s 2>&1 | FileCheck %s
+
+declare void @foo()
+
+define void @good() {
+  call void @foo()
+  ret void
+}
+
+define void @bad() {
+  musttail call void @foo()
+  ret void
+}
+
+; CHECK: error: The function bad was indicated as a context root, but it features musttail calls, which is not supported.
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index 7fa14f6cd30b..56c7c7519f69 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
 ; RUN: opt -passes=pgo-instr-gen -profile-context-root=an_entrypoint \
 ; RUN:   -S < %s | FileCheck --check-prefix=INSTRUMENT %s
-; RUN: opt -passes=pgo-instr-gen,pgo-ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN: opt -passes=pgo-instr-gen,ctx-instr-lower -profile-context-root=an_entrypoint \
 ; RUN:   -profile-context-root=another_entrypoint_no_callees \
 ; RUN:   -S < %s | FileCheck --check-prefix=LOWERING %s
 
@@ -13,11 +13,15 @@ declare void @bar()
 ; INSTRUMENT: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
 ; INSTRUMENT: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
 ; INSTRUMENT: @__profn_simple = private constant [6 x i8] c"simple"
+; INSTRUMENT: @__profn_no_callsites = private constant [12 x i8] c"no_callsites"
+; INSTRUMENT: @__profn_no_counters = private constant [11 x i8] c"no_counters"
 ;.
 ; LOWERING: @__profn_foo = private constant [3 x i8] c"foo"
 ; LOWERING: @__profn_an_entrypoint = private constant [13 x i8] c"an_entrypoint"
 ; LOWERING: @__profn_another_entrypoint_no_callees = private constant [29 x i8] c"another_entrypoint_no_callees"
 ; LOWERING: @__profn_simple = private constant [6 x i8] c"simple"
+; LOWERING: @__profn_no_callsites = private constant [12 x i8] c"no_callsites"
+; LOWERING: @__profn_no_counters = private constant [11 x i8] c"no_counters"
 ; LOWERING: @an_entrypoint_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @another_entrypoint_no_callees_ctx_root = global { ptr, ptr, ptr, i8 } zeroinitializer
 ; LOWERING: @__llvm_ctx_profile_callsite = external hidden thread_local global ptr
@@ -151,14 +155,14 @@ define void @another_entrypoint_no_callees(i32 %a) {
 ;
 ; LOWERING-LABEL: define void @another_entrypoint_no_callees(
 ; LOWERING-SAME: i32 [[A:%.*]]) {
-; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_start_context(ptr @another_entrypoint_no_callees_ctx_root, i64 -6371873725078000974, i32 0, i32 0)
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_start_context(ptr @another_entrypoint_no_callees_ctx_root, i64 -6371873725078000974, i32 2, i32 0)
 ; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
 ; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
 ; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
 ; LOWERING-NEXT:    [[T:%.*]] = icmp eq i32 [[A]], 0
 ; LOWERING-NEXT:    br i1 [[T]], label [[YES:%.*]], label [[NO:%.*]]
 ; LOWERING:       yes:
-; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [0 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
 ; LOWERING-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
 ; LOWERING-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
 ; LOWERING-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
@@ -185,7 +189,7 @@ define void @simple(i32 %a) {
 ;
 ; LOWERING-LABEL: define void @simple(
 ; LOWERING-SAME: i32 [[A:%.*]]) {
-; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @simple, i64 -3006003237940970099, i32 0, i32 0)
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @simple, i64 -3006003237940970099, i32 1, i32 0)
 ; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
 ; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
 ; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
@@ -194,6 +198,70 @@ define void @simple(i32 %a) {
   ret void
 }
 
+
+define i32 @no_callsites(i32 %a) {
+; INSTRUMENT-LABEL: define i32 @no_callsites(
+; INSTRUMENT-SAME: i32 [[A:%.*]]) {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_no_callsites, i64 784007058953177093, i32 2, i32 0)
+; INSTRUMENT-NEXT:    [[C:%.*]] = icmp eq i32 [[A]], 0
+; INSTRUMENT-NEXT:    br i1 [[C]], label [[YES:%.*]], label [[NO:%.*]]
+; INSTRUMENT:       yes:
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_no_callsites, i64 784007058953177093, i32 2, i32 1)
+; INSTRUMENT-NEXT:    ret i32 1
+; INSTRUMENT:       no:
+; INSTRUMENT-NEXT:    ret i32 0
+;
+; LOWERING-LABEL: define i32 @no_callsites(
+; LOWERING-SAME: i32 [[A:%.*]]) {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @no_callsites, i64 5679753335911435902, i32 2, i32 0)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP4:%.*]] = inttoptr i64 [[TMP3]] to ptr
+; LOWERING-NEXT:    [[C:%.*]] = icmp eq i32 [[A]], 0
+; LOWERING-NEXT:    br i1 [[C]], label [[YES:%.*]], label [[NO:%.*]]
+; LOWERING:       yes:
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [2 x i64], [0 x ptr] }, ptr [[TMP4]], i32 0, i32 1, i32 1
+; LOWERING-NEXT:    [[TMP6:%.*]] = load i64, ptr [[TMP5]], align 4
+; LOWERING-NEXT:    [[TMP7:%.*]] = add i64 [[TMP6]], 1
+; LOWERING-NEXT:    store i64 [[TMP7]], ptr [[TMP5]], align 4
+; LOWERING-NEXT:    ret i32 1
+; LOWERING:       no:
+; LOWERING-NEXT:    ret i32 0
+;
+  %c = icmp eq i32 %a, 0
+  br i1 %c, label %yes, label %no
+yes:
+  ret i32 1
+no:
+  ret i32 0
+}
+
+define void @no_counters() {
+; INSTRUMENT-LABEL: define void @no_counters() {
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.increment(ptr @__profn_no_counters, i64 742261418966908927, i32 1, i32 0)
+; INSTRUMENT-NEXT:    call void @llvm.instrprof.callsite(ptr @__profn_no_counters, i64 742261418966908927, i32 1, i32 0, ptr @bar)
+; INSTRUMENT-NEXT:    call void @bar()
+; INSTRUMENT-NEXT:    ret void
+;
+; LOWERING-LABEL: define void @no_counters() {
+; LOWERING-NEXT:    [[TMP1:%.*]] = call ptr @__llvm_ctx_profile_get_context(ptr @no_counters, i64 5458232184388660970, i32 1, i32 1)
+; LOWERING-NEXT:    [[TMP2:%.*]] = ptrtoint ptr [[TMP1]] to i64
+; LOWERING-NEXT:    [[TMP3:%.*]] = and i64 [[TMP2]], 1
+; LOWERING-NEXT:    [[TMP4:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_expected_callee)
+; LOWERING-NEXT:    [[TMP5:%.*]] = getelementptr ptr, ptr [[TMP4]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP6:%.*]] = call ptr @llvm.threadlocal.address.p0(ptr @__llvm_ctx_profile_callsite)
+; LOWERING-NEXT:    [[TMP7:%.*]] = getelementptr i32, ptr [[TMP6]], i64 [[TMP3]]
+; LOWERING-NEXT:    [[TMP8:%.*]] = and i64 [[TMP2]], -2
+; LOWERING-NEXT:    [[TMP9:%.*]] = inttoptr i64 [[TMP8]] to ptr
+; LOWERING-NEXT:    store volatile ptr @bar, ptr [[TMP5]], align 8
+; LOWERING-NEXT:    [[TMP10:%.*]] = getelementptr { { i64, ptr, i32, i32 }, [1 x i64], [1 x ptr] }, ptr [[TMP1]], i32 0, i32 2, i32 0
+; LOWERING-NEXT:    store volatile ptr [[TMP10]], ptr [[TMP7]], align 8
+; LOWERING-NEXT:    call void @bar()
+; LOWERING-NEXT:    ret void
+;
+  call void @bar()
+  ret void
+}
 ;.
 ; INSTRUMENT: attributes #[[ATTR0:[0-9]+]] = { nounwind }
 ;.

>From 601387f84012aeadd7eaf553a8a388217fd816c7 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 8 May 2024 14:49:42 -0700
Subject: [PATCH 4/5] updated documentation

---
 llvm/docs/LangRef.rst | 48 ++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 43 insertions(+), 5 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index 6291a4e57919..b70f8b4bf244 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -14111,6 +14111,25 @@ structures and the code to increment the appropriate value, in a
 format that can be written out by a compiler runtime and consumed via
 the ``llvm-profdata`` tool.
 
+.. FIXME: write complete doc on contextual instrumentation and link from here
+.. and from llvm.instrprof.callsite.
+
+The intrinsic is lowered differently for contextual profiling by the
+``-ctx-instr-lower`` pass. Here:
+
+* the entry basic block increment counter is lowered as a call to compiler-rt,
+  to either ``__llvm_ctx_profile_start_context`` or
+  ``__llvm_ctx_profile_get_context``. Either returns a pointer to a context object
+  which contains a buffer into which counter increments can happen. Note that the
+  pointer value returned by compiler-rt may have its LSB set - counter increments
+  happen offset from the address with the LSB cleared.
+
+* all the other lowerings of ``llvm.instrprof.increment[.step]`` happen within
+  that context.
+
+* the context is assumed to be a local value to the function, and no concurrency
+  concerns need to be handled by LLVM.
+
 '``llvm.instrprof.increment.step``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
@@ -14156,10 +14175,10 @@ Syntax:
 Overview:
 """""""""
 
-.. FIXME: detail when it's emitted once the support is added
-
 The '``llvm.instrprof.callsite``' intrinsic should be emitted before a callsite
-that's not to a "fake" callee (like another intrinsic or asm).
+that's not to a "fake" callee (like another intrinsic or asm). It is used by
+contextual profiling and is side-effectful. Its lowering happens in IR, and
+target-specific backends should never encounter it.
 
 Arguments:
 """"""""""
@@ -14172,9 +14191,28 @@ The last argument is the called value of the callsite this intrinsic precedes.
 
 Semantics:
 """"""""""
-.. FIXME: detail how when the lowering pass is added.
 
-This is lowered by contextual profiling.
+This is lowered by contextual profiling. In contextual profiling, functions get,
+from compiler-rt, a pointer to a context object. The context object consists of
+a buffer LLVM can use to perform counter increments (i.e. the lowering of
+``llvm.instrprof.increment[.step]``. The address range following the counter
+buffer, ``<num-counters>`` x ``sizeof(ptr)`` - sized, is expected to contain
+pointers to contexts of functions called from this function ("subcontexts").
+LLVM does not dereference into that memory region, just calculates GEPs. 
+
+The lowering of ``llvm.instrprof.callsite`` consists of:
+
+* write to ``__llvm_ctx_profile_expected_callee`` the ``<callsite>`` value;
+
+* write to ``__llvm_ctx_profile_callsite`` the address into this function's
+  context of the ``<index>`` position into the subcontexts region.
+
+
+``__llvm_ctx_profile_{expected_callee|callsite}`` are initialized by compiler-rt
+and are TLS. They are both 2-sized vectors of pointers. The index into each is
+determined when the current function obtains the pointer to its context from
+compiler-rt. The pointer's LSB gives the index.
+
 
 '``llvm.instrprof.timestamp``' Intrinsic
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

>From bb1d5fbbd9bd20927d17d8fa465c430acb39b8e6 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Wed, 8 May 2024 15:17:45 -0700
Subject: [PATCH 5/5] doc fixes

---
 llvm/docs/LangRef.rst | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst
index b70f8b4bf244..dd870dbd0589 100644
--- a/llvm/docs/LangRef.rst
+++ b/llvm/docs/LangRef.rst
@@ -14177,7 +14177,7 @@ Overview:
 
 The '``llvm.instrprof.callsite``' intrinsic should be emitted before a callsite
 that's not to a "fake" callee (like another intrinsic or asm). It is used by
-contextual profiling and is side-effectful. Its lowering happens in IR, and
+contextual profiling and has side-effects. Its lowering happens in IR, and
 target-specific backends should never encounter it.
 
 Arguments:
@@ -14209,7 +14209,7 @@ The lowering of ``llvm.instrprof.callsite`` consists of:
 
 
 ``__llvm_ctx_profile_{expected_callee|callsite}`` are initialized by compiler-rt
-and are TLS. They are both 2-sized vectors of pointers. The index into each is
+and are TLS. They are both vectors of pointers of size 2. The index into each is
 determined when the current function obtains the pointer to its context from
 compiler-rt. The pointer's LSB gives the index.
 



More information about the llvm-commits mailing list