[llvm] [llvm][ctx_profile] Add instrumentation lowering (PR #90821)

Snehasish Kumar via llvm-commits llvm-commits at lists.llvm.org
Wed May 8 09:39:29 PDT 2024


================
@@ -22,3 +32,298 @@ static cl::list<std::string> ContextRoots(
 bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
   return !ContextRoots.empty();
 }
+
+// the names of symbols we expect in compiler-rt. Using a namespace for
+// readability.
+namespace CompilerRtAPINames {
+static auto StartCtx = "__llvm_ctx_profile_start_context";
+static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
+static auto GetCtx = "__llvm_ctx_profile_get_context";
+static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
+static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
+} // namespace CompilerRtAPINames
+
+namespace {
+// The lowering logic and state.
+class CtxInstrumentationLowerer final {
+  Module &M;
+  ModuleAnalysisManager &MAM;
+  Type *ContextNodeTy = nullptr;
+  Type *ContextRootTy = nullptr;
+
+  DenseMap<const Function *, Constant *> ContextRootMap;
+  Function *StartCtx = nullptr;
+  Function *GetCtx = nullptr;
+  Function *ReleaseCtx = nullptr;
+  GlobalVariable *ExpectedCalleeTLS = nullptr;
+  GlobalVariable *CallsiteInfoTLS = nullptr;
+
+public:
+  CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
+  // return true if lowering happened (i.e. a change was made)
+  bool lowerFunction(Function &F);
+};
+
+std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
+  uint32_t NrCounters = 0;
+  uint32_t NrCallsites = 0;
+  for (const auto &BB : F) {
+    for (const auto &I : BB) {
+      if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
+        if (!NrCounters)
+          NrCounters =
+              static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+      } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
+        if (!NrCallsites)
+          NrCallsites =
+              static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+      }
+      if (NrCounters && NrCallsites)
+        return std::make_pair(NrCounters, NrCallsites);
+    }
+  }
+  return {0, 0};
+}
+} // namespace
+
+// set up tie-in with compiler-rt.
+// NOTE!!!
+// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
+                                                     ModuleAnalysisManager &MAM)
+    : M(M), MAM(MAM) {
+  auto *PointerTy = PointerType::get(M.getContext(), 0);
+  auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
+  auto *I32Ty = Type::getInt32Ty(M.getContext());
+  auto *I64Ty = Type::getInt64Ty(M.getContext());
+
+  // The ContextRoot type
+  ContextRootTy =
+      StructType::get(M.getContext(), {
+                                          PointerTy,          /*FirstNode*/
+                                          PointerTy,          /*FirstMemBlock*/
+                                          PointerTy,          /*CurrentMem*/
+                                          SanitizerMutexType, /*Taken*/
+                                      });
+  // The Context header.
+  ContextNodeTy = StructType::get(M.getContext(), {
+                                                      I64Ty,     /*Guid*/
+                                                      PointerTy, /*Next*/
+                                                      I32Ty,     /*NrCounters*/
+                                                      I32Ty,     /*NrCallsites*/
+                                                  });
+
+  // Define a global for each entrypoint. We'll reuse the entrypoint's name as
+  // prefix. We assume the entrypoint names to be unique.
+  for (const auto &Fname : ContextRoots) {
+    if (const auto *F = M.getFunction(Fname)) {
+      if (F->isDeclaration())
+        continue;
+      auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
+      cast<GlobalVariable>(G)->setInitializer(
+          Constant::getNullValue(ContextRootTy));
+      ContextRootMap.insert(std::make_pair(F, G));
+    }
+  }
+
+  // Declare the functions we will call.
+  StartCtx = cast<Function>(
+      M.getOrInsertFunction(
+           CompilerRtAPINames::StartCtx,
+           FunctionType::get(ContextNodeTy->getPointerTo(),
+                             {ContextRootTy->getPointerTo(), /*ContextRoot*/
+                              I64Ty, /*Guid*/ I32Ty,
+                              /*NrCounters*/ I32Ty /*NrCallsites*/},
+                             false))
+          .getCallee());
+  GetCtx = cast<Function>(
+      M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
+                            FunctionType::get(ContextNodeTy->getPointerTo(),
+                                              {PointerTy, /*Callee*/
+                                               I64Ty,     /*Guid*/
+                                               I32Ty,     /*NrCounters*/
+                                               I32Ty},    /*NrCallsites*/
+                                              false))
+          .getCallee());
+  ReleaseCtx = cast<Function>(
+      M.getOrInsertFunction(
+           CompilerRtAPINames::ReleaseCtx,
+           FunctionType::get(Type::getVoidTy(M.getContext()),
+                             {
+                                 ContextRootTy->getPointerTo(), /*ContextRoot*/
+                             },
+                             false))
+          .getCallee());
+
+  // Declare the TLSes we will need to use.
+  CallsiteInfoTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, CompilerRtAPINames::CallsiteTLS);
+  CallsiteInfoTLS->setThreadLocal(true);
+  CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+  ExpectedCalleeTLS =
+      new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+                         nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
+  ExpectedCalleeTLS->setThreadLocal(true);
+  ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+}
+
+PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
+                                              ModuleAnalysisManager &MAM) {
+  CtxInstrumentationLowerer Lowerer(M, MAM);
+  bool Changed = false;
+  for (auto &F : M)
+    Changed |= Lowerer.lowerFunction(F);
+  return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
+
+bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
+  if (F.isDeclaration())
+    return false;
+  auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+
+  Value *Guid = nullptr;
+  auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
+
+  Value *Context = nullptr;
+  Value *RealContext = nullptr;
+
+  StructType *ThisContextType = nullptr;
+  Value *TheRootContext = nullptr;
+  Value *ExpectedCalleeTLSAddr = nullptr;
+  Value *CallsiteInfoTLSAddr = nullptr;
+
+  auto &Head = F.getEntryBlock();
+  for (auto &I : Head) {
+    // Find the increment intrinsic in the entry basic block.
+    if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
+      assert(Mark->getIndex()->isZero());
+
+      IRBuilder<> Builder(Mark);
+      // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
+      Guid = Builder.getInt64(F.getGUID());
+      // The type of the context of this function is now knowable since we have
+      // NrCallsites and NrCounters. We delcare it here because it's more
+      // convenient - we have the Builder.
+      ThisContextType = StructType::get(
+          F.getContext(),
+          {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
+           ArrayType::get(Builder.getPtrTy(), NrCallsites)});
+      // Figure out which way we obtain the context object for this function -
+      // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
+      // former case, we also set TheRootContext since we need it to release it
+      // at the end (plus it can be used to know if we have an entrypoint or a
+      // regular function)
+      auto Iter = ContextRootMap.find(&F);
+      if (Iter != ContextRootMap.end()) {
+        TheRootContext = Iter->second;
+        Context = Builder.CreateCall(StartCtx, {TheRootContext, Guid,
+                                                Builder.getInt32(NrCounters),
+                                                Builder.getInt32(NrCallsites)});
+        ORE.emit(
+            [&] { return OptimizationRemark(DEBUG_TYPE, "Entrypoint", &F); });
+      } else {
+        Context =
+            Builder.CreateCall(GetCtx, {&F, Guid, Builder.getInt32(NrCounters),
+                                        Builder.getInt32(NrCallsites)});
+        ORE.emit([&] {
+          return OptimizationRemark(DEBUG_TYPE, "RegularFunction", &F);
+        });
+      }
+      // The context could be scratch.
+      auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
+      if (NrCallsites > 0) {
+        // Figure out which index of the TLS 2-element buffers to use.
+        // Scratch context => we use index == 1. Real contexts => index == 0.
+        auto *Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
+        // The GEPs corresponding to that index, in the respective TLS.
+        ExpectedCalleeTLSAddr = Builder.CreateGEP(
+            Builder.getInt8Ty()->getPointerTo(),
+            Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {Index});
+        CallsiteInfoTLSAddr = Builder.CreateGEP(
+            Builder.getInt32Ty(),
+            Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {Index});
+      }
+      // Because the context pointer may have LSB set (to indicate scratch),
+      // clear it for the value we use as base address for the counter vector.
+      // This way, if later we want to have "real" (not clobbered) buffers
+      // acting as scratch, the lowering (at least this part of it that deals
+      // with counters) stays the same.
+      RealContext = Builder.CreateIntToPtr(
+          Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
+          ThisContextType->getPointerTo());
+      I.eraseFromParent();
+      break;
+    }
+  }
+  if (!Context) {
+    ORE.emit([&] {
+      return OptimizationRemarkMissed(DEBUG_TYPE, "Skip", &F)
+             << "Function doesn't have instrumentation, skipping";
+    });
+    return false;
+  }
+
+  bool ContextWasReleased = false;
+  for (auto &BB : F) {
+    for (auto &I : llvm::make_early_inc_range(BB)) {
+      if (auto *Instr = dyn_cast<InstrProfCntrInstBase>(&I)) {
+        IRBuilder<> Builder(Instr);
+        switch (Instr->getIntrinsicID()) {
+        case llvm::Intrinsic::instrprof_increment:
+        case llvm::Intrinsic::instrprof_increment_step: {
+          // Increments (or increment-steps) are just a typical load - increment
+          // - store in the RealContext.
+          auto *AsStep = cast<InstrProfIncrementInst>(Instr);
+          auto *GEP = Builder.CreateGEP(
+              ThisContextType, RealContext,
+              {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
+          Builder.CreateStore(
+              Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(), GEP),
+                                AsStep->getStep()),
+              GEP);
+        } break;
+        case llvm::Intrinsic::instrprof_callsite:
+          // callsite lowering: write the called value in the expected callee
+          // TLS we treat the TLS as volatile because of signal handlers and to
+          // avoid these being moved away from the callsite they decorate.
+          auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
+          Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
+                              true);
+          // write the GEP of the slot in the sub-contexts portion of the
+          // context in TLS. Now, here, we use the actual Context value - as
+          // returned from compiler-rt - which may have the LSB set if the
+          // Context was scratch. Since the header of the context object and
+          // then the values are all 8-aligned (or, really, insofar as we care,
+          // they are even) - if the context is scratch (meaning, an odd value),
+          // so will the GEP. This is important because this is then visible to
+          // compiler-rt which will produce scratch contexts for callers that
+          // have a scratch context.
+          Builder.CreateStore(
+              Builder.CreateGEP(ThisContextType, Context,
+                                {Builder.getInt32(0), Builder.getInt32(2),
+                                 CSIntrinsic->getIndex()}),
+              CallsiteInfoTLSAddr, true);
+          break;
+        }
+        I.eraseFromParent();
+      } else if (TheRootContext && isa<ReturnInst>(I)) {
+        // Remember to release the context if we are an entrypoint.
+        IRBuilder<> Builder(&I);
+        Builder.CreateCall(ReleaseCtx, {TheRootContext});
+        ContextWasReleased = true;
+      }
+    }
+  }
+  // FIXME: This would happen if the entrypoint tailcalls. A way to fix would be
+  // to disallow this, (so this then stays as an error), another is to detect
----------------
snehasish wrote:

The sentence mentions "disallow this" and then ends with "disallow the tail call". Is this intended?

Disalllowing tailcalls is not trivial unfortunately, while `-fno-optimize-sibling-calls` can stop the compiler from generating tail calls, clang's `musttail` attribute [1] is used in performance critical code. Though I don't think these would be "entrypoints" though. Perhaps we can check this where the entrypoint intrinsics are emitted?

[1] https://clang.llvm.org/docs/AttributeReference.html#musttail

https://github.com/llvm/llvm-project/pull/90821


More information about the llvm-commits mailing list