[llvm] Reland "[CodeGen] Support start/stop in CodeGenPassBuilder (#70912)" (PR #78570)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 17:44:11 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/78570

>From bd1375589e0b9ed39e1b40206c9a13f1ac52f290 Mon Sep 17 00:00:00 2001
From: paperchalice <liujunchang97 at outlook.com>
Date: Thu, 18 Jan 2024 14:54:56 +0800
Subject: [PATCH] [CodeGen] Support start/stop in CodeGenPassBuilder (#70912)

Add `-start/stop-before/after` support for CodeGenPassBuilder.
Part of #69879.
---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h | 133 +++++++++++++++---
 llvm/include/llvm/CodeGen/TargetPassConfig.h  |  15 ++
 llvm/lib/CodeGen/TargetPassConfig.cpp         | 109 +++++---------
 llvm/lib/Passes/PassBuilder.cpp               |   4 +-
 4 files changed, 164 insertions(+), 97 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 78ee7bef02ab1a..12088f6fc35e0b 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -44,6 +44,7 @@
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
 #include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/UnreachableBlockElim.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
+    AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
     ~AddIRPass() {
       if (!FPM.isEmpty())
         MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
     }
 
-    template <typename PassT> void operator()(PassT &&Pass) {
+    template <typename PassT>
+    void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
       static_assert((is_detected<is_function_pass_t, PassT>::value ||
                      is_detected<is_module_pass_t, PassT>::value) &&
                     "Only module pass and function pass are supported.");
 
+      if (!PB.runBeforeAdding(Name))
+        return;
+
       // Add Function Pass
       if constexpr (is_detected<is_function_pass_t, PassT>::value) {
         FPM.addPass(std::forward<PassT>(Pass));
+
+        for (auto &C : PB.AfterCallbacks)
+          C(Name);
       } else {
         // Add Module Pass
         if (!FPM.isEmpty()) {
           MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
           FPM = FunctionPassManager();
         }
+
         MPM.addPass(std::forward<PassT>(Pass));
+
+        for (auto &C : PB.AfterCallbacks)
+          C(Name);
       }
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
+    const DerivedT &PB;
   };
 
   // Function object to maintain state while adding codegen machine passes.
   class AddMachinePass {
   public:
-    AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+    AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+        : PM(PM), PB(PB) {}
 
     template <typename PassT> void operator()(PassT &&Pass) {
       static_assert(
           is_detected<has_key_t, PassT>::value,
           "Machine function pass must define a static member variable `Key`.");
-      for (auto &C : BeforeCallbacks)
-        if (!C(&PassT::Key))
-          return;
+
+      if (!PB.runBeforeAdding(PassT::name()))
+        return;
+
       PM.addPass(std::forward<PassT>(Pass));
-      for (auto &C : AfterCallbacks)
-        C(&PassT::Key);
+
+      for (auto &C : PB.AfterCallbacks)
+        C(PassT::name());
     }
 
     template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
-      AfterCallbacks.emplace_back(
+      PB.AfterCallbacks.emplace_back(
           [this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
             if (PassID == ID)
               this->PM.addPass(std::move(Pass));
           });
     }
 
-    void disablePass(MachinePassKey *ID) {
-      BeforeCallbacks.emplace_back(
-          [ID](MachinePassKey *PassID) { return PassID != ID; });
-    }
-
     MachineFunctionPassManager releasePM() { return std::move(PM); }
 
   private:
     MachineFunctionPassManager &PM;
-    SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
-        BeforeCallbacks;
-    SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
-        AfterCallbacks;
+    const DerivedT &PB;
   };
 
   LLVMTargetMachine &TM;
@@ -473,6 +481,25 @@ template <typename DerivedT> class CodeGenPassBuilder {
   const DerivedT &derived() const {
     return static_cast<const DerivedT &>(*this);
   }
+
+  bool runBeforeAdding(StringRef Name) const {
+    bool ShouldAdd = true;
+    for (auto &C : BeforeCallbacks)
+      ShouldAdd &= C(Name);
+    return ShouldAdd;
+  }
+
+  void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+  Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
+
+  mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+      BeforeCallbacks;
+  mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+  /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+  mutable bool Started = true;
+  mutable bool Stopped = true;
 };
 
 template <typename Derived>
@@ -480,13 +507,17 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
     ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
     raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
     CodeGenFileType FileType) const {
-  AddIRPass addIRPass(MPM);
+  auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+  if (!StartStopInfo)
+    return StartStopInfo.takeError();
+  setStartStopPasses(*StartStopInfo);
+  AddIRPass addIRPass(MPM, derived());
   // `ProfileSummaryInfo` is always valid.
   addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
   addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
   addISelPasses(addIRPass);
 
-  AddMachinePass addPass(MFPM);
+  AddMachinePass addPass(MFPM, derived());
   if (auto Err = addCoreISelPasses(addPass))
     return std::move(Err);
 
@@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
       });
 
   addPass(FreeMachineFunctionPass());
+  return verifyStartStop(*StartStopInfo);
+}
+
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (!Info.StartPass.empty()) {
+    Started = false;
+    BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
+                                  Count = 0u](StringRef ClassName) mutable {
+      if (Count == Info.StartInstanceNum) {
+        if (AfterFlag) {
+          AfterFlag = false;
+          Started = true;
+        }
+        return Started;
+      }
+
+      auto PassName = PIC->getPassNameForClassName(ClassName);
+      if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
+        Started = !Info.StartAfter;
+
+      return Started;
+    });
+  }
+
+  if (!Info.StopPass.empty()) {
+    Stopped = false;
+    BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
+                                  Count = 0u](StringRef ClassName) mutable {
+      if (Count == Info.StopInstanceNum) {
+        if (AfterFlag) {
+          AfterFlag = false;
+          Stopped = true;
+        }
+        return !Stopped;
+      }
+
+      auto PassName = PIC->getPassNameForClassName(ClassName);
+      if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
+        Stopped = !Info.StopAfter;
+      return !Stopped;
+    });
+  }
+}
+
+template <typename Derived>
+Error CodeGenPassBuilder<Derived>::verifyStartStop(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (Started && Stopped)
+    return Error::success();
+
+  if (!Started)
+    return make_error<StringError>(
+        "Can't find start pass \"" +
+            PIC->getPassNameForClassName(Info.StartPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!Stopped)
+    return make_error<StringError>(
+        "Can't find stop pass \"" +
+            PIC->getPassNameForClassName(Info.StopPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
   return Error::success();
 }
 
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..de6a760c4e4fd1 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
 
 #include "llvm/Pass.h"
 #include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
 #include <cassert>
 #include <string>
 
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
   static std::string
   getLimitedCodeGenPipelineReason(const char *Separator = "/");
 
+  struct StartStopInfo {
+    bool StartAfter;
+    bool StopAfter;
+    unsigned StartInstanceNum;
+    unsigned StopInstanceNum;
+    StringRef StartPass;
+    StringRef StopPass;
+  };
+
+  /// Returns pass name in `-stop-before` or `-stop-after`
+  /// NOTE: New pass manager migration only
+  static Expected<StartStopInfo>
+  getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
   void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
 
   bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 3bbc792f4cbf46..76ba8da547e6b9 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -504,81 +504,6 @@ CGPassBuilderOption llvm::getCGPassBuilderOption() {
   return Opt;
 }
 
-static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
-                                            LLVMTargetMachine &LLVMTM) {
-  StringRef StartBefore;
-  StringRef StartAfter;
-  StringRef StopBefore;
-  StringRef StopAfter;
-
-  unsigned StartBeforeInstanceNum = 0;
-  unsigned StartAfterInstanceNum = 0;
-  unsigned StopBeforeInstanceNum = 0;
-  unsigned StopAfterInstanceNum = 0;
-
-  std::tie(StartBefore, StartBeforeInstanceNum) =
-      getPassNameAndInstanceNum(StartBeforeOpt);
-  std::tie(StartAfter, StartAfterInstanceNum) =
-      getPassNameAndInstanceNum(StartAfterOpt);
-  std::tie(StopBefore, StopBeforeInstanceNum) =
-      getPassNameAndInstanceNum(StopBeforeOpt);
-  std::tie(StopAfter, StopAfterInstanceNum) =
-      getPassNameAndInstanceNum(StopAfterOpt);
-
-  if (StartBefore.empty() && StartAfter.empty() && StopBefore.empty() &&
-      StopAfter.empty())
-    return;
-
-  std::tie(StartBefore, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StartBefore);
-  std::tie(StartAfter, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StartAfter);
-  std::tie(StopBefore, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StopBefore);
-  std::tie(StopAfter, std::ignore) =
-      LLVMTM.getPassNameFromLegacyName(StopAfter);
-  if (!StartBefore.empty() && !StartAfter.empty())
-    report_fatal_error(Twine(StartBeforeOptName) + Twine(" and ") +
-                       Twine(StartAfterOptName) + Twine(" specified!"));
-  if (!StopBefore.empty() && !StopAfter.empty())
-    report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
-                       Twine(StopAfterOptName) + Twine(" specified!"));
-
-  PIC.registerShouldRunOptionalPassCallback(
-      [=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
-       EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
-       StartAfterCount = 0u, StopBeforeCount = 0u,
-       StopAfterCount = 0u](StringRef P, Any) mutable {
-        bool StartBeforePass = !StartBefore.empty() && P.contains(StartBefore);
-        bool StartAfterPass = !StartAfter.empty() && P.contains(StartAfter);
-        bool StopBeforePass = !StopBefore.empty() && P.contains(StopBefore);
-        bool StopAfterPass = !StopAfter.empty() && P.contains(StopAfter);
-
-        // Implement -start-after/-stop-after
-        if (EnableNext) {
-          EnableCurrent = *EnableNext;
-          EnableNext.reset();
-        }
-
-        // Using PIC.registerAfterPassCallback won't work because if this
-        // callback returns false, AfterPassCallback is also skipped.
-        if (StartAfterPass && StartAfterCount++ == StartAfterInstanceNum) {
-          assert(!EnableNext && "Error: assign to EnableNext more than once");
-          EnableNext = true;
-        }
-        if (StopAfterPass && StopAfterCount++ == StopAfterInstanceNum) {
-          assert(!EnableNext && "Error: assign to EnableNext more than once");
-          EnableNext = false;
-        }
-
-        if (StartBeforePass && StartBeforeCount++ == StartBeforeInstanceNum)
-          EnableCurrent = true;
-        if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
-          EnableCurrent = false;
-        return EnableCurrent;
-      });
-}
-
 void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
                                    LLVMTargetMachine &LLVMTM) {
 
@@ -605,8 +530,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
 
     return true;
   });
+}
 
-  registerPartialPipelineCallback(PIC, LLVMTM);
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+  auto [StartBefore, StartBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StartBeforeOpt);
+  auto [StartAfter, StartAfterInstanceNum] =
+      getPassNameAndInstanceNum(StartAfterOpt);
+  auto [StopBefore, StopBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StopBeforeOpt);
+  auto [StopAfter, StopAfterInstanceNum] =
+      getPassNameAndInstanceNum(StopAfterOpt);
+
+  if (!StartBefore.empty() && !StartAfter.empty())
+    return make_error<StringError>(
+        Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!StopBefore.empty() && !StopAfter.empty())
+    return make_error<StringError>(
+        Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+
+  StartStopInfo Result;
+  Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+  Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+  Result.StartInstanceNum =
+      StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+  Result.StopInstanceNum =
+      StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+  Result.StartAfter = !StartAfter.empty();
+  Result.StopAfter = !StopAfter.empty();
+  Result.StartInstanceNum += Result.StartInstanceNum == 0;
+  Result.StopInstanceNum += Result.StopInstanceNum == 0;
+  return Result;
 }
 
 // Out of line constructor provides default values for pass options and
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index d309ed999bd206..8d3f69be503831 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -93,6 +93,7 @@
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
 #include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/TypePromotion.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -316,7 +317,8 @@ namespace {
 /// We currently only use this for --print-before/after.
 bool shouldPopulateClassToPassNames() {
   return PrintPipelinePasses || !printBeforePasses().empty() ||
-         !printAfterPasses().empty() || !isFilterPassesEmpty();
+         !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+         TargetPassConfig::hasLimitedCodeGenPipeline();
 }
 
 // A pass for testing -print-on-crash.



More information about the llvm-commits mailing list