[llvm] Reland "[CodeGen] Support start/stop in CodeGenPassBuilder (#70912)" (PR #78570)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 04:51:14 PST 2024


https://github.com/paperchalice created https://github.com/llvm/llvm-project/pull/78570

Unfortunately the legacy pass system can't recognize `no-op-module` and `no-op-function` so it  causes test failure in `CodeGenTests`. Add a workaround in  function `PassInfo *getPassInfo(StringRef PassName)`, `TargetPassConfig.cpp`.

>From fcb7de6ecb95abde783fb4000aaaa4f56b6c5ee5 Mon Sep 17 00:00:00 2001
From: paperchalice <liujunchang97 at outlook.com>
Date: Thu, 18 Jan 2024 14:54:56 +0800
Subject: [PATCH] [CodeGen] Support start/stop in CodeGenPassBuilder (#70912)

Add `-start/stop-before/after` support for CodeGenPassBuilder.
Part of #69879.
---
 .../include/llvm/CodeGen/CodeGenPassBuilder.h | 133 +++++++++++++++---
 llvm/include/llvm/CodeGen/TargetPassConfig.h  |  15 ++
 llvm/lib/CodeGen/TargetPassConfig.cpp         |  40 ++++++
 llvm/lib/Passes/PassBuilder.cpp               |   4 +-
 .../CodeGen/CodeGenPassBuilderTest.cpp        |  41 ++++++
 5 files changed, 212 insertions(+), 21 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 78ee7bef02ab1a..12088f6fc35e0b 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -44,6 +44,7 @@
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
 #include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/UnreachableBlockElim.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -176,73 +177,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
   // Function object to maintain state while adding codegen IR passes.
   class AddIRPass {
   public:
-    AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
+    AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
     ~AddIRPass() {
       if (!FPM.isEmpty())
         MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
     }
 
-    template <typename PassT> void operator()(PassT &&Pass) {
+    template <typename PassT>
+    void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
       static_assert((is_detected<is_function_pass_t, PassT>::value ||
                      is_detected<is_module_pass_t, PassT>::value) &&
                     "Only module pass and function pass are supported.");
 
+      if (!PB.runBeforeAdding(Name))
+        return;
+
       // Add Function Pass
       if constexpr (is_detected<is_function_pass_t, PassT>::value) {
         FPM.addPass(std::forward<PassT>(Pass));
+
+        for (auto &C : PB.AfterCallbacks)
+          C(Name);
       } else {
         // Add Module Pass
         if (!FPM.isEmpty()) {
           MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
           FPM = FunctionPassManager();
         }
+
         MPM.addPass(std::forward<PassT>(Pass));
+
+        for (auto &C : PB.AfterCallbacks)
+          C(Name);
       }
     }
 
   private:
     ModulePassManager &MPM;
     FunctionPassManager FPM;
+    const DerivedT &PB;
   };
 
   // Function object to maintain state while adding codegen machine passes.
   class AddMachinePass {
   public:
-    AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+    AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+        : PM(PM), PB(PB) {}
 
     template <typename PassT> void operator()(PassT &&Pass) {
       static_assert(
           is_detected<has_key_t, PassT>::value,
           "Machine function pass must define a static member variable `Key`.");
-      for (auto &C : BeforeCallbacks)
-        if (!C(&PassT::Key))
-          return;
+
+      if (!PB.runBeforeAdding(PassT::name()))
+        return;
+
       PM.addPass(std::forward<PassT>(Pass));
-      for (auto &C : AfterCallbacks)
-        C(&PassT::Key);
+
+      for (auto &C : PB.AfterCallbacks)
+        C(PassT::name());
     }
 
     template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
-      AfterCallbacks.emplace_back(
+      PB.AfterCallbacks.emplace_back(
           [this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
             if (PassID == ID)
               this->PM.addPass(std::move(Pass));
           });
     }
 
-    void disablePass(MachinePassKey *ID) {
-      BeforeCallbacks.emplace_back(
-          [ID](MachinePassKey *PassID) { return PassID != ID; });
-    }
-
     MachineFunctionPassManager releasePM() { return std::move(PM); }
 
   private:
     MachineFunctionPassManager &PM;
-    SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
-        BeforeCallbacks;
-    SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
-        AfterCallbacks;
+    const DerivedT &PB;
   };
 
   LLVMTargetMachine &TM;
@@ -473,6 +481,25 @@ template <typename DerivedT> class CodeGenPassBuilder {
   const DerivedT &derived() const {
     return static_cast<const DerivedT &>(*this);
   }
+
+  bool runBeforeAdding(StringRef Name) const {
+    bool ShouldAdd = true;
+    for (auto &C : BeforeCallbacks)
+      ShouldAdd &= C(Name);
+    return ShouldAdd;
+  }
+
+  void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+  Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
+
+  mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+      BeforeCallbacks;
+  mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+  /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+  mutable bool Started = true;
+  mutable bool Stopped = true;
 };
 
 template <typename Derived>
@@ -480,13 +507,17 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
     ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
     raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
     CodeGenFileType FileType) const {
-  AddIRPass addIRPass(MPM);
+  auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+  if (!StartStopInfo)
+    return StartStopInfo.takeError();
+  setStartStopPasses(*StartStopInfo);
+  AddIRPass addIRPass(MPM, derived());
   // `ProfileSummaryInfo` is always valid.
   addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
   addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
   addISelPasses(addIRPass);
 
-  AddMachinePass addPass(MFPM);
+  AddMachinePass addPass(MFPM, derived());
   if (auto Err = addCoreISelPasses(addPass))
     return std::move(Err);
 
@@ -499,6 +530,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
       });
 
   addPass(FreeMachineFunctionPass());
+  return verifyStartStop(*StartStopInfo);
+}
+
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (!Info.StartPass.empty()) {
+    Started = false;
+    BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
+                                  Count = 0u](StringRef ClassName) mutable {
+      if (Count == Info.StartInstanceNum) {
+        if (AfterFlag) {
+          AfterFlag = false;
+          Started = true;
+        }
+        return Started;
+      }
+
+      auto PassName = PIC->getPassNameForClassName(ClassName);
+      if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
+        Started = !Info.StartAfter;
+
+      return Started;
+    });
+  }
+
+  if (!Info.StopPass.empty()) {
+    Stopped = false;
+    BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
+                                  Count = 0u](StringRef ClassName) mutable {
+      if (Count == Info.StopInstanceNum) {
+        if (AfterFlag) {
+          AfterFlag = false;
+          Stopped = true;
+        }
+        return !Stopped;
+      }
+
+      auto PassName = PIC->getPassNameForClassName(ClassName);
+      if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
+        Stopped = !Info.StopAfter;
+      return !Stopped;
+    });
+  }
+}
+
+template <typename Derived>
+Error CodeGenPassBuilder<Derived>::verifyStartStop(
+    const TargetPassConfig::StartStopInfo &Info) const {
+  if (Started && Stopped)
+    return Error::success();
+
+  if (!Started)
+    return make_error<StringError>(
+        "Can't find start pass \"" +
+            PIC->getPassNameForClassName(Info.StartPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!Stopped)
+    return make_error<StringError>(
+        "Can't find stop pass \"" +
+            PIC->getPassNameForClassName(Info.StopPass) + "\".",
+        std::make_error_code(std::errc::invalid_argument));
   return Error::success();
 }
 
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..de6a760c4e4fd1 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
 
 #include "llvm/Pass.h"
 #include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
 #include <cassert>
 #include <string>
 
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
   static std::string
   getLimitedCodeGenPipelineReason(const char *Separator = "/");
 
+  struct StartStopInfo {
+    bool StartAfter;
+    bool StopAfter;
+    unsigned StartInstanceNum;
+    unsigned StopInstanceNum;
+    StringRef StartPass;
+    StringRef StopPass;
+  };
+
+  /// Returns pass name in `-stop-before` or `-stop-after`
+  /// NOTE: New pass manager migration only
+  static Expected<StartStopInfo>
+  getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
   void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
 
   bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 3bbc792f4cbf46..c3d409931fffce 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -406,6 +406,12 @@ static const PassInfo *getPassInfo(StringRef PassName) {
   if (PassName.empty())
     return nullptr;
 
+  // FIXME: Workaround for a failure in unittests/CodeGen/CodeGenTests
+  // There is no counterpart in legacy pass
+  // delete this when related tests are migrated to lit.
+  if (PassName == "no-op-module" || PassName == "no-op-function")
+    return nullptr;
+
   const PassRegistry &PR = *PassRegistry::getPassRegistry();
   const PassInfo *PI = PR.getPassInfo(PassName);
   if (!PI)
@@ -609,6 +615,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
   registerPartialPipelineCallback(PIC, LLVMTM);
 }
 
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+  auto [StartBefore, StartBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StartBeforeOpt);
+  auto [StartAfter, StartAfterInstanceNum] =
+      getPassNameAndInstanceNum(StartAfterOpt);
+  auto [StopBefore, StopBeforeInstanceNum] =
+      getPassNameAndInstanceNum(StopBeforeOpt);
+  auto [StopAfter, StopAfterInstanceNum] =
+      getPassNameAndInstanceNum(StopAfterOpt);
+
+  if (!StartBefore.empty() && !StartAfter.empty())
+    return make_error<StringError>(
+        Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+  if (!StopBefore.empty() && !StopAfter.empty())
+    return make_error<StringError>(
+        Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
+        std::make_error_code(std::errc::invalid_argument));
+
+  StartStopInfo Result;
+  Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+  Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+  Result.StartInstanceNum =
+      StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+  Result.StopInstanceNum =
+      StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+  Result.StartAfter = !StartAfter.empty();
+  Result.StopAfter = !StopAfter.empty();
+  Result.StartInstanceNum += Result.StartInstanceNum == 0;
+  Result.StopInstanceNum += Result.StopInstanceNum == 0;
+  return Result;
+}
+
 // Out of line constructor provides default values for pass options and
 // registers all common codegen passes.
 TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index d309ed999bd206..8d3f69be503831 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -93,6 +93,7 @@
 #include "llvm/CodeGen/ShadowStackGCLowering.h"
 #include "llvm/CodeGen/SjLjEHPrepare.h"
 #include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/TypePromotion.h"
 #include "llvm/CodeGen/WasmEHPrepare.h"
 #include "llvm/CodeGen/WinEHPrepare.h"
@@ -316,7 +317,8 @@ namespace {
 /// We currently only use this for --print-before/after.
 bool shouldPopulateClassToPassNames() {
   return PrintPipelinePasses || !printBeforePasses().empty() ||
-         !printAfterPasses().empty() || !isFilterPassesEmpty();
+         !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+         TargetPassConfig::hasLimitedCodeGenPipeline();
 }
 
 // A pass for testing -print-on-crash.
diff --git a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
index d6ec393155cf09..63499b056d1ef4 100644
--- a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
+++ b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
@@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) {
   EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline);
 }
 
+// TODO: Move this to lit test when llc support new pm.
+TEST_F(CodeGenPassBuilderTest, start_stop) {
+  static const char *argv[] = {
+      "test",
+      "-start-after=no-op-module",
+      "-stop-before=no-op-function,2",
+  };
+  int argc = std::size(argv);
+  cl::ParseCommandLineOptions(argc, argv);
+
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+
+  PassInstrumentationCallbacks PIC;
+  DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
+  PipelineTuningOptions PTO;
+  PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC);
+
+  PB.registerModuleAnalyses(MAM);
+  PB.registerCGSCCAnalyses(CGAM);
+  PB.registerFunctionAnalyses(FAM);
+  PB.registerLoopAnalyses(LAM);
+  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+  ModulePassManager MPM;
+  MachineFunctionPassManager MFPM;
+
+  Error Err =
+      CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
+  EXPECT_FALSE(Err);
+  std::string IRPipeline;
+  raw_string_ostream IROS(IRPipeline);
+  MPM.printPipeline(IROS, [&PIC](StringRef Name) {
+    auto PassName = PIC.getPassNameForClassName(Name);
+    return PassName.empty() ? Name : PassName;
+  });
+  EXPECT_EQ(IRPipeline, "function(no-op-function)");
+}
+
 } // namespace



More information about the llvm-commits mailing list