[llvm] [nfc] Improve testability of PGOInstrumentationGen (PR #104490)

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 15 12:43:02 PDT 2024


https://github.com/mtrofin created https://github.com/llvm/llvm-project/pull/104490

None

>From c82371a92fe03e7dc2f0fba3cbc47e8ac1a60874 Mon Sep 17 00:00:00 2001
From: Mircea Trofin <mtrofin at google.com>
Date: Thu, 15 Aug 2024 09:26:02 -0700
Subject: [PATCH] [nfc] Improve testability of PGOInstrumentationGen

---
 .../Instrumentation/PGOInstrumentation.h      |  6 +-
 llvm/lib/Passes/PassBuilderPipelines.cpp      |  2 +-
 llvm/lib/Passes/PassRegistry.def              |  1 +
 .../Instrumentation/PGOInstrumentation.cpp    | 95 +++++++++++--------
 .../ctx-instrumentation-invalid-roots.ll      |  2 +-
 .../PGOProfile/ctx-instrumentation.ll         |  4 +-
 6 files changed, 66 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h b/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
index 7199f27dbc991a..f796a05ef86e60 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
@@ -56,12 +56,14 @@ class PGOInstrumentationGenCreateVar
 /// The instrumentation (profile-instr-gen) pass for IR based PGO.
 class PGOInstrumentationGen : public PassInfoMixin<PGOInstrumentationGen> {
 public:
-  PGOInstrumentationGen(bool IsCS = false) : IsCS(IsCS) {}
+  PGOInstrumentationGen(bool IsCS = false, bool IsCtxProf = false)
+      : IsCS(IsCS), IsCtxProf(IsCtxProf) {}
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
 
 private:
   // If this is a context sensitive instrumentation.
-  bool IsCS;
+  const bool IsCS;
+  const bool IsCtxProf;
 };
 
 /// The profile annotation (profile-instr-use) pass for IR based PGO.
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 0201e69f3e216a..b4b2bdcccc38b1 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -1193,7 +1193,7 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
                       PGOOpt->ProfileFile, PGOOpt->ProfileRemappingFile,
                       PGOOpt->FS);
   } else if (IsCtxProfGen || IsCtxProfUse) {
-    MPM.addPass(PGOInstrumentationGen(false));
+    MPM.addPass(PGOInstrumentationGen(/*IsCS=*/false, /*IsCtxProf=*/true));
     // In pre-link, we just want the instrumented IR. We use the contextual
     // profile in the post-thinlink phase.
     // The instrumentation will be removed in post-thinlink after IPO.
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 18f4aa19224da0..d3babded763920 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -56,6 +56,7 @@ MODULE_PASS("constmerge", ConstantMergePass())
 MODULE_PASS("coro-cleanup", CoroCleanupPass())
 MODULE_PASS("coro-early", CoroEarlyPass())
 MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
+MODULE_PASS("ctx-instr-gen", PGOInstrumentationGen(false, true))
 MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
 MODULE_PASS("debugify", NewPMDebugifyPass())
 MODULE_PASS("dfsan", DataFlowSanitizerPass())
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 41618194d12ed7..7867bee5c6b080 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -110,7 +110,6 @@
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
 #include "llvm/Transforms/Instrumentation/CFGMST.h"
-#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -321,7 +320,6 @@ static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold(
              " greater than this threshold."));
 
 extern cl::opt<unsigned> MaxNumVTableAnnotations;
-extern cl::opt<std::string> UseCtxProfile;
 
 namespace llvm {
 // Command line option to turn on CFG dot dump after profile annotation.
@@ -339,21 +337,41 @@ extern cl::opt<bool> EnableVTableProfileUse;
 extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
 } // namespace llvm
 
-bool shouldInstrumentForCtxProf() {
-  return PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() ||
-         !UseCtxProfile.empty();
-}
-bool shouldInstrumentEntryBB() {
-  return PGOInstrumentEntry || shouldInstrumentForCtxProf();
-}
+namespace {
+class FunctionInstrumenter final {
+  Module &M;
+  Function &F;
+  TargetLibraryInfo &TLI;
+  std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
+  BranchProbabilityInfo *const BPI;
+  BlockFrequencyInfo *const BFI;
 
-// FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx
-// profiling implicitly captures indirect call cases, but not other values.
-// Supporting other values is relatively straight-forward - just another counter
-// range within the context.
-bool isValueProfilingDisabled() {
-  return DisableValueProfiling || shouldInstrumentForCtxProf();
-}
+  const bool IsCS;
+  const bool IsCtxProf;
+  // FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls.
+  // Ctx profiling implicitly captures indirect call cases, but not other
+  // values. Supporting other values is relatively straight-forward - just
+  // another counter range within the context.
+  bool isValueProfilingDisabled() const {
+    return DisableValueProfiling || IsCtxProf;
+  }
+
+  bool shouldInstrumentEntryBB() const {
+    return PGOInstrumentEntry || IsCtxProf;
+  }
+
+public:
+  FunctionInstrumenter(
+      Module &M, Function &F, TargetLibraryInfo &TLI,
+      std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
+      BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr,
+      bool IsCS = false, bool IsCtxProf = false)
+      : M(M), F(F), TLI(TLI), ComdatMembers(ComdatMembers), BPI(BPI), BFI(BFI),
+        IsCS(IsCS), IsCtxProf(IsCtxProf) {}
+  
+  void instrument();
+};
+} // namespace
 
 // Return a string describing the branch condition that can be
 // used in static branch probability heuristics:
@@ -395,13 +413,14 @@ static const char *ValueProfKindDescr[] = {
 
 // Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime
 // aware this is an ir_level profile so it can set the version flag.
-static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
+static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS,
+                                                   bool IsCtxProf) {
   const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
   Type *IntTy64 = Type::getInt64Ty(M.getContext());
   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
   if (IsCS)
     ProfileVersion |= VARIANT_MASK_CSIR_PROF;
-  if (shouldInstrumentEntryBB())
+  if (PGOInstrumentEntry || IsCtxProf)
     ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
   if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
     ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -871,11 +890,7 @@ populateEHOperandBundle(VPCandidateInfo &Cand,
 
 // Visit all edge and instrument the edges not in MST, and do value profiling.
 // Critical edges will be split.
-static void instrumentOneFunc(
-    Function &F, Module *M, TargetLibraryInfo &TLI, BranchProbabilityInfo *BPI,
-    BlockFrequencyInfo *BFI,
-    std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
-    bool IsCS) {
+void FunctionInstrumenter::instrument() {
   if (!PGOBlockCoverage) {
     // Split indirectbr critical edges here before computing the MST rather than
     // later in getInstrBB() to avoid invalidating it.
@@ -887,7 +902,7 @@ static void instrumentOneFunc(
       PGOBlockCoverage);
 
   auto Name = FuncInfo.FuncNameVar;
-  auto CFGHash = ConstantInt::get(Type::getInt64Ty(M->getContext()),
+  auto CFGHash = ConstantInt::get(Type::getInt64Ty(M.getContext()),
                                   FuncInfo.FunctionHash);
   if (PGOFunctionEntryCoverage) {
     auto &EntryBB = F.getEntryBlock();
@@ -895,7 +910,7 @@ static void instrumentOneFunc(
     // llvm.instrprof.cover(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                      i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_cover),
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_cover),
         {Name, CFGHash, Builder.getInt32(1), Builder.getInt32(0)});
     return;
   }
@@ -905,9 +920,9 @@ static void instrumentOneFunc(
   unsigned NumCounters =
       InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
 
-  if (shouldInstrumentForCtxProf()) {
+  if (IsCtxProf) {
     auto *CSIntrinsic =
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_callsite);
     // We want to count the instrumentable callsites, then instrument them. This
     // is because the llvm.instrprof.callsite intrinsic has an argument (like
     // the other instrprof intrinsics) capturing the total number of
@@ -950,7 +965,7 @@ static void instrumentOneFunc(
     // llvm.instrprof.timestamp(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                          i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_timestamp),
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_timestamp),
         {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I)});
     I += PGOBlockCoverage ? 8 : 1;
   }
@@ -962,7 +977,7 @@ static void instrumentOneFunc(
     // llvm.instrprof.increment(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                          i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, PGOBlockCoverage
+        Intrinsic::getDeclaration(&M, PGOBlockCoverage
                                          ? Intrinsic::instrprof_cover
                                          : Intrinsic::instrprof_increment),
         {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I++)});
@@ -1011,7 +1026,7 @@ static void instrumentOneFunc(
       SmallVector<OperandBundleDef, 1> OpBundles;
       populateEHOperandBundle(Cand, BlockColors, OpBundles);
       Builder.CreateCall(
-          Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
+          Intrinsic::getDeclaration(&M, Intrinsic::instrprof_value_profile),
           {FuncInfo.FuncNameVar, Builder.getInt64(FuncInfo.FunctionHash),
            ToProfile, Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)},
           OpBundles);
@@ -1746,7 +1761,7 @@ static uint32_t getMaxNumAnnotations(InstrProfValueKind ValueProfKind) {
 
 // Traverse all valuesites and annotate the instructions for all value kind.
 void PGOUseFunc::annotateValueSites() {
-  if (isValueProfilingDisabled())
+  if (DisableValueProfiling)
     return;
 
   // Create the PGOFuncName meta data.
@@ -1861,11 +1876,12 @@ static bool skipPGOGen(const Function &F) {
 static bool InstrumentAllFunctions(
     Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
-    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
+    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS,
+    bool IsCtxProf) {
   // For the context-sensitve instrumentation, we should have a separated pass
   // (before LTO/ThinLTO linking) to create these variables.
-  if (!IsCS && !shouldInstrumentForCtxProf())
-    createIRLevelProfileFlagVar(M, /*IsCS=*/false);
+  if (!IsCS && !IsCtxProf)
+    createIRLevelProfileFlagVar(M, /*IsCS=*/false, /*IsCtxProf=*/IsCtxProf);
 
   Triple TT(M.getTargetTriple());
   LLVMContext &Ctx = M.getContext();
@@ -1884,7 +1900,9 @@ static bool InstrumentAllFunctions(
     auto &TLI = LookupTLI(F);
     auto *BPI = LookupBPI(F);
     auto *BFI = LookupBFI(F);
-    instrumentOneFunc(F, &M, TLI, BPI, BFI, ComdatMembers, IsCS);
+    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI, IsCS,
+                            IsCtxProf);
+    FI.instrument();
   }
   return true;
 }
@@ -1894,7 +1912,8 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) {
   createProfileFileNameVar(M, CSInstrName);
   // The variable in a comdat may be discarded by LTO. Ensure the declaration
   // will be retained.
-  appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true));
+  appendToCompilerUsed(
+      M, createIRLevelProfileFlagVar(M, /*IsCS=*/true, /*IsCtxProf=*/false));
   if (ProfileSampling)
     createProfileSamplingVar(M);
   PreservedAnalyses PA;
@@ -1916,7 +1935,8 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
 
-  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, IsCS))
+  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, IsCS,
+                              IsCtxProf))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
@@ -2115,7 +2135,6 @@ static bool annotateAllFunctions(
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
   if (PGOInstrumentEntry.getNumOccurrences() > 0)
     InstrumentFuncEntry = PGOInstrumentEntry;
-  InstrumentFuncEntry |= shouldInstrumentForCtxProf();
 
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
index 99c7762a67dfbd..454780153b8236 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -passes=pgo-instr-gen,ctx-instr-lower -profile-context-root=good \
+; RUN: not opt -passes=ctx-instr-gen,ctx-instr-lower -profile-context-root=good \
 ; RUN:   -profile-context-root=bad \
 ; RUN:   -S < %s 2>&1 | FileCheck %s
 
diff --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index a70f94e1521f0d..df4e467567c46e 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
-; RUN: opt -passes=pgo-instr-gen -profile-context-root=an_entrypoint \
+; RUN: opt -passes=ctx-instr-gen -profile-context-root=an_entrypoint \
 ; RUN:   -S < %s | FileCheck --check-prefix=INSTRUMENT %s
-; RUN: opt -passes=pgo-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN: opt -passes=ctx-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
 ; RUN:   -profile-context-root=another_entrypoint_no_callees \
 ; RUN:   -S < %s | FileCheck --check-prefix=LOWERING %s
 



More information about the llvm-commits mailing list