[llvm] [llvm][ctx_profile] Add instrumentation lowering (PR #90821)
Snehasish Kumar via llvm-commits
llvm-commits at lists.llvm.org
Wed May 8 09:39:28 PDT 2024
================
@@ -22,3 +32,298 @@ static cl::list<std::string> ContextRoots(
bool PGOCtxProfLoweringPass::isContextualIRPGOEnabled() {
return !ContextRoots.empty();
}
+
+// the names of symbols we expect in compiler-rt. Using a namespace for
+// readability.
+namespace CompilerRtAPINames {
+static auto StartCtx = "__llvm_ctx_profile_start_context";
+static auto ReleaseCtx = "__llvm_ctx_profile_release_context";
+static auto GetCtx = "__llvm_ctx_profile_get_context";
+static auto ExpectedCalleeTLS = "__llvm_ctx_profile_expected_callee";
+static auto CallsiteTLS = "__llvm_ctx_profile_callsite";
+} // namespace CompilerRtAPINames
+
+namespace {
+// The lowering logic and state.
+class CtxInstrumentationLowerer final {
+ Module &M;
+ ModuleAnalysisManager &MAM;
+ Type *ContextNodeTy = nullptr;
+ Type *ContextRootTy = nullptr;
+
+ DenseMap<const Function *, Constant *> ContextRootMap;
+ Function *StartCtx = nullptr;
+ Function *GetCtx = nullptr;
+ Function *ReleaseCtx = nullptr;
+ GlobalVariable *ExpectedCalleeTLS = nullptr;
+ GlobalVariable *CallsiteInfoTLS = nullptr;
+
+public:
+ CtxInstrumentationLowerer(Module &M, ModuleAnalysisManager &MAM);
+ // return true if lowering happened (i.e. a change was made)
+ bool lowerFunction(Function &F);
+};
+
+std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(const Function &F) {
+ uint32_t NrCounters = 0;
+ uint32_t NrCallsites = 0;
+ for (const auto &BB : F) {
+ for (const auto &I : BB) {
+ if (const auto *Incr = dyn_cast<InstrProfIncrementInst>(&I)) {
+ if (!NrCounters)
+ NrCounters =
+ static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
+ } else if (const auto *CSIntr = dyn_cast<InstrProfCallsite>(&I)) {
+ if (!NrCallsites)
+ NrCallsites =
+ static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
+ }
+ if (NrCounters && NrCallsites)
+ return std::make_pair(NrCounters, NrCallsites);
+ }
+ }
+ return {0, 0};
+}
+} // namespace
+
+// set up tie-in with compiler-rt.
+// NOTE!!!
+// These have to match compiler-rt/lib/ctx_profile/CtxInstrProfiling.h
+CtxInstrumentationLowerer::CtxInstrumentationLowerer(Module &M,
+ ModuleAnalysisManager &MAM)
+ : M(M), MAM(MAM) {
+ auto *PointerTy = PointerType::get(M.getContext(), 0);
+ auto *SanitizerMutexType = Type::getInt8Ty(M.getContext());
+ auto *I32Ty = Type::getInt32Ty(M.getContext());
+ auto *I64Ty = Type::getInt64Ty(M.getContext());
+
+ // The ContextRoot type
+ ContextRootTy =
+ StructType::get(M.getContext(), {
+ PointerTy, /*FirstNode*/
+ PointerTy, /*FirstMemBlock*/
+ PointerTy, /*CurrentMem*/
+ SanitizerMutexType, /*Taken*/
+ });
+ // The Context header.
+ ContextNodeTy = StructType::get(M.getContext(), {
+ I64Ty, /*Guid*/
+ PointerTy, /*Next*/
+ I32Ty, /*NrCounters*/
+ I32Ty, /*NrCallsites*/
+ });
+
+ // Define a global for each entrypoint. We'll reuse the entrypoint's name as
+ // prefix. We assume the entrypoint names to be unique.
+ for (const auto &Fname : ContextRoots) {
+ if (const auto *F = M.getFunction(Fname)) {
+ if (F->isDeclaration())
+ continue;
+ auto *G = M.getOrInsertGlobal(Fname + "_ctx_root", ContextRootTy);
+ cast<GlobalVariable>(G)->setInitializer(
+ Constant::getNullValue(ContextRootTy));
+ ContextRootMap.insert(std::make_pair(F, G));
+ }
+ }
+
+ // Declare the functions we will call.
+ StartCtx = cast<Function>(
+ M.getOrInsertFunction(
+ CompilerRtAPINames::StartCtx,
+ FunctionType::get(ContextNodeTy->getPointerTo(),
+ {ContextRootTy->getPointerTo(), /*ContextRoot*/
+ I64Ty, /*Guid*/ I32Ty,
+ /*NrCounters*/ I32Ty /*NrCallsites*/},
+ false))
+ .getCallee());
+ GetCtx = cast<Function>(
+ M.getOrInsertFunction(CompilerRtAPINames::GetCtx,
+ FunctionType::get(ContextNodeTy->getPointerTo(),
+ {PointerTy, /*Callee*/
+ I64Ty, /*Guid*/
+ I32Ty, /*NrCounters*/
+ I32Ty}, /*NrCallsites*/
+ false))
+ .getCallee());
+ ReleaseCtx = cast<Function>(
+ M.getOrInsertFunction(
+ CompilerRtAPINames::ReleaseCtx,
+ FunctionType::get(Type::getVoidTy(M.getContext()),
+ {
+ ContextRootTy->getPointerTo(), /*ContextRoot*/
+ },
+ false))
+ .getCallee());
+
+ // Declare the TLSes we will need to use.
+ CallsiteInfoTLS =
+ new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+ nullptr, CompilerRtAPINames::CallsiteTLS);
+ CallsiteInfoTLS->setThreadLocal(true);
+ CallsiteInfoTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+ ExpectedCalleeTLS =
+ new GlobalVariable(M, PointerTy, false, GlobalValue::ExternalLinkage,
+ nullptr, CompilerRtAPINames::ExpectedCalleeTLS);
+ ExpectedCalleeTLS->setThreadLocal(true);
+ ExpectedCalleeTLS->setVisibility(llvm::GlobalValue::HiddenVisibility);
+}
+
+PreservedAnalyses PGOCtxProfLoweringPass::run(Module &M,
+ ModuleAnalysisManager &MAM) {
+ CtxInstrumentationLowerer Lowerer(M, MAM);
+ bool Changed = false;
+ for (auto &F : M)
+ Changed |= Lowerer.lowerFunction(F);
+ return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
+}
+
+bool CtxInstrumentationLowerer::lowerFunction(Function &F) {
+ if (F.isDeclaration())
+ return false;
+ auto &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+ auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
+
+ Value *Guid = nullptr;
+ auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(F);
+
+ Value *Context = nullptr;
+ Value *RealContext = nullptr;
+
+ StructType *ThisContextType = nullptr;
+ Value *TheRootContext = nullptr;
+ Value *ExpectedCalleeTLSAddr = nullptr;
+ Value *CallsiteInfoTLSAddr = nullptr;
+
+ auto &Head = F.getEntryBlock();
+ for (auto &I : Head) {
+ // Find the increment intrinsic in the entry basic block.
+ if (auto *Mark = dyn_cast<InstrProfIncrementInst>(&I)) {
+ assert(Mark->getIndex()->isZero());
+
+ IRBuilder<> Builder(Mark);
+ // FIXME(mtrofin): use InstrProfSymtab::getCanonicalName
+ Guid = Builder.getInt64(F.getGUID());
+ // The type of the context of this function is now knowable since we have
+ // NrCallsites and NrCounters. We delcare it here because it's more
+ // convenient - we have the Builder.
+ ThisContextType = StructType::get(
+ F.getContext(),
+ {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
+ ArrayType::get(Builder.getPtrTy(), NrCallsites)});
+ // Figure out which way we obtain the context object for this function -
+ // if it's an entrypoint, then we call StartCtx, otherwise GetCtx. In the
+ // former case, we also set TheRootContext since we need it to release it
----------------
snehasish wrote:
typo: "need it to" -> "need to"?
https://github.com/llvm/llvm-project/pull/90821
More information about the llvm-commits
mailing list