[llvm] [Pass] Support start/stop in instrumentation (PR #70912)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 8 22:45:40 PST 2024
https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/70912
>From 16d2498b05c93819c4f2fa8495862d206481097e Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Wed, 3 Jan 2024 10:34:33 +0800
Subject: [PATCH] [CodeGen][NewPM] Support start/stop in CodeGen
---
.../include/llvm/CodeGen/CodeGenPassBuilder.h | 157 +++++++++++++++---
llvm/include/llvm/CodeGen/TargetPassConfig.h | 15 ++
llvm/include/llvm/IR/PassInstrumentation.h | 6 +
llvm/lib/CodeGen/TargetPassConfig.cpp | 62 +++++--
llvm/lib/IR/PassInstrumentation.cpp | 8 +
llvm/lib/Passes/PassBuilder.cpp | 4 +-
6 files changed, 216 insertions(+), 36 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
index 2100c30aad1180..76782620b9cd7e 100644
--- a/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
+++ b/llvm/include/llvm/CodeGen/CodeGenPassBuilder.h
@@ -40,6 +40,7 @@
#include "llvm/CodeGen/SelectOptimize.h"
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/UnreachableBlockElim.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -172,8 +173,9 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen IR passes.
class AddIRPass {
public:
- AddIRPass(ModulePassManager &MPM, bool DebugPM, bool Check = true)
- : MPM(MPM) {
+ AddIRPass(ModulePassManager &MPM, const DerivedT &PB, bool DebugPM,
+ bool Check = true)
+ : MPM(MPM), PB(PB) {
if (Check)
AddingFunctionPasses = false;
}
@@ -184,10 +186,17 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Add Function Pass
template <typename PassT>
std::enable_if_t<is_detected<is_function_pass_t, PassT>::value>
- operator()(PassT &&Pass) {
+ operator()(PassT &&Pass, StringRef Name = PassT::name()) {
if (AddingFunctionPasses && !*AddingFunctionPasses)
AddingFunctionPasses = true;
+ for (const auto C : PB.BeforeCallbacks)
+ if (!C(Name))
+ return;
+
FPM.addPass(std::forward<PassT>(Pass));
+
+ for (const auto &C : PB.AfterCallbacks)
+ C(Name);
}
// Add Module Pass
@@ -197,12 +206,20 @@ template <typename DerivedT> class CodeGenPassBuilder {
operator()(PassT &&Pass) {
assert((!AddingFunctionPasses || !*AddingFunctionPasses) &&
"could not add module pass after adding function pass");
+ for (const auto C : PB.BeforeCallbacks)
+ if (!C(PassT::name()))
+ return;
+
MPM.addPass(std::forward<PassT>(Pass));
+
+ for (const auto &C : PB.AfterCallbacks)
+ C(PassT::name());
}
private:
ModulePassManager &MPM;
FunctionPassManager FPM;
+ const DerivedT &PB;
// The codegen IR pipeline are mostly function passes with the exceptions of
// a few loop and module passes. `AddingFunctionPasses` make sures that
// we could only add module passes at the beginning of the pipeline. Once
@@ -216,41 +233,34 @@ template <typename DerivedT> class CodeGenPassBuilder {
// Function object to maintain state while adding codegen machine passes.
class AddMachinePass {
public:
- AddMachinePass(MachineFunctionPassManager &PM) : PM(PM) {}
+ AddMachinePass(MachineFunctionPassManager &PM, const DerivedT &PB)
+ : PM(PM), PB(PB) {}
template <typename PassT> void operator()(PassT &&Pass) {
static_assert(
is_detected<has_key_t, PassT>::value,
"Machine function pass must define a static member variable `Key`.");
- for (auto &C : BeforeCallbacks)
- if (!C(&PassT::Key))
+ for (auto &C : PB.BeforeCallbacks)
+ if (!C(PassT::name()))
return;
PM.addPass(std::forward<PassT>(Pass));
- for (auto &C : AfterCallbacks)
- C(&PassT::Key);
+ for (auto &C : PB.AfterCallbacks)
+ C(PassT::name());
}
template <typename PassT> void insertPass(MachinePassKey *ID, PassT Pass) {
- AfterCallbacks.emplace_back(
+ PB.AfterCallbacks.emplace_back(
[this, ID, Pass = std::move(Pass)](MachinePassKey *PassID) {
if (PassID == ID)
this->PM.addPass(std::move(Pass));
});
}
- void disablePass(MachinePassKey *ID) {
- BeforeCallbacks.emplace_back(
- [ID](MachinePassKey *PassID) { return PassID != ID; });
- }
-
MachineFunctionPassManager releasePM() { return std::move(PM); }
private:
MachineFunctionPassManager &PM;
- SmallVector<llvm::unique_function<bool(MachinePassKey *)>, 4>
- BeforeCallbacks;
- SmallVector<llvm::unique_function<void(MachinePassKey *)>, 4>
- AfterCallbacks;
+ const DerivedT &PB;
};
LLVMTargetMachine &TM;
@@ -478,6 +488,19 @@ template <typename DerivedT> class CodeGenPassBuilder {
const DerivedT &derived() const {
return static_cast<const DerivedT &>(*this);
}
+
+ void setStartStopPasses(const TargetPassConfig::StartStopInfo &Info) const;
+
+ Error verifyStartStop(const TargetPassConfig::StartStopInfo &Info) const;
+
+ mutable SmallVector<llvm::unique_function<bool(StringRef)>, 4>
+ BeforeCallbacks;
+ mutable SmallVector<llvm::unique_function<void(StringRef)>, 4> AfterCallbacks;
+
+ /// Helper variable for `-start-before/-start-after/-stop-before/-stop-after`
+ mutable bool ShouldAddPass = true;
+ mutable bool Started = true;
+ mutable bool Stopped = true;
};
template <typename Derived>
@@ -485,12 +508,16 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
ModulePassManager &MPM, MachineFunctionPassManager &MFPM,
raw_pwrite_stream &Out, raw_pwrite_stream *DwoOut,
CodeGenFileType FileType) const {
- AddIRPass addIRPass(MPM, Opt.DebugPM);
+ auto StartStopInfo = TargetPassConfig::getStartStopInfo(*PIC);
+ if (StartStopInfo)
+ return StartStopInfo.takeError();
+ setStartStopPasses(*StartStopInfo);
+ AddIRPass addIRPass(MPM, derived(), Opt.DebugPM);
// `ProfileSummaryInfo` is always valid.
addIRPass(RequireAnalysisPass<ProfileSummaryAnalysis, Module>());
addISelPasses(addIRPass);
- AddMachinePass addPass(MFPM);
+ AddMachinePass addPass(MFPM, derived());
if (auto Err = addCoreISelPasses(addPass))
return std::move(Err);
@@ -503,7 +530,90 @@ Error CodeGenPassBuilder<Derived>::buildPipeline(
});
addPass(FreeMachineFunctionPass());
- return Error::success();
+ return verifyStartStop(StartStopInfo);
+}
+
+template <typename Derived>
+void CodeGenPassBuilder<Derived>::setStartStopPasses(
+ const TargetPassConfig::StartStopInfo &Info) const {
+ if (!Info.StartPass.empty()) {
+ ShouldAddPass = false;
+ Started = false;
+ if (Info.StartBefore) {
+ BeforeCallbacks.emplace_back(
+ [this, &Info, Count = 0](StringRef ClassName) mutable {
+ if (Count == Info.StartInstanceNum)
+ return true;
+ if (Info.StartPass == ClassName)
+ ++Count;
+ if (Count == Info.StartInstanceNum) {
+ ShouldAddPass = true;
+ Started = true;
+ }
+ return ShouldAddPass;
+ });
+ } else {
+ AfterCallbacks.emplace_back(
+ [this, &Info, Count = 0](StringRef ClassName) mutable {
+ if (Count == Info.StartInstanceNum)
+ return;
+ if (Info.StartPass == ClassName)
+ ++Count;
+ if (Count == Info.StartInstanceNum) {
+ ShouldAddPass = true;
+ Started = true;
+ }
+ });
+ }
+ }
+
+ if (!Info.StopPass.empty()) {
+ Stopped = false;
+ if (Info.StopBefore) {
+ BeforeCallbacks.emplace_back(
+ [this, &Info, Count = 0u](StringRef ClassName) mutable {
+ if (Count == Info.StopInstanceNum)
+ return false;
+ if (Info.StopPass == ClassName)
+ ++Count;
+ if (Count == Info.StopInstanceNum) {
+ ShouldAddPass = false;
+ Stopped = true;
+ }
+ return ShouldAddPass;
+ });
+ } else {
+ AfterCallbacks.emplace_back(
+ [this, &Info, Count = 0u](StringRef ClassName) mutable {
+ if (Count == Info.StopInstanceNum)
+ return;
+ if (Info.StopPass == ClassName)
+ ++Count;
+ if (Count == Info.StopInstanceNum) {
+ ShouldAddPass = false;
+ Stopped = true;
+ }
+ });
+ }
+ }
+}
+
+template <typename Derived>
+Error CodeGenPassBuilder<Derived>::verifyStartStop(
+ const TargetPassConfig::StartStopInfo &Info) const {
+ if (Started && Stopped)
+ return Error::success();
+
+ if (!Started)
+ return make_error<StringError>(
+ "Can't find start pass \"" +
+ PIC->getPassNameForClassName(Info.StartPass) + "\".",
+ std::make_error_code(std::errc::invalid_argument));
+ if (!Stopped)
+ return make_error<StringError>(
+ "Can't find stop pass \"" +
+ PIC->getPassNameForClassName(Info.StopPass) + "\".",
+ std::make_error_code(std::errc::invalid_argument));
}
static inline AAManager registerAAAnalyses() {
@@ -623,8 +733,9 @@ void CodeGenPassBuilder<Derived>::addIRPasses(AddIRPass &addPass) const {
// Run loop strength reduction before anything else.
if (getOptLevel() != CodeGenOptLevel::None && !Opt.DisableLSR) {
- addPass(createFunctionToLoopPassAdaptor(
- LoopStrengthReducePass(), /*UseMemorySSA*/ true, Opt.DebugPM));
+ addPass(createFunctionToLoopPassAdaptor(LoopStrengthReducePass(),
+ /*UseMemorySSA*/ true, Opt.DebugPM),
+ LoopStrengthReducePass::name());
// FIXME: use -stop-after so we could remove PrintLSR
if (Opt.PrintLSR)
addPass(PrintFunctionPass(dbgs(), "\n\n*** Code after LSR ***\n"));
diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h
index 66365419aa330b..2823485562c20f 100644
--- a/llvm/include/llvm/CodeGen/TargetPassConfig.h
+++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h
@@ -15,6 +15,7 @@
#include "llvm/Pass.h"
#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/Error.h"
#include <cassert>
#include <string>
@@ -176,6 +177,20 @@ class TargetPassConfig : public ImmutablePass {
static std::string
getLimitedCodeGenPipelineReason(const char *Separator = "/");
+ struct StartStopInfo {
+ bool StartBefore;
+ bool StopBefore;
+ unsigned StartInstanceNum;
+ unsigned StopInstanceNum;
+ StringRef StartPass;
+ StringRef StopPass;
+ };
+
+ /// Returns pass name in `-stop-before` or `-stop-after`
+ /// NOTE: New pass manager migration only
+ static Expected<StartStopInfo>
+ getStartStopInfo(PassInstrumentationCallbacks &PIC);
+
void setDisableVerify(bool Disable) { setOpt(DisableVerify, Disable); }
bool getEnableTailMerge() const { return EnableTailMerge; }
diff --git a/llvm/include/llvm/IR/PassInstrumentation.h b/llvm/include/llvm/IR/PassInstrumentation.h
index 519a5e46b4373b..ae493c740908d1 100644
--- a/llvm/include/llvm/IR/PassInstrumentation.h
+++ b/llvm/include/llvm/IR/PassInstrumentation.h
@@ -84,6 +84,7 @@ class PassInstrumentationCallbacks {
using AfterAnalysisFunc = void(StringRef, Any);
using AnalysisInvalidatedFunc = void(StringRef, Any);
using AnalysesClearedFunc = void(StringRef);
+ using StartStopFunc = bool(StringRef);
public:
PassInstrumentationCallbacks() = default;
@@ -152,6 +153,11 @@ class PassInstrumentationCallbacks {
void addClassToPassName(StringRef ClassName, StringRef PassName);
/// Get the pass name for a given pass class name.
StringRef getPassNameForClassName(StringRef ClassName);
+ /// Inverse of getPassNameForClassName
+ StringRef getClassNameForPassName(StringRef PassName);
+
+ /// Helper callback to support options like start-before.
+ llvm::unique_function<StartStopFunc> StartStopCallback;
private:
friend class PassInstrumentation;
diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp
index 4003a08a5422dd..313dbe96969b0e 100644
--- a/llvm/lib/CodeGen/TargetPassConfig.cpp
+++ b/llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -526,17 +526,16 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
getPassNameAndInstanceNum(StopAfterOpt);
if (StartBefore.empty() && StartAfter.empty() && StopBefore.empty() &&
- StopAfter.empty())
+ StopAfter.empty()) {
+ PIC.StartStopCallback = [](StringRef) { return true; };
return;
+ }
+
+ StartBefore = PIC.getClassNameForPassName(StartBefore);
+ StartAfter = PIC.getClassNameForPassName(StartAfter);
+ StopBefore = PIC.getClassNameForPassName(StopBefore);
+ StopAfter = PIC.getClassNameForPassName(StopAfter);
- std::tie(StartBefore, std::ignore) =
- LLVMTM.getPassNameFromLegacyName(StartBefore);
- std::tie(StartAfter, std::ignore) =
- LLVMTM.getPassNameFromLegacyName(StartAfter);
- std::tie(StopBefore, std::ignore) =
- LLVMTM.getPassNameFromLegacyName(StopBefore);
- std::tie(StopAfter, std::ignore) =
- LLVMTM.getPassNameFromLegacyName(StopAfter);
if (!StartBefore.empty() && !StartAfter.empty())
report_fatal_error(Twine(StartBeforeOptName) + Twine(" and ") +
Twine(StartAfterOptName) + Twine(" specified!"));
@@ -544,11 +543,11 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
report_fatal_error(Twine(StopBeforeOptName) + Twine(" and ") +
Twine(StopAfterOptName) + Twine(" specified!"));
- PIC.registerShouldRunOptionalPassCallback(
+ PIC.StartStopCallback =
[=, EnableCurrent = StartBefore.empty() && StartAfter.empty(),
EnableNext = std::optional<bool>(), StartBeforeCount = 0u,
StartAfterCount = 0u, StopBeforeCount = 0u,
- StopAfterCount = 0u](StringRef P, Any) mutable {
+ StopAfterCount = 0u](StringRef P) mutable {
bool StartBeforePass = !StartBefore.empty() && P.contains(StartBefore);
bool StartAfterPass = !StartAfter.empty() && P.contains(StartAfter);
bool StopBeforePass = !StopBefore.empty() && P.contains(StopBefore);
@@ -576,7 +575,7 @@ static void registerPartialPipelineCallback(PassInstrumentationCallbacks &PIC,
if (StopBeforePass && StopBeforeCount++ == StopBeforeInstanceNum)
EnableCurrent = false;
return EnableCurrent;
- });
+ };
}
void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
@@ -609,6 +608,45 @@ void llvm::registerCodeGenCallback(PassInstrumentationCallbacks &PIC,
registerPartialPipelineCallback(PIC, LLVMTM);
}
+Expected<TargetPassConfig::StartStopInfo>
+TargetPassConfig::getStartStopInfo(PassInstrumentationCallbacks &PIC) {
+ auto [StartBefore, StartBeforeInstanceNum] =
+ getPassNameAndInstanceNum(StartBeforeOpt);
+ auto [StartAfter, StartAfterInstanceNum] =
+ getPassNameAndInstanceNum(StartAfterOpt);
+ auto [StopBefore, StopBeforeInstanceNum] =
+ getPassNameAndInstanceNum(StopBeforeOpt);
+ auto [StopAfter, StopAfterInstanceNum] =
+ getPassNameAndInstanceNum(StopAfterOpt);
+
+ StartBefore = PIC.getClassNameForPassName(StartBefore);
+ StartAfter = PIC.getClassNameForPassName(StartAfter);
+ StopBefore = PIC.getClassNameForPassName(StopBefore);
+ StopAfter = PIC.getClassNameForPassName(StopAfter);
+
+ if (!StartBefore.empty() && !StartAfter.empty())
+ return make_error<StringError>(
+ Twine(StartBeforeOptName) + " and " + StartAfterOptName + " specified!",
+ std::make_error_code(std::errc::invalid_argument));
+ if (!StopBefore.empty() && !StopAfter.empty())
+ return make_error<StringError>(
+ Twine(StopBeforeOptName) + " and " + StopAfterOptName + " specified!",
+ std::make_error_code(std::errc::invalid_argument));
+
+ StartStopInfo Result;
+ Result.StartPass = StartBefore.empty() ? StartAfter : StartBefore;
+ Result.StopPass = StopBefore.empty() ? StopAfter : StopBefore;
+ Result.StartInstanceNum =
+ StartBefore.empty() ? StartAfterInstanceNum : StartBeforeInstanceNum;
+ Result.StopInstanceNum =
+ StopBefore.empty() ? StopAfterInstanceNum : StopBeforeInstanceNum;
+ Result.StartBefore = !StartBefore.empty();
+ Result.StopBefore = !StopBefore.empty();
+ Result.StartInstanceNum += Result.StartInstanceNum == 0;
+ Result.StopInstanceNum += Result.StopInstanceNum == 0;
+ return Result;
+}
+
// Out of line constructor provides default values for pass options and
// registers all common codegen passes.
TargetPassConfig::TargetPassConfig(LLVMTargetMachine &TM, PassManagerBase &pm)
diff --git a/llvm/lib/IR/PassInstrumentation.cpp b/llvm/lib/IR/PassInstrumentation.cpp
index 6d5f3acb7a4d35..e64c30132793a4 100644
--- a/llvm/lib/IR/PassInstrumentation.cpp
+++ b/llvm/lib/IR/PassInstrumentation.cpp
@@ -28,6 +28,14 @@ PassInstrumentationCallbacks::getPassNameForClassName(StringRef ClassName) {
return ClassToPassName[ClassName];
}
+StringRef
+PassInstrumentationCallbacks::getClassNameForPassName(StringRef PassName) {
+ for (const auto &P : ClassToPassName)
+ if (P.second == PassName)
+ return P.first();
+ return StringRef();
+}
+
AnalysisKey PassInstrumentationAnalysis::Key;
bool isSpecialPass(StringRef PassID, const std::vector<StringRef> &Specials) {
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 649451edc0e2c6..e5eadbdfc7dfe0 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -88,6 +88,7 @@
#include "llvm/CodeGen/SelectOptimize.h"
#include "llvm/CodeGen/ShadowStackGCLowering.h"
#include "llvm/CodeGen/SjLjEHPrepare.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TypePromotion.h"
#include "llvm/CodeGen/WasmEHPrepare.h"
#include "llvm/CodeGen/WinEHPrepare.h"
@@ -407,7 +408,8 @@ AnalysisKey NoOpLoopAnalysis::Key;
/// We currently only use this for --print-before/after.
bool shouldPopulateClassToPassNames() {
return PrintPipelinePasses || !printBeforePasses().empty() ||
- !printAfterPasses().empty() || !isFilterPassesEmpty();
+ !printAfterPasses().empty() || !isFilterPassesEmpty() ||
+ TargetPassConfig::hasLimitedCodeGenPipeline();
}
// A pass for testing -print-on-crash.
More information about the llvm-commits
mailing list