[llvm] [CodeGen] Support start/stop in CodeGenPassBuilder (PR #70912)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 16 17:04:55 PST 2024
https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/70912
>From ce88f2bcbf328bbfb49b7fd762ac43584b7a2e89 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Wed, 3 Jan 2024 10:34:33 +0800
Subject: [PATCH] [CodeGen][NewPM] Support start/stop in CodeGen
---
.../include/llvm/CodeGen/CodeGenPassBuilder.h | 133 +++++++++++++++---
llvm/include/llvm/CodeGen/TargetPassConfig.h | 15 ++
llvm/lib/CodeGen/TargetPassConfig.cpp | 34 +++++
llvm/lib/Passes/PassBuilder.cpp | 4 +-
.../CodeGen/CodeGenPassBuilderTest.cpp | 41 ++++++
5 files changed, 206 insertions(+), 21 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 0ea81347638e99..f425177ba2a902 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -43,6 +43,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/UnreachableBlockElim.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -175,73 +176,80 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen IR passes.
class AddIRPass {
public:
- AddIRPass(ModulePassManager &MPM) : MPM(MPM) {}
+ AddIRPass(ModulePassManager &MPM, const DerivedT &PB) : MPM(MPM), PB(PB) {}
~AddIRPass() {
if (!FPM.isEmpty())
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
- template <typename PassT> void operator()(PassT &&Pass) {
+ template <typename PassT>
+ void operator()(PassT &&Pass, StringRef Name = PassT::name()) {
static_assert((is_detected<is_function_pass_t, PassT>::value ||
is_detected<is_module_pass_t, PassT>::value) &&
"Only module pass and function pass are supported.");
+ if (!PB.runBeforeAdding(Name))
+ return;
+
// Add Function Pass
if constexpr (is_detected<is_function_pass_t, PassT>::value) {
FPM.addPass(std::forward<PassT>(Pass));
+
+ for (auto &C : PB.AfterCallbacks)
+ C(Name);
} else {
// Add Module Pass
if (!FPM.isEmpty()) {
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
FPM = FunctionPassManager();
}
+
MPM.addPass(std::forward<PassT>(Pass));
+
+ for (auto &C : PB.AfterCallbacks)
+ C(Name);
}
}
private:
ModulePassManager &MPM;
FunctionPassManager FPM;
+ const DerivedT &PB;
};
// Function object to maintain state while adding codegen machine passes.
class AddMachinePass {
public:
- AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+ AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+ : PM(PM), PB(PB) {}
template <typename PassT> void operator()(PassT &&Pass) {
static_assert(
is_detected<has_key_t, PassT>::value,
"Machine function pass must define a static member variable `Key`.");
- for (auto &C : BeforeCallbacks)
- if (!C(&PassT::Key))
- return;
+
+ if (!PB.runBeforeAdding(PassT::name()))
+ return;
+
PM.addPass(std::forward<PassT>(Pass));
- for (auto &C : AfterCallbacks)
- C(&PassT::Key);
+
+ for (auto &C : PB.AfterCallbacks)
+ C(PassT::name());
}
template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
- AfterCallbacks.emplace_back(
+ PB.AfterCallbacks.emplace_back(
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
if (PassID == ID)
this->PM.addPass(std::move(Pass));
});
}
- void disablePass(MachinePassKey *ID) {
- BeforeCallbacks.emplace_back(
- [ID](MachinePassKey *PassID) { return PassID != ID; });
- }
-
MachineFunctionPassManager releasePM() { return std::move(PM); }
private:
MachineFunctionPassManager &PM;
- SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
- BeforeCallbacks;
- SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
- AfterCallbacks;
+ const DerivedT &PB;
};
LLVMTargetMachine &TM;
@@ -469,6 +477,25 @@ template <typename DerivedT> class CodeGenPassBuilder {
const DerivedT &derived() const {
return static_cast<const DerivedT &>(*this);
}
+
+ bool runBeforeAdding(StringRef Name) const {
+ bool ShouldAdd = true;
+ for (auto &C : BeforeCallbacks)
+ ShouldAdd &= C(Name);
+ return ShouldAdd;
+ }
+
+ void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+ Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
+
+ mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+ BeforeCallbacks;
+ mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+ /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+ mutable bool Started = true;
+ mutable bool Stopped = true;
};
template <typename Derived>
@@ -476,13 +503,17 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const {
- AddIRPass addIRPass(MPM);
+ auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+ if (!StartStopInfo)
+ return StartStopInfo.takeError();
+ setStartStopPasses(*StartStopInfo);
+ AddIRPass addIRPass(MPM, derived());
// `ProfileSummaryInfo` is always valid.
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addIRPass(RequireAnalysisPass<CollectorMetadataAnalysis, Module>());
addISelPasses(addIRPass);
- AddMachinePass addPass(MFPM);
+ AddMachinePass addPass(MFPM, derived());
if (auto Err = addCoreISelPasses(addPass))
return std::move(Err);
@@ -495,6 +526,68 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
});
addPass(FreeMachineFunctionPass());
+ return verifyStartStop(*StartStopInfo);
+}
+
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+ const TargetPassConfig::StartStopInfo &Info) const {
+ if (!Info.StartPass.empty()) {
+ Started = false;
+ BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StartAfter,
+ Count = 0u](StringRef ClassName) mutable {
+ if (Count == Info.StartInstanceNum) {
+ if (AfterFlag) {
+ AfterFlag = false;
+ Started = true;
+ }
+ return Started;
+ }
+
+ auto PassName = PIC->getPassNameForClassName(ClassName);
+ if (Info.StartPass == PassName && ++Count == Info.StartInstanceNum)
+ Started = !Info.StartAfter;
+
+ return Started;
+ });
+ }
+
+ if (!Info.StopPass.empty()) {
+ Stopped = false;
+ BeforeCallbacks.emplace_back([this, &Info, AfterFlag = Info.StopAfter,
+ Count = 0u](StringRef ClassName) mutable {
+ if (Count == Info.StopInstanceNum) {
+ if (AfterFlag) {
+ AfterFlag = false;
+ Stopped = true;
+ }
+ return !Stopped;
+ }
+
+ auto PassName = PIC->getPassNameForClassName(ClassName);
+ if (Info.StopPass == PassName && ++Count == Info.StopInstanceNum)
+ Stopped = !Info.StopAfter;
+ return !Stopped;
+ });
+ }
+}
+
+template <typename Derived>
+Error CodeGenPassBuilder<Derived>::verifyStartStop(
+ const TargetPassConfig::StartStopInfo &Info) const {
+ if (Started && Stopped)
+ return Error::success();
+
+ if (!Started)
+ return make_error<StringError>(
+ "Can't find start pass \"" +
+ PIC->getPassNameForClassName(Info.StartPass) + "\".",
+ std::make_error_code(std::errc::invalid_argument));
+ if (!Stopped)
+ return make_error<StringError>(
+ "Can't find stop pass \"" +
+ PIC->getPassNameForClassName(Info.StopPass) + "\".",
+ std::make_error_code(std::errc::invalid_argument));
return Error::success();
}
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..de6a760c4e4fd1 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
#include <cassert>
#include <string>
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
static std::string
getLimitedCodeGenPipelineReason(const char *Separator = "/");
+ struct StartStopInfo {
+ bool StartAfter;
+ bool StopAfter;
+ unsigned StartInstanceNum;
+ unsigned StopInstanceNum;
+ StringRef StartPass;
+ StringRef StopPass;
+ };
+
+ /// Returns pass name in `-stop-before` or `-stop-after`
+ /// NOTE: New pass manager migration only
+ static Expected<StartStopInfo>
+ getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 3bbc792f4cbf46..52cf6b84f32722 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -609,6 +609,40 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
registerPartialPipelineCallback(PIC, LLVMTM);
}
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+ auto [StartBefore, StartBeforeInstanceNum] =
+ getPassNameAndInstanceNum(StartBeforeOpt);
+ auto [StartAfter, StartAfterInstanceNum] =
+ getPassNameAndInstanceNum(StartAfterOpt);
+ auto [StopBefore, StopBeforeInstanceNum] =
+ getPassNameAndInstanceNum(StopBeforeOpt);
+ auto [StopAfter, StopAfterInstanceNum] =
+ getPassNameAndInstanceNum(StopAfterOpt);
+
+ if (!StartBefore.empty() && !StartAfter.empty())
+ return make_error<StringError>(
+ Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
+ std::make_error_code(std::errc::invalid_argument));
+ if (!StopBefore.empty() && !StopAfter.empty())
+ return make_error<StringError>(
+ Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
+ std::make_error_code(std::errc::invalid_argument));
+
+ StartStopInfo Result;
+ Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+ Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+ Result.StartInstanceNum =
+ StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+ Result.StopInstanceNum =
+ StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+ Result.StartAfter = !StartAfter.empty();
+ Result.StopAfter = !StopAfter.empty();
+ Result.StartInstanceNum += Result.StartInstanceNum == 0;
+ Result.StopInstanceNum += Result.StopInstanceNum == 0;
+ return Result;
+}
+
// Out of line constructor provides default values for pass options and
// registers all common codegen passes.
TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index d0f3a55a12b056..8843d9bd984ee5 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -92,6 +92,7 @@
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
#include "llvm/CodeGen/StackProtector.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TypePromotion.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -315,7 +316,8 @@ namespace {
/// We currently only use this for --print-before/after.
bool shouldPopulateClassToPassNames() {
return PrintPipelinePasses || !printBeforePasses().empty() ||
- !printAfterPasses().empty() || !isFilterPassesEmpty();
+ !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+ TargetPassConfig::hasLimitedCodeGenPipeline();
}
// A pass for testing -print-on-crash.
diff --git a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
index d6ec393155cf09..63499b056d1ef4 100644
--- a/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
+++ b/llvm/unittests/CodeGen/CodeGenPassBuilderTest.cpp
@@ -138,4 +138,45 @@ TEST_F(CodeGenPassBuilderTest, basic) {
EXPECT_EQ(MIRPipeline, ExpectedMIRPipeline);
}
+// TODO: Move this to lit test when llc support new pm.
+TEST_F(CodeGenPassBuilderTest, start_stop) {
+ static const char *argv[] = {
+ "test",
+ "-start-after=no-op-module",
+ "-stop-before=no-op-function,2",
+ };
+ int argc = std::size(argv);
+ cl::ParseCommandLineOptions(argc, argv);
+
+ LoopAnalysisManager LAM;
+ FunctionAnalysisManager FAM;
+ CGSCCAnalysisManager CGAM;
+ ModuleAnalysisManager MAM;
+
+ PassInstrumentationCallbacks PIC;
+ DummyCodeGenPassBuilder CGPB(*TM, getCGPassBuilderOption(), &PIC);
+ PipelineTuningOptions PTO;
+ PassBuilder PB(TM.get(), PTO, std::nullopt, &PIC);
+
+ PB.registerModuleAnalyses(MAM);
+ PB.registerCGSCCAnalyses(CGAM);
+ PB.registerFunctionAnalyses(FAM);
+ PB.registerLoopAnalyses(LAM);
+ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+ ModulePassManager MPM;
+ MachineFunctionPassManager MFPM;
+
+ Error Err =
+ CGPB.buildPipeline(MPM, MFPM, outs(), nullptr, CodeGenFileType::Null);
+ EXPECT_FALSE(Err);
+ std::string IRPipeline;
+ raw_string_ostream IROS(IRPipeline);
+ MPM.printPipeline(IROS, [&PIC](StringRef Name) {
+ auto PassName = PIC.getPassNameForClassName(Name);
+ return PassName.empty() ? Name : PassName;
+ });
+ EXPECT_EQ(IRPipeline, "function(no-op-function)");
+}
+
} // namespace
More information about the llvm-commits
mailing list