[llvm] [Pass] Support start/stop in instrumentation (PR #70912)

via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 7 06:08:56 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/70912

>From 13d0ba206c1ad1c3ca6f5c464b922ca3fa72912b Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Wed, 3 Jan 2024 10:34:33 +0800
Subject: [PATCH] [CodeGen][NewPM] Support start/stop in CodeGen

---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h | 119 +++++++++++++++---
 llvm/include/llvm/CodeGen/TargetPassConfig.h  |  15 +++
 llvm/include/llvm/IR/PassInstrumentation.h    |   6 +
 llvm/lib/CodeGen/TargetPassConfig.cpp         |  66 ++++++++--
 llvm/lib/IR/PassInstrumentation.cpp           |   8 ++
 llvm/lib/Passes/PassBuilder.cpp               |   4 +-
 6 files changed, 185 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 2100c30aad1180..d389350b3686a2 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -40,6 +40,7 @@
 #include "llvm/CodeGen/SelectOptimize.h"
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/UnreachableBlockElim.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -172,8 +173,9 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM, bool DebugPM, bool Check = true)
-        : MPM(MPM) {
+    AddIRPass(ModulePassManager &MPM, const DerivedT &PB, bool DebugPM,
+              bool Check = true)
+        : MPM(MPM), PB(PB) {
       if (Check)
         AddingFunctionPasses = false;
     }
@@ -184,10 +186,17 @@ template <typename DerivedT> class CodeGenPassBuilder {
     // Add Function Pass
     template <typename PassT>
     std::enable_if_t<is_detected<is_function_pass_t, PassT>::value>
-    operator()(PassT &&Pass) {
+    operator()(PassT &&Pass, StringRef Name = PassT::name()) {
       if (AddingFunctionPasses && !*AddingFunctionPasses)
         AddingFunctionPasses = true;
+      for (const auto C : PB.BeforeCallbacks)
+        if (!C(Name))
+          return;
+
       FPM.addPass(std::forward<PassT>(Pass));
+
+      for (const auto &C : PB.AfterCallbacks)
+        C(Name);
     }
 
     // Add Module Pass
@@ -197,12 +206,20 @@ template <typename DerivedT> class CodeGenPassBuilder {
     operator()(PassT &&Pass) {
       assert((!AddingFunctionPasses || !*AddingFunctionPasses) &&
              "could not add module pass after adding function pass");
+      for (const auto C : PB.BeforeCallbacks)
+        if (!C(PassT::name()))
+          return;
+
       MPM.addPass(std::forward<PassT>(Pass));
+
+      for (const auto &C : PB.AfterCallbacks)
+        C(PassT::name());
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
+    const DerivedT &PB;
     // The codegen IR pipeline are mostly function passes with the exceptions of
     // a few loop and module passes. `AddingFunctionPasses` make sures that
     // we could only add module passes at the beginning of the pipeline. Once
@@ -216,18 +233,19 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen machine passes.
   class AddMachinePass {
   public:
-    AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+    AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+        : PM(PM), PB(PB) {}
 
     template <typename PassT> void operator()(PassT &&Pass) {
       static_assert(
           is_detected<has_key_t, PassT>::value,
           "Machine function pass must define a static member variable `Key`.");
-      for (auto &C : BeforeCallbacks)
-        if (!C(&PassT::Key))
+      for (auto &C : PB.BeforeCallbacks)
+        if (!C(PassT::name()))
           return;
       PM.addPass(std::forward<PassT>(Pass));
       for (auto &C : AfterCallbacks)
-        C(&PassT::Key);
+        C(PassT::name());
     }
 
     template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
@@ -238,19 +256,11 @@ template <typename DerivedT> class CodeGenPassBuilder {
           });
     }
 
-    void disablePass(MachinePassKey *ID) {
-      BeforeCallbacks.emplace_back(
-          [ID](MachinePassKey *PassID) { return PassID != ID; });
-    }
-
     MachineFunctionPassManager releasePM() { return std::move(PM); }
 
   private:
     MachineFunctionPassManager &PM;
-    SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
-        BeforeCallbacks;
-    SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
-        AfterCallbacks;
+    const DerivedT &PB;
   };
 
   LLVMTargetMachine &TM;
@@ -478,6 +488,15 @@ template <typename DerivedT> class CodeGenPassBuilder {
   const DerivedT &derived() const {
     return static_cast<const DerivedT &>(*this);
   }
+
+  void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+  mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+      BeforeCallbacks;
+  mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+  /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+  mutable bool ShouldAddPass = true;
 };
 
 template <typename Derived>
@@ -485,12 +504,16 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
     ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
     raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
     CodeGenFileType FileType) const {
-  AddIRPass addIRPass(MPM, Opt.DebugPM);
+  auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+  if (StartStopInfo)
+    return StartStopInfo.takeError();
+  setStartStopPasses(*StartStopInfo);
+  AddIRPass addIRPass(MPM, derived(), Opt.DebugPM);
   // `ProfileSummaryInfo` is always valid.
   addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
   addISelPasses(addIRPass);
 
-  AddMachinePass addPass(MFPM);
+  AddMachinePass addPass(MFPM, derived());
   if (auto Err = addCoreISelPasses(addPass))
     return std::move(Err);
 
@@ -506,6 +529,61 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
   return Error::success();
 }
 
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (!Info.StartPass.empty()) {
+    ShouldAddPass = false;
+    if (Info.StartBefore) {
+      BeforeCallbacks.emplace_back(
+          [this, &Info, Count = 0](StringRef ClassName) mutable {
+            if (Count == Info.StartInstanceNum)
+              return true;
+            if (Info.StartPass == ClassName)
+              ++Count;
+            if (Count == Info.StartInstanceNum)
+              ShouldAddPass = true;
+            return ShouldAddPass;
+          });
+    } else {
+      AfterCallbacks.emplace_back(
+          [this, &Info, Count = 0](StringRef ClassName) mutable {
+            if (Count == Info.StartInstanceNum)
+              return;
+            if (Info.StartPass == ClassName)
+              ++Count;
+            if (Count == Info.StartInstanceNum)
+              ShouldAddPass = true;
+          });
+    }
+  }
+
+  if (!Info.StopPass.empty()) {
+    if (Info.StopBefore) {
+      BeforeCallbacks.emplace_back(
+          [this, &Info, Count = 0u](StringRef ClassName) mutable {
+            if (Count == Info.StopInstanceNum)
+              return false;
+            if (Info.StopPass == ClassName)
+              ++Count;
+            if (Count == Info.StopInstanceNum)
+              ShouldAddPass = false;
+            return ShouldAddPass;
+          });
+    } else {
+      AfterCallbacks.emplace_back(
+          [this, &Info, Count = 0u](StringRef ClassName) mutable {
+            if (Count == Info.StopInstanceNum)
+              return;
+            if (Info.StopPass == ClassName)
+              ++Count;
+            if (Count == Info.StopInstanceNum)
+              ShouldAddPass = false;
+          });
+    }
+  }
+}
+
 static inline AAManager registerAAAnalyses() {
   AAManager AA;
 
@@ -623,8 +701,9 @@ void CodeGenPassBuilder<Derived>::addIRPasses(AddIRPass &addPass) const {
 
   // Run loop strength reduction before anything else.
   if (getOptLevel() != CodeGenOptLevel::None && !Opt.DisableLSR) {
-    addPass(createFunctionToLoopPassAdaptor(
-        LoopStrengthReducePass(), /*UseMemorySSA*/ true, Opt.DebugPM));
+    addPass(createFunctionToLoopPassAdaptor(LoopStrengthReducePass(),
+                                            /*UseMemorySSA*/ true, Opt.DebugPM),
+            LoopStrengthReducePass::name());
     // FIXME: use -stop-after so we could remove PrintLSR
     if (Opt.PrintLSR)
       addPass(PrintFunctionPass(dbgs(), "\n\n*** Code after LSR ***\n"));
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..2823485562c20f 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
 
 #include "llvm/Pass.h"
 #include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
 #include <cassert>
 #include <string>
 
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
   static std::string
   getLimitedCodeGenPipelineReason(const char *Separator = "/");
 
+  struct StartStopInfo {
+    bool StartBefore;
+    bool StopBefore;
+    unsigned StartInstanceNum;
+    unsigned StopInstanceNum;
+    StringRef StartPass;
+    StringRef StopPass;
+  };
+
+  /// Returns pass name in `-stop-before` or `-stop-after`
+  /// NOTE: New pass manager migration only
+  static Expected<StartStopInfo>
+  getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
   void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
 
   bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/include/llvm/IR/PassInstrumentation.h b/llvm/include/llvm/IR/PassInstrumentation.h
index 519a5e46b4373b..ae493c740908d1 100644
--- a/llvm/include/llvm/IR/PassInstrumentation.h
+++ b/llvm/include/llvm/IR/PassInstrumentation.h
@@ -84,6 +84,7 @@ class PassInstrumentationCallbacks {
   using AfterAnalysisFunc = void(StringRef, Any);
   using AnalysisInvalidatedFunc = void(StringRef, Any);
   using AnalysesClearedFunc = void(StringRef);
+  using StartStopFunc = bool(StringRef);
 
 public:
   PassInstrumentationCallbacks() = default;
@@ -152,6 +153,11 @@ class PassInstrumentationCallbacks {
   void addClassToPassName(StringRef ClassName, StringRef PassName);
   /// Get the pass name for a given pass class name.
   StringRef getPassNameForClassName(StringRef ClassName);
+  /// Inverse of getPassNameForClassName
+  StringRef getClassNameForPassName(StringRef PassName);
+
+  /// Helper callback to support options like start-before.
+  llvm::unique_function<StartStopFunc> StartStopCallback;
 
 private:
   friend class PassInstrumentation;
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 4003a08a5422dd..29f1b810887e04 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -526,17 +526,16 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
       getPassNameAndInstanceNum(StopAfterOpt);
 
   if (StartBefore.empty() && StartAfter.empty() && StopBefore.empty() &&
-      StopAfter.empty())
+      StopAfter.empty()) {
+    PIC.StartStopCallback = [](StringRef) { return true; };
     return;
+  }
+
+  StartBefore = PIC.getClassNameForPassName(StartBefore);
+  StartAfter = PIC.getClassNameForPassName(StartAfter);
+  StopBefore = PIC.getClassNameForPassName(StopBefore);
+  StopAfter = PIC.getClassNameForPassName(StopAfter);
 
-  std::tie(StartBefore, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StartBefore);
-  std::tie(StartAfter, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StartAfter);
-  std::tie(StopBefore, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StopBefore);
-  std::tie(StopAfter, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StopAfter);
   if (!StartBefore.empty() && !StartAfter.empty())
     report_fatal_error(Twine(StartBeforeOptName) + Twine(" and ") +
                        Twine(StartAfterOptName) + Twine(" specified!"));
@@ -544,11 +543,11 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
     report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
                        Twine(StopAfterOptName) + Twine(" specified!"));
 
-  PIC.registerShouldRunOptionalPassCallback(
+  PIC.StartStopCallback =
       [=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
        EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
        StartAfterCount = 0u, StopBeforeCount = 0u,
-       StopAfterCount = 0u](StringRef P, Any) mutable {
+       StopAfterCount = 0u](StringRef P) mutable {
         bool StartBeforePass = !StartBefore.empty() && P.contains(StartBefore);
         bool StartAfterPass = !StartAfter.empty() && P.contains(StartAfter);
         bool StopBeforePass = !StopBefore.empty() && P.contains(StopBefore);
@@ -576,7 +575,7 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
         if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
           EnableCurrent = false;
         return EnableCurrent;
-      });
+      };
 }
 
 void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
@@ -609,6 +608,45 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
   registerPartialPipelineCallback(PIC, LLVMTM);
 }
 
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+  auto [StartBefore, StartBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StartBeforeOpt);
+  auto [StartAfter, StartAfterInstanceNum] =
+      getPassNameAndInstanceNum(StartAfterOpt);
+  auto [StopBefore, StopBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StopBeforeOpt);
+  auto [StopAfter, StopAfterInstanceNum] =
+      getPassNameAndInstanceNum(StopAfterOpt);
+
+  StartBefore = PIC.getClassNameForPassName(StartBefore);
+  StartAfter = PIC.getClassNameForPassName(StartAfter);
+  StopBefore = PIC.getClassNameForPassName(StopBefore);
+  StopAfter = PIC.getClassNameForPassName(StopAfter);
+
+  if (!StartBefore.empty() && !StartAfter.empty())
+    return make_error<StringError>(Twine(StartBeforeOptName) + Twine(" and ") +
+                                   Twine(StartAfterOptName) +
+                                   Twine(" specified!"));
+  if (!StopBefore.empty() && !StopAfter.empty())
+    return make_error<StringError>(Twine(StopBeforeOptName) + Twine(" and ") +
+                                   Twine(StopAfterOptName) +
+                                   Twine(" specified!"));
+
+  StartStopInfo Result;
+  Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+  Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+  Result.StartInstanceNum =
+      StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+  Result.StopInstanceNum =
+      StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+  Result.StartBefore = !StartBefore.empty();
+  Result.StopBefore = !StopBefore.empty();
+  Result.StartInstanceNum += Result.StartInstanceNum == 0;
+  Result.StopInstanceNum += Result.StopInstanceNum == 0;
+  return Result;
+}
+
 // Out of line constructor provides default values for pass options and
 // registers all common codegen passes.
 TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
@@ -698,6 +736,10 @@ TargetPassConfig::getLimitedCodeGenPipelineReason(const char *Separator) {
   return Res;
 }
 
+StringRef TargetPassConfig::getStopPass() {
+  return StopBeforeOpt.empty() ? StopAfterOpt : StopBeforeOpt;
+}
+
 // Helper to verify the analysis is really immutable.
 void TargetPassConfig::setOpt(bool &Opt, bool Val) {
   assert(!Initialized && "PassConfig is immutable");
diff --git a/llvm/lib/IR/PassInstrumentation.cpp b/llvm/lib/IR/PassInstrumentation.cpp
index 6d5f3acb7a4d35..e64c30132793a4 100644
--- a/llvm/lib/IR/PassInstrumentation.cpp
+++ b/llvm/lib/IR/PassInstrumentation.cpp
@@ -28,6 +28,14 @@ PassInstrumentationCallbacks::getPassNameForClassName(StringRef ClassName) {
   return ClassToPassName[ClassName];
 }
 
+StringRef
+PassInstrumentationCallbacks::getClassNameForPassName(StringRef PassName) {
+  for (const auto &P : ClassToPassName)
+    if (P.second == PassName)
+      return P.first();
+  return StringRef();
+}
+
 AnalysisKey PassInstrumentationAnalysis::Key;
 
 bool isSpecialPass(StringRef PassID, const std::vector<StringRef> &Specials) {
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 649451edc0e2c6..e5eadbdfc7dfe0 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -88,6 +88,7 @@
 #include "llvm/CodeGen/SelectOptimize.h"
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/TypePromotion.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -407,7 +408,8 @@ AnalysisKey NoOpLoopAnalysis::Key;
 /// We currently only use this for --print-before/after.
 bool shouldPopulateClassToPassNames() {
   return PrintPipelinePasses || !printBeforePasses().empty() ||
-         !printAfterPasses().empty() || !isFilterPassesEmpty();
+         !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+         TargetPassConfig::hasLimitedCodeGenPipeline();
 }
 
 // A pass for testing -print-on-crash.



More information about the llvm-commits mailing list