[llvm] 3f18a0a - [nfc] Improve testability of PGOInstrumentationGen (#104490)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 16 09:45:34 PDT 2024


Author: Mircea Trofin
Date: 2024-08-16T09:45:29-07:00
New Revision: 3f18a0a71cc29c502591a3d29a1845a011415f2a

URL: https://github.com/llvm/llvm-project/commit/3f18a0a71cc29c502591a3d29a1845a011415f2a
DIFF: https://github.com/llvm/llvm-project/commit/3f18a0a71cc29c502591a3d29a1845a011415f2a.diff

LOG: [nfc] Improve testability of PGOInstrumentationGen (#104490)

Passing to the `PGOInstrumentationGen` pass whether it needs to produce contextual profiling instrumentation as a flag, in the process restructuring a bit the places that need to be aware of that (some were unnecessarily in `PGOInstrumentationUse`)

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
    llvm/lib/Passes/PassBuilderPipelines.cpp
    llvm/lib/Passes/PassRegistry.def
    llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
    llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
    llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h b/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
index 7199f27dbc991a..1d0af87e965af5 100644
--- a/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
+++ b/llvm/include/llvm/Transforms/Instrumentation/PGOInstrumentation.h
@@ -53,15 +53,18 @@ class PGOInstrumentationGenCreateVar
   bool ProfileSampling;
 };
 
+enum class PGOInstrumentationType { Invalid = 0, FDO, CSFDO, CTXPROF };
 /// The instrumentation (profile-instr-gen) pass for IR based PGO.
 class PGOInstrumentationGen : public PassInfoMixin<PGOInstrumentationGen> {
 public:
-  PGOInstrumentationGen(bool IsCS = false) : IsCS(IsCS) {}
+  PGOInstrumentationGen(
+      PGOInstrumentationType InstrumentationType = PGOInstrumentationType ::FDO)
+      : InstrumentationType(InstrumentationType) {}
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
 
 private:
   // If this is a context sensitive instrumentation.
-  bool IsCS;
+  const PGOInstrumentationType InstrumentationType;
 };
 
 /// The profile annotation (profile-instr-use) pass for IR based PGO.

diff  --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 0201e69f3e216a..1184123c7710f0 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -844,7 +844,8 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
   }
 
   // Perform PGO instrumentation.
-  MPM.addPass(PGOInstrumentationGen(IsCS));
+  MPM.addPass(PGOInstrumentationGen(IsCS ? PGOInstrumentationType::CSFDO
+                                         : PGOInstrumentationType::FDO));
 
   addPostPGOLoopRotation(MPM, Level);
   // Add the profile lowering pass.
@@ -879,7 +880,8 @@ void PassBuilder::addPGOInstrPassesForO0(
   }
 
   // Perform PGO instrumentation.
-  MPM.addPass(PGOInstrumentationGen(IsCS));
+  MPM.addPass(PGOInstrumentationGen(IsCS ? PGOInstrumentationType::CSFDO
+                                         : PGOInstrumentationType::FDO));
   // Add the profile lowering pass.
   InstrProfOptions Options;
   if (!ProfileFile.empty())
@@ -1193,7 +1195,7 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
                       PGOOpt->ProfileFile, PGOOpt->ProfileRemappingFile,
                       PGOOpt->FS);
   } else if (IsCtxProfGen || IsCtxProfUse) {
-    MPM.addPass(PGOInstrumentationGen(false));
+    MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
     // In pre-link, we just want the instrumented IR. We use the contextual
     // profile in the post-thinlink phase.
     // The instrumentation will be removed in post-thinlink after IPO.

diff  --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 442c972fc616fc..a11fc3755494ab 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -56,6 +56,8 @@ MODULE_PASS("constmerge", ConstantMergePass())
 MODULE_PASS("coro-cleanup", CoroCleanupPass())
 MODULE_PASS("coro-early", CoroEarlyPass())
 MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
+MODULE_PASS("ctx-instr-gen",
+            PGOInstrumentationGen(PGOInstrumentationType::CTXPROF))
 MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
 MODULE_PASS("debugify", NewPMDebugifyPass())
 MODULE_PASS("dfsan", DataFlowSanitizerPass())

diff  --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index 41618194d12ed7..b3644031c5a44b 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -110,7 +110,6 @@
 #include "llvm/Transforms/Instrumentation.h"
 #include "llvm/Transforms/Instrumentation/BlockCoverageInference.h"
 #include "llvm/Transforms/Instrumentation/CFGMST.h"
-#include "llvm/Transforms/Instrumentation/PGOCtxProfLowering.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/MisExpect.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -321,7 +320,6 @@ static cl::opt<unsigned> PGOFunctionCriticalEdgeThreshold(
              " greater than this threshold."));
 
 extern cl::opt<unsigned> MaxNumVTableAnnotations;
-extern cl::opt<std::string> UseCtxProfile;
 
 namespace llvm {
 // Command line option to turn on CFG dot dump after profile annotation.
@@ -339,21 +337,43 @@ extern cl::opt<bool> EnableVTableProfileUse;
 extern cl::opt<InstrProfCorrelator::ProfCorrelatorKind> ProfileCorrelate;
 } // namespace llvm
 
-bool shouldInstrumentForCtxProf() {
-  return PGOCtxProfLoweringPass::isCtxIRPGOInstrEnabled() ||
-         !UseCtxProfile.empty();
-}
-bool shouldInstrumentEntryBB() {
-  return PGOInstrumentEntry || shouldInstrumentForCtxProf();
-}
+namespace {
+class FunctionInstrumenter final {
+  Module &M;
+  Function &F;
+  TargetLibraryInfo &TLI;
+  std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
+  BranchProbabilityInfo *const BPI;
+  BlockFrequencyInfo *const BFI;
 
-// FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls. Ctx
-// profiling implicitly captures indirect call cases, but not other values.
-// Supporting other values is relatively straight-forward - just another counter
-// range within the context.
-bool isValueProfilingDisabled() {
-  return DisableValueProfiling || shouldInstrumentForCtxProf();
-}
+  const PGOInstrumentationType InstrumentationType;
+
+  // FIXME(mtrofin): re-enable this for ctx profiling, for non-indirect calls.
+  // Ctx profiling implicitly captures indirect call cases, but not other
+  // values. Supporting other values is relatively straight-forward - just
+  // another counter range within the context.
+  bool isValueProfilingDisabled() const {
+    return DisableValueProfiling ||
+           InstrumentationType == PGOInstrumentationType::CTXPROF;
+  }
+
+  bool shouldInstrumentEntryBB() const {
+    return PGOInstrumentEntry ||
+           InstrumentationType == PGOInstrumentationType::CTXPROF;
+  }
+
+public:
+  FunctionInstrumenter(
+      Module &M, Function &F, TargetLibraryInfo &TLI,
+      std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
+      BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr,
+      PGOInstrumentationType InstrumentationType = PGOInstrumentationType::FDO)
+      : M(M), F(F), TLI(TLI), ComdatMembers(ComdatMembers), BPI(BPI), BFI(BFI),
+        InstrumentationType(InstrumentationType) {}
+
+  void instrument();
+};
+} // namespace
 
 // Return a string describing the branch condition that can be
 // used in static branch probability heuristics:
@@ -395,13 +415,16 @@ static const char *ValueProfKindDescr[] = {
 
 // Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime
 // aware this is an ir_level profile so it can set the version flag.
-static GlobalVariable *createIRLevelProfileFlagVar(Module &M, bool IsCS) {
+static GlobalVariable *
+createIRLevelProfileFlagVar(Module &M,
+                            PGOInstrumentationType InstrumentationType) {
   const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
   Type *IntTy64 = Type::getInt64Ty(M.getContext());
   uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
-  if (IsCS)
+  if (InstrumentationType == PGOInstrumentationType::CSFDO)
     ProfileVersion |= VARIANT_MASK_CSIR_PROF;
-  if (shouldInstrumentEntryBB())
+  if (PGOInstrumentEntry ||
+      InstrumentationType == PGOInstrumentationType::CTXPROF)
     ProfileVersion |= VARIANT_MASK_INSTR_ENTRY;
   if (DebugInfoCorrelate || ProfileCorrelate == InstrProfCorrelator::DEBUG_INFO)
     ProfileVersion |= VARIANT_MASK_DBG_CORRELATE;
@@ -871,11 +894,7 @@ populateEHOperandBundle(VPCandidateInfo &Cand,
 
 // Visit all edge and instrument the edges not in MST, and do value profiling.
 // Critical edges will be split.
-static void instrumentOneFunc(
-    Function &F, Module *M, TargetLibraryInfo &TLI, BranchProbabilityInfo *BPI,
-    BlockFrequencyInfo *BFI,
-    std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
-    bool IsCS) {
+void FunctionInstrumenter::instrument() {
   if (!PGOBlockCoverage) {
     // Split indirectbr critical edges here before computing the MST rather than
     // later in getInstrBB() to avoid invalidating it.
@@ -883,19 +902,20 @@ static void instrumentOneFunc(
   }
 
   FuncPGOInstrumentation<PGOEdge, PGOBBInfo> FuncInfo(
-      F, TLI, ComdatMembers, true, BPI, BFI, IsCS, shouldInstrumentEntryBB(),
-      PGOBlockCoverage);
+      F, TLI, ComdatMembers, true, BPI, BFI,
+      InstrumentationType == PGOInstrumentationType::CSFDO,
+      shouldInstrumentEntryBB(), PGOBlockCoverage);
 
   auto Name = FuncInfo.FuncNameVar;
-  auto CFGHash = ConstantInt::get(Type::getInt64Ty(M->getContext()),
-                                  FuncInfo.FunctionHash);
+  auto CFGHash =
+      ConstantInt::get(Type::getInt64Ty(M.getContext()), FuncInfo.FunctionHash);
   if (PGOFunctionEntryCoverage) {
     auto &EntryBB = F.getEntryBlock();
     IRBuilder<> Builder(&EntryBB, EntryBB.getFirstInsertionPt());
     // llvm.instrprof.cover(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                      i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_cover),
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_cover),
         {Name, CFGHash, Builder.getInt32(1), Builder.getInt32(0)});
     return;
   }
@@ -905,9 +925,9 @@ static void instrumentOneFunc(
   unsigned NumCounters =
       InstrumentBBs.size() + FuncInfo.SIVisitor.getNumOfSelectInsts();
 
-  if (shouldInstrumentForCtxProf()) {
+  if (InstrumentationType == PGOInstrumentationType::CTXPROF) {
     auto *CSIntrinsic =
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_callsite);
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_callsite);
     // We want to count the instrumentable callsites, then instrument them. This
     // is because the llvm.instrprof.callsite intrinsic has an argument (like
     // the other instrprof intrinsics) capturing the total number of
@@ -950,7 +970,7 @@ static void instrumentOneFunc(
     // llvm.instrprof.timestamp(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                          i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, Intrinsic::instrprof_timestamp),
+        Intrinsic::getDeclaration(&M, Intrinsic::instrprof_timestamp),
         {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I)});
     I += PGOBlockCoverage ? 8 : 1;
   }
@@ -962,9 +982,9 @@ static void instrumentOneFunc(
     // llvm.instrprof.increment(i8* <name>, i64 <hash>, i32 <num-counters>,
     //                          i32 <index>)
     Builder.CreateCall(
-        Intrinsic::getDeclaration(M, PGOBlockCoverage
-                                         ? Intrinsic::instrprof_cover
-                                         : Intrinsic::instrprof_increment),
+        Intrinsic::getDeclaration(&M, PGOBlockCoverage
+                                          ? Intrinsic::instrprof_cover
+                                          : Intrinsic::instrprof_increment),
         {Name, CFGHash, Builder.getInt32(NumCounters), Builder.getInt32(I++)});
   }
 
@@ -1011,7 +1031,7 @@ static void instrumentOneFunc(
       SmallVector<OperandBundleDef, 1> OpBundles;
       populateEHOperandBundle(Cand, BlockColors, OpBundles);
       Builder.CreateCall(
-          Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile),
+          Intrinsic::getDeclaration(&M, Intrinsic::instrprof_value_profile),
           {FuncInfo.FuncNameVar, Builder.getInt64(FuncInfo.FunctionHash),
            ToProfile, Builder.getInt32(Kind), Builder.getInt32(SiteIndex++)},
           OpBundles);
@@ -1746,7 +1766,7 @@ static uint32_t getMaxNumAnnotations(InstrProfValueKind ValueProfKind) {
 
 // Traverse all valuesites and annotate the instructions for all value kind.
 void PGOUseFunc::annotateValueSites() {
-  if (isValueProfilingDisabled())
+  if (DisableValueProfiling)
     return;
 
   // Create the PGOFuncName meta data.
@@ -1861,11 +1881,12 @@ static bool skipPGOGen(const Function &F) {
 static bool InstrumentAllFunctions(
     Module &M, function_ref<TargetLibraryInfo &(Function &)> LookupTLI,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
-    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
+    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+    PGOInstrumentationType InstrumentationType) {
   // For the context-sensitve instrumentation, we should have a separated pass
   // (before LTO/ThinLTO linking) to create these variables.
-  if (!IsCS && !shouldInstrumentForCtxProf())
-    createIRLevelProfileFlagVar(M, /*IsCS=*/false);
+  if (InstrumentationType == PGOInstrumentationType::FDO)
+    createIRLevelProfileFlagVar(M, InstrumentationType);
 
   Triple TT(M.getTargetTriple());
   LLVMContext &Ctx = M.getContext();
@@ -1884,7 +1905,9 @@ static bool InstrumentAllFunctions(
     auto &TLI = LookupTLI(F);
     auto *BPI = LookupBPI(F);
     auto *BFI = LookupBFI(F);
-    instrumentOneFunc(F, &M, TLI, BPI, BFI, ComdatMembers, IsCS);
+    FunctionInstrumenter FI(M, F, TLI, ComdatMembers, BPI, BFI,
+                            InstrumentationType);
+    FI.instrument();
   }
   return true;
 }
@@ -1894,7 +1917,8 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) {
   createProfileFileNameVar(M, CSInstrName);
   // The variable in a comdat may be discarded by LTO. Ensure the declaration
   // will be retained.
-  appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true));
+  appendToCompilerUsed(
+      M, createIRLevelProfileFlagVar(M, PGOInstrumentationType::CSFDO));
   if (ProfileSampling)
     createProfileSamplingVar(M);
   PreservedAnalyses PA;
@@ -1916,7 +1940,8 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M,
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
 
-  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI, IsCS))
+  if (!InstrumentAllFunctions(M, LookupTLI, LookupBPI, LookupBFI,
+                              InstrumentationType))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
@@ -2115,7 +2140,6 @@ static bool annotateAllFunctions(
   bool InstrumentFuncEntry = PGOReader->instrEntryBBEnabled();
   if (PGOInstrumentEntry.getNumOccurrences() > 0)
     InstrumentFuncEntry = PGOInstrumentEntry;
-  InstrumentFuncEntry |= shouldInstrumentForCtxProf();
 
   bool HasSingleByteCoverage = PGOReader->hasSingleByteCoverage();
   for (auto &F : M) {

diff  --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
index 99c7762a67dfbd..454780153b8236 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation-invalid-roots.ll
@@ -1,4 +1,4 @@
-; RUN: not opt -passes=pgo-instr-gen,ctx-instr-lower -profile-context-root=good \
+; RUN: not opt -passes=ctx-instr-gen,ctx-instr-lower -profile-context-root=good \
 ; RUN:   -profile-context-root=bad \
 ; RUN:   -S < %s 2>&1 | FileCheck %s
 

diff  --git a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
index a70f94e1521f0d..df4e467567c46e 100644
--- a/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
+++ b/llvm/test/Transforms/PGOProfile/ctx-instrumentation.ll
@@ -1,7 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 4
-; RUN: opt -passes=pgo-instr-gen -profile-context-root=an_entrypoint \
+; RUN: opt -passes=ctx-instr-gen -profile-context-root=an_entrypoint \
 ; RUN:   -S < %s | FileCheck --check-prefix=INSTRUMENT %s
-; RUN: opt -passes=pgo-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
+; RUN: opt -passes=ctx-instr-gen,assign-guid,ctx-instr-lower -profile-context-root=an_entrypoint \
 ; RUN:   -profile-context-root=another_entrypoint_no_callees \
 ; RUN:   -S < %s | FileCheck --check-prefix=LOWERING %s
 


        


More information about the llvm-commits mailing list