[llvm] [Pass] Support start/stop in instrumentation (PR #70912)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 9 22:03:44 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/70912

>From 7cc6b7b07d9b2057a108781a24e38b4c33e6d7c8 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Wed, 3 Jan 2024 10:34:33 +0800
Subject: [PATCH] [CodeGen][NewPM] Support start/stop in CodeGen

---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h | 161 +++++++++++++++---
 llvm/include/llvm/CodeGen/TargetPassConfig.h  |  15 ++
 llvm/lib/CodeGen/TargetPassConfig.cpp         |  34 ++++
 llvm/lib/IR/PassInstrumentation.cpp           |   8 +
 llvm/lib/Passes/PassBuilder.cpp               |   4 +-
 5 files changed, 198 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 2100c30aad1180..04f4a43fea7160 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -40,6 +40,7 @@
 #include "llvm/CodeGen/SelectOptimize.h"
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/UnreachableBlockElim.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -172,8 +173,9 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM, bool DebugPM, bool Check = true)
-        : MPM(MPM) {
+    AddIRPass(ModulePassManager &MPM, const DerivedT &PB, bool DebugPM,
+              bool Check = true)
+        : MPM(MPM), PB(PB) {
       if (Check)
         AddingFunctionPasses = false;
     }
@@ -184,10 +186,17 @@ template <typename DerivedT> class CodeGenPassBuilder {
     // Add Function Pass
     template <typename PassT>
     std::enable_if_t<is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
+    operator()(PassT &&Pass, StringRef Name = PassT::name()) {
       if (AddingFunctionPasses && !*AddingFunctionPasses)
         AddingFunctionPasses = true;
+      for (const auto C : PB.BeforeCallbacks)
+        if (!C(Name))
+          return;
+
       FPM.addPass(std::forward<PassT>(Pass));
+
+      for (const auto &C : PB.AfterCallbacks)
+        C(Name);
     }
 
     // Add Module Pass
@@ -197,12 +206,20 @@ template <typename DerivedT> class CodeGenPassBuilder {
     operator()(PassT &&Pass) {
       assert((!AddingFunctionPasses || !*AddingFunctionPasses) &&
              "could not add module pass after adding function pass");
+      for (const auto C : PB.BeforeCallbacks)
+        if (!C(PassT::name()))
+          return;
+
       MPM.addPass(std::forward<PassT>(Pass));
+
+      for (const auto &C : PB.AfterCallbacks)
+        C(PassT::name());
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
+    const DerivedT &PB;
     // The codegen IR pipeline are mostly function passes with the exceptions of
     // a few loop and module passes. `AddingFunctionPasses` make sures that
     // we could only add module passes at the beginning of the pipeline. Once
@@ -216,41 +233,34 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen machine passes.
   class AddMachinePass {
   public:
-    AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+    AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+        : PM(PM), PB(PB) {}
 
     template <typename PassT> void operator()(PassT &&Pass) {
       static_assert(
           is_detected<has_key_t, PassT>::value,
           "Machine function pass must define a static member variable `Key`.");
-      for (auto &C : BeforeCallbacks)
-        if (!C(&PassT::Key))
+      for (auto &C : PB.BeforeCallbacks)
+        if (!C(PassT::name()))
           return;
       PM.addPass(std::forward<PassT>(Pass));
-      for (auto &C : AfterCallbacks)
-        C(&PassT::Key);
+      for (auto &C : PB.AfterCallbacks)
+        C(PassT::name());
     }
 
     template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
-      AfterCallbacks.emplace_back(
+      PB.AfterCallbacks.emplace_back(
           [this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
             if (PassID == ID)
               this->PM.addPass(std::move(Pass));
           });
     }
 
-    void disablePass(MachinePassKey *ID) {
-      BeforeCallbacks.emplace_back(
-          [ID](MachinePassKey *PassID) { return PassID != ID; });
-    }
-
     MachineFunctionPassManager releasePM() { return std::move(PM); }
 
   private:
     MachineFunctionPassManager &PM;
-    SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
-        BeforeCallbacks;
-    SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
-        AfterCallbacks;
+    const DerivedT &PB;
   };
 
   LLVMTargetMachine &TM;
@@ -478,6 +488,19 @@ template <typename DerivedT> class CodeGenPassBuilder {
   const DerivedT &derived() const {
     return static_cast<const DerivedT &>(*this);
   }
+
+  void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+  Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
+
+  mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+      BeforeCallbacks;
+  mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+  /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+  mutable bool ShouldAddPass = true;
+  mutable bool Started = true;
+  mutable bool Stopped = true;
 };
 
 template <typename Derived>
@@ -485,12 +508,16 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
     ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
     raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
     CodeGenFileType FileType) const {
-  AddIRPass addIRPass(MPM, Opt.DebugPM);
+  auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+  if (StartStopInfo)
+    return StartStopInfo.takeError();
+  setStartStopPasses(*StartStopInfo);
+  AddIRPass addIRPass(MPM, derived(), Opt.DebugPM);
   // `ProfileSummaryInfo` is always valid.
   addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
   addISelPasses(addIRPass);
 
-  AddMachinePass addPass(MFPM);
+  AddMachinePass addPass(MFPM, derived());
   if (auto Err = addCoreISelPasses(addPass))
     return std::move(Err);
 
@@ -503,7 +530,94 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
       });
 
   addPass(FreeMachineFunctionPass());
-  return Error::success();
+  return verifyStartStop(StartStopInfo);
+}
+
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (!Info.StartPass.empty()) {
+    ShouldAddPass = false;
+    Started = false;
+    if (Info.StartBefore) {
+      BeforeCallbacks.emplace_back(
+          [this, &Info, Count = 0](StringRef ClassName) mutable {
+            auto PassName = PIC->getPassNameForClassName(ClassName);
+            if (Count == Info.StartInstanceNum)
+              return true;
+            if (Info.StartPass == PassName)
+              ++Count;
+            if (Count == Info.StartInstanceNum) {
+              ShouldAddPass = true;
+              Started = true;
+            }
+            return ShouldAddPass;
+          });
+    } else {
+      AfterCallbacks.emplace_back(
+          [this, &Info, Count = 0](StringRef ClassName) mutable {
+            auto PassName = PIC->getPassNameForClassName(ClassName);
+            if (Count == Info.StartInstanceNum)
+              return;
+            if (Info.StartPass == PassName)
+              ++Count;
+            if (Count == Info.StartInstanceNum) {
+              ShouldAddPass = true;
+              Started = true;
+            }
+          });
+    }
+  }
+
+  if (!Info.StopPass.empty()) {
+    Stopped = false;
+    if (Info.StopBefore) {
+      BeforeCallbacks.emplace_back(
+          [this, &Info, Count = 0u](StringRef ClassName) mutable {
+            auto PassName = PIC->getPassNameForClassName(ClassName);
+            if (Count == Info.StopInstanceNum)
+              return false;
+            if (Info.StopPass == PassName)
+              ++Count;
+            if (Count == Info.StopInstanceNum) {
+              ShouldAddPass = false;
+              Stopped = true;
+            }
+            return ShouldAddPass;
+          });
+    } else {
+      AfterCallbacks.emplace_back(
+          [this, &Info, Count = 0u](StringRef ClassName) mutable {
+            auto PassName = PIC->getPassNameForClassName(ClassName);
+            if (Count == Info.StopInstanceNum)
+              return;
+            if (Info.StopPass == PassName)
+              ++Count;
+            if (Count == Info.StopInstanceNum) {
+              ShouldAddPass = false;
+              Stopped = true;
+            }
+          });
+    }
+  }
+}
+
+template <typename Derived>
+Error CodeGenPassBuilder<Derived>::verifyStartStop(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (Started && Stopped)
+    return Error::success();
+
+  if (!Started)
+    return make_error<StringError>(
+        "Can't find start pass \"" +
+            PIC->getPassNameForClassName(Info.StartPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!Stopped)
+    return make_error<StringError>(
+        "Can't find stop pass \"" +
+            PIC->getPassNameForClassName(Info.StopPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
 }
 
 static inline AAManager registerAAAnalyses() {
@@ -623,8 +737,9 @@ void CodeGenPassBuilder<Derived>::addIRPasses(AddIRPass &addPass) const {
 
   // Run loop strength reduction before anything else.
   if (getOptLevel() != CodeGenOptLevel::None && !Opt.DisableLSR) {
-    addPass(createFunctionToLoopPassAdaptor(
-        LoopStrengthReducePass(), /*UseMemorySSA*/ true, Opt.DebugPM));
+    addPass(createFunctionToLoopPassAdaptor(LoopStrengthReducePass(),
+                                            /*UseMemorySSA*/ true, Opt.DebugPM),
+            LoopStrengthReducePass::name());
     // FIXME: use -stop-after so we could remove PrintLSR
     if (Opt.PrintLSR)
       addPass(PrintFunctionPass(dbgs(), "\n\n*** Code after LSR ***\n"));
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..2823485562c20f 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
 
 #include "llvm/Pass.h"
 #include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
 #include <cassert>
 #include <string>
 
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
   static std::string
   getLimitedCodeGenPipelineReason(const char *Separator = "/");
 
+  struct StartStopInfo {
+    bool StartBefore;
+    bool StopBefore;
+    unsigned StartInstanceNum;
+    unsigned StopInstanceNum;
+    StringRef StartPass;
+    StringRef StopPass;
+  };
+
+  /// Returns pass name in `-stop-before` or `-stop-after`
+  /// NOTE: New pass manager migration only
+  static Expected<StartStopInfo>
+  getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
   void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
 
   bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 4003a08a5422dd..9aab0a5703592d 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
   registerPartialPipelineCallback(PIC, LLVMTM);
 }
 
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+  auto [StartBefore, StartBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StartBeforeOpt);
+  auto [StartAfter, StartAfterInstanceNum] =
+      getPassNameAndInstanceNum(StartAfterOpt);
+  auto [StopBefore, StopBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StopBeforeOpt);
+  auto [StopAfter, StopAfterInstanceNum] =
+      getPassNameAndInstanceNum(StopAfterOpt);
+
+  if (!StartBefore.empty() && !StartAfter.empty())
+    return make_error<StringError>(
+        Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!StopBefore.empty() && !StopAfter.empty())
+    return make_error<StringError>(
+        Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+
+  StartStopInfo Result;
+  Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+  Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+  Result.StartInstanceNum =
+      StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+  Result.StopInstanceNum =
+      StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+  Result.StartBefore = !StartBefore.empty();
+  Result.StopBefore = !StopBefore.empty();
+  Result.StartInstanceNum += Result.StartInstanceNum == 0;
+  Result.StopInstanceNum += Result.StopInstanceNum == 0;
+  return Result;
+}
+
 // Out of line constructor provides default values for pass options and
 // registers all common codegen passes.
 TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
diff --git a/llvm/lib/IR/PassInstrumentation.cpp b/llvm/lib/IR/PassInstrumentation.cpp
index 6d5f3acb7a4d35..e64c30132793a4 100644
--- a/llvm/lib/IR/PassInstrumentation.cpp
+++ b/llvm/lib/IR/PassInstrumentation.cpp
@@ -28,6 +28,14 @@ PassInstrumentationCallbacks::getPassNameForClassName(StringRef ClassName) {
   return ClassToPassName[ClassName];
 }
 
+StringRef
+PassInstrumentationCallbacks::getClassNameForPassName(StringRef PassName) {
+  for (const auto &P : ClassToPassName)
+    if (P.second == PassName)
+      return P.first();
+  return StringRef();
+}
+
 AnalysisKey PassInstrumentationAnalysis::Key;
 
 bool isSpecialPass(StringRef PassID, const std::vector<StringRef> &Specials) {
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 649451edc0e2c6..e5eadbdfc7dfe0 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -88,6 +88,7 @@
 #include "llvm/CodeGen/SelectOptimize.h"
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/TypePromotion.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -407,7 +408,8 @@ AnalysisKey NoOpLoopAnalysis::Key;
 /// We currently only use this for --print-before/after.
 bool shouldPopulateClassToPassNames() {
   return PrintPipelinePasses || !printBeforePasses().empty() ||
-         !printAfterPasses().empty() || !isFilterPassesEmpty();
+         !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+         TargetPassConfig::hasLimitedCodeGenPipeline();
 }
 
 // A pass for testing -print-on-crash.



More information about the llvm-commits mailing list