[llvm] [CodeGen] Let `PassBuilder` support machine passes (PR #76320)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 8 19:34:29 PST 2024
https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/76320
>From 0a6fe7476748ab01ccaa9a93714c425e940b6101 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Sun, 24 Dec 2023 15:19:12 +0800
Subject: [PATCH] [CodeGen] Let `PassBuilder` support machine passes
`PassBuilder` would be a better place to parse MIR pipeline. We can reuse the code to support parsing pass with parameters and targets can reuse `registerPassBuilderCallbacks` to register the target specific passes. `PassBuilder` also has ability to check whether a Pass is a machine pass, so it can replace part of the work of `LLVMTargetMachine::getPassNameFromLegacyName`.
---
llvm/include/llvm/Passes/PassBuilder.h | 39 ++
llvm/include/llvm/Target/TargetMachine.h | 3 +-
llvm/lib/Passes/PassBuilder.cpp | 105 ++++
llvm/unittests/MIR/CMakeLists.txt | 2 +
.../MIR/PassBuilderCallbacksTest.cpp | 499 ++++++++++++++++++
5 files changed, 646 insertions(+), 2 deletions(-)
create mode 100644 llvm/unittests/MIR/PassBuilderCallbacksTest.cpp
diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index 61417431f8a8f3..6b0ad7e7d11a57 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -16,6 +16,7 @@
#define LLVM_PASSES_PASSBUILDER_H
#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/CodeGen/MachinePassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/Error.h"
@@ -165,6 +166,14 @@ class PassBuilder {
/// additional analyses.
void registerLoopAnalyses(LoopAnalysisManager &LAM);
+ /// Registers all available machine function analysis passes.
+ ///
+ /// This is an interface that can be used to populate a \c
+ /// MachineFunctionAnalysisManager with all registered function analyses.
+ /// Callers can still manually register any additional analyses. Callers can
+ /// also pre-register analyses and this will not override those.
+ void registerMachineFunctionAnalyses(MachineFunctionAnalysisManager &MFAM);
+
/// Construct the core LLVM function canonicalization and simplification
/// pipeline.
///
@@ -352,6 +361,18 @@ class PassBuilder {
Error parsePassPipeline(LoopPassManager &LPM, StringRef PipelineText);
/// @}}
+ /// Parse a textual MIR pipeline into the provided \c MachineFunctionPass
+ /// manager.
+ /// The format of the textual machine pipeline is a comma separated list of
+ /// machine pass names:
+ ///
+ /// machine-funciton-pass,machine-module-pass,...
+ ///
+ /// There is no need to specify the pass nesting, and this function
+ /// currently cannot handle the pass nesting.
+ Error parsePassPipeline(MachineFunctionPassManager &MFPM,
+ StringRef PipelineText);
+
/// Parse a textual alias analysis pipeline into the provided AA manager.
///
/// The format of the textual AA pipeline is a comma separated list of AA
@@ -520,6 +541,10 @@ class PassBuilder {
const std::function<void(ModuleAnalysisManager &)> &C) {
ModuleAnalysisRegistrationCallbacks.push_back(C);
}
+ void registerAnalysisRegistrationCallback(
+ const std::function<void(MachineFunctionAnalysisManager &)> &C) {
+ MachineFunctionAnalysisRegistrationCallbacks.push_back(C);
+ }
/// @}}
/// {{@ Register pipeline parsing callbacks with this pass builder instance.
@@ -546,6 +571,11 @@ class PassBuilder {
ArrayRef<PipelineElement>)> &C) {
ModulePipelineParsingCallbacks.push_back(C);
}
+ void registerPipelineParsingCallback(
+ const std::function<bool(StringRef Name, MachineFunctionPassManager &)>
+ &C) {
+ MachinePipelineParsingCallbacks.push_back(C);
+ }
/// @}}
/// Register a callback for a top-level pipeline entry.
@@ -616,8 +646,12 @@ class PassBuilder {
Error parseCGSCCPass(CGSCCPassManager &CGPM, const PipelineElement &E);
Error parseFunctionPass(FunctionPassManager &FPM, const PipelineElement &E);
Error parseLoopPass(LoopPassManager &LPM, const PipelineElement &E);
+ Error parseMachinePass(MachineFunctionPassManager &MFPM,
+ const PipelineElement &E);
bool parseAAPassName(AAManager &AA, StringRef Name);
+ Error parseMachinePassPipeline(MachineFunctionPassManager &MFPM,
+ ArrayRef<PipelineElement> Pipeline);
Error parseLoopPassPipeline(LoopPassManager &LPM,
ArrayRef<PipelineElement> Pipeline);
Error parseFunctionPassPipeline(FunctionPassManager &FPM,
@@ -699,6 +733,11 @@ class PassBuilder {
// AA callbacks
SmallVector<std::function<bool(StringRef Name, AAManager &AA)>, 2>
AAParsingCallbacks;
+ // Machine pass callbackcs
+ SmallVector<std::function<void(MachineFunctionAnalysisManager &)>, 2>
+ MachineFunctionAnalysisRegistrationCallbacks;
+ SmallVector<std::function<bool(StringRef, MachineFunctionPassManager &)>, 2>
+ MachinePipelineParsingCallbacks;
};
/// This utility template takes care of adding require<> and invalidate<>
diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h
index 1fe47dec70b163..ec82cf90521551 100644
--- a/llvm/include/llvm/Target/TargetMachine.h
+++ b/llvm/include/llvm/Target/TargetMachine.h
@@ -463,8 +463,7 @@ class LLVMTargetMachine : public TargetMachine {
}
virtual std::pair<StringRef, bool> getPassNameFromLegacyName(StringRef) {
- llvm_unreachable(
- "getPassNameFromLegacyName parseMIRPipeline is not overridden");
+ llvm_unreachable("getPassNameFromLegacyName is not overridden");
}
/// Add passes to the specified pass manager to get machine code emitted with
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 649451edc0e2c6..85e39d0af11523 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -73,6 +73,7 @@
#include "llvm/Analysis/TypeBasedAliasAnalysis.h"
#include "llvm/Analysis/UniformityAnalysis.h"
#include "llvm/CodeGen/CallBrPrepare.h"
+#include "llvm/CodeGen/CodeGenPassBuilder.h"
#include "llvm/CodeGen/DwarfEHPrepare.h"
#include "llvm/CodeGen/ExpandLargeDivRem.h"
#include "llvm/CodeGen/ExpandLargeFpConvert.h"
@@ -484,6 +485,18 @@ PassBuilder::PassBuilder(TargetMachine *TM, PipelineTuningOptions PTO,
#define CGSCC_ANALYSIS(NAME, CREATE_PASS) \
PIC->addClassToPassName(decltype(CREATE_PASS)::name(), NAME);
#include "PassRegistry.def"
+
+#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \
+ PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#include "llvm/CodeGen/MachinePassRegistry.def"
}
}
@@ -519,6 +532,19 @@ void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) {
C(FAM);
}
+void PassBuilder::registerMachineFunctionAnalyses(
+ MachineFunctionAnalysisManager &MFAM) {
+
+#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \
+ MFAM.registerPass([&] { return PASS_NAME(); });
+#define DUMMY_MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \
+ MFAM.registerPass([&] { return PASS_NAME(); });
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+ for (auto &C : MachineFunctionAnalysisRegistrationCallbacks)
+ C(MFAM);
+}
+
void PassBuilder::registerLoopAnalyses(LoopAnalysisManager &LAM) {
#define LOOP_ANALYSIS(NAME, CREATE_PASS) \
LAM.registerPass([&] { return CREATE_PASS; });
@@ -1873,6 +1899,43 @@ Error PassBuilder::parseLoopPass(LoopPassManager &LPM,
inconvertibleErrorCode());
}
+Error PassBuilder::parseMachinePass(MachineFunctionPassManager &MFPM,
+ const PipelineElement &E) {
+ StringRef Name = E.Name;
+ if (!E.InnerPipeline.empty())
+ return make_error<StringError>("invalid pipeline",
+ inconvertibleErrorCode());
+
+#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ if (Name == NAME) { \
+ MFPM.addPass(PASS_NAME()); \
+ return Error::success(); \
+ }
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ if (Name == NAME) { \
+ MFPM.addPass(PASS_NAME()); \
+ return Error::success(); \
+ }
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ if (Name == NAME) { \
+ MFPM.addPass(PASS_NAME()); \
+ return Error::success(); \
+ }
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ if (Name == NAME) { \
+ MFPM.addPass(PASS_NAME()); \
+ return Error::success(); \
+ }
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+ for (auto &C : MachinePipelineParsingCallbacks)
+ if (C(Name, MFPM))
+ return Error::success();
+ return make_error<StringError>(
+ formatv("unknown machine pass '{0}'", Name).str(),
+ inconvertibleErrorCode());
+}
+
bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) {
#define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS) \
if (Name == NAME) { \
@@ -1894,6 +1957,15 @@ bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) {
return false;
}
+Error PassBuilder::parseMachinePassPipeline(
+ MachineFunctionPassManager &MFPM, ArrayRef<PipelineElement> Pipeline) {
+ for (const auto &Element : Pipeline) {
+ if (auto Err = parseMachinePass(MFPM, Element))
+ return Err;
+ }
+ return Error::success();
+}
+
Error PassBuilder::parseLoopPassPipeline(LoopPassManager &LPM,
ArrayRef<PipelineElement> Pipeline) {
for (const auto &Element : Pipeline) {
@@ -2053,6 +2125,20 @@ Error PassBuilder::parsePassPipeline(LoopPassManager &CGPM,
return Error::success();
}
+Error PassBuilder::parsePassPipeline(MachineFunctionPassManager &MFPM,
+ StringRef PipelineText) {
+ auto Pipeline = parsePipelineText(PipelineText);
+ if (!Pipeline || Pipeline->empty())
+ return make_error<StringError>(
+ formatv("invalid machine pass pipeline '{0}'", PipelineText).str(),
+ inconvertibleErrorCode());
+
+ if (auto Err = parseMachinePassPipeline(MFPM, *Pipeline))
+ return Err;
+
+ return Error::success();
+}
+
Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
// If the pipeline just consists of the word 'default' just replace the AA
// manager with our default one.
@@ -2147,6 +2233,25 @@ void PassBuilder::printPassNames(raw_ostream &OS) {
OS << "Loop analyses:\n";
#define LOOP_ANALYSIS(NAME, CREATE_PASS) printPassName(NAME, OS);
#include "PassRegistry.def"
+
+ OS << "Machine module passes (WIP):\n";
+#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ printPassName(NAME, OS);
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+ OS << "Machine function passes (WIP):\n";
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ printPassName(NAME, OS);
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \
+ printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+ OS << "Machine function analyses (WIP):\n";
+#define DUMMY_MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \
+ printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
}
void PassBuilder::registerParseTopLevelPipelineCallback(
diff --git a/llvm/unittests/MIR/CMakeLists.txt b/llvm/unittests/MIR/CMakeLists.txt
index 3c0e9e43f9afb1..f485dcbd971b6a 100644
--- a/llvm/unittests/MIR/CMakeLists.txt
+++ b/llvm/unittests/MIR/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
FileCheck
MC
MIRParser
+ Passes
Support
Target
TargetParser
@@ -13,6 +14,7 @@ set(LLVM_LINK_COMPONENTS
add_llvm_unittest(MIRTests
MachineMetadata.cpp
+ PassBuilderCallbacksTest.cpp
)
target_link_libraries(MIRTests PRIVATE LLVMTestingSupport)
diff --git a/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp
new file mode 100644
index 00000000000000..29e238e65c12a4
--- /dev/null
+++ b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp
@@ -0,0 +1,499 @@
+//===- unittests/MIR/PassBuilderCallbacksTest.cpp - PB Callback Tests --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Testing/Support/Error.h"
+#include <functional>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <llvm/ADT/Any.h>
+#include <llvm/AsmParser/Parser.h>
+#include <llvm/CodeGen/MIRParser/MIRParser.h>
+#include <llvm/CodeGen/MachineFunction.h>
+#include <llvm/CodeGen/MachineModuleInfo.h>
+#include <llvm/CodeGen/MachinePassManager.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/PassInstrumentation.h>
+#include <llvm/IR/PassManager.h>
+#include <llvm/Passes/PassBuilder.h>
+#include <llvm/Support/Regex.h>
+#include <llvm/Support/SourceMgr.h>
+#include <llvm/Support/TargetSelect.h>
+
+using namespace llvm;
+
+namespace {
+using testing::_;
+using testing::AnyNumber;
+using testing::DoAll;
+using testing::Not;
+using testing::Return;
+using testing::WithArgs;
+
+StringRef MIRString = R"MIR(
+--- |
+ define void @test() {
+ ret void
+ }
+...
+---
+name: test
+body: |
+ bb.0 (%ir-block.0):
+ RET64
+...
+)MIR";
+
+/// Helper for HasName matcher that returns getName both for IRUnit and
+/// for IRUnit pointer wrapper into llvm::Any (wrapped by PassInstrumentation).
+template <typename IRUnitT> std::string getName(const IRUnitT &IR) {
+ return std::string(IR.getName());
+}
+
+template <> std::string getName(const StringRef &name) {
+ return std::string(name);
+}
+
+template <> std::string getName(const Any &WrappedIR) {
+ if (const auto *const *M = llvm::any_cast<const Module *>(&WrappedIR))
+ return (*M)->getName().str();
+ if (const auto *const *F = llvm::any_cast<const Function *>(&WrappedIR))
+ return (*F)->getName().str();
+ if (const auto *const *MF =
+ llvm::any_cast<const MachineFunction *>(&WrappedIR))
+ return (*MF)->getName().str();
+ return "<UNKNOWN>";
+}
+/// Define a custom matcher for objects which support a 'getName' method.
+///
+/// LLVM often has IR objects or analysis objects which expose a name
+/// and in tests it is convenient to match these by name for readability.
+/// Usually, this name is either a StringRef or a plain std::string. This
+/// matcher supports any type exposing a getName() method of this form whose
+/// return value is compatible with an std::ostream. For StringRef, this uses
+/// the shift operator defined above.
+///
+/// It should be used as:
+///
+/// HasName("my_function")
+///
+/// No namespace or other qualification is required.
+MATCHER_P(HasName, Name, "") {
+ *result_listener << "has name '" << getName(arg) << "'";
+ return Name == getName(arg);
+}
+
+MATCHER_P(HasNameRegex, Name, "") {
+ *result_listener << "has name '" << getName(arg) << "'";
+ llvm::Regex r(Name);
+ return r.match(getName(arg));
+}
+
+struct MockPassInstrumentationCallbacks {
+ PassInstrumentationCallbacks Callbacks;
+
+ MockPassInstrumentationCallbacks() {
+ ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true));
+ }
+ MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runBeforeSkippedPass, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD3(runAfterPass,
+ void(StringRef PassID, llvm::Any, const PreservedAnalyses &PA));
+ MOCK_METHOD2(runAfterPassInvalidated,
+ void(StringRef PassID, const PreservedAnalyses &PA));
+ MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runAfterAnalysis, void(StringRef PassID, llvm::Any));
+
+ void registerPassInstrumentation() {
+ Callbacks.registerShouldRunOptionalPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ return this->runBeforePass(P, IR);
+ });
+ Callbacks.registerBeforeSkippedPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ this->runBeforeSkippedPass(P, IR);
+ });
+ Callbacks.registerBeforeNonSkippedPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ this->runBeforeNonSkippedPass(P, IR);
+ });
+ Callbacks.registerAfterPassCallback(
+ [this](StringRef P, llvm::Any IR, const PreservedAnalyses &PA) {
+ this->runAfterPass(P, IR, PA);
+ });
+ Callbacks.registerAfterPassInvalidatedCallback(
+ [this](StringRef P, const PreservedAnalyses &PA) {
+ this->runAfterPassInvalidated(P, PA);
+ });
+ Callbacks.registerBeforeAnalysisCallback([this](StringRef P, llvm::Any IR) {
+ return this->runBeforeAnalysis(P, IR);
+ });
+ Callbacks.registerAfterAnalysisCallback(
+ [this](StringRef P, llvm::Any IR) { this->runAfterAnalysis(P, IR); });
+ }
+
+ void ignoreNonMockPassInstrumentation(StringRef IRName) {
+ // Generic EXPECT_CALLs are needed to match instrumentation on unimportant
+ // parts of a pipeline that we do not care about (e.g. various passes added
+ // by default by PassBuilder - Verifier pass etc).
+ // Make sure to avoid ignoring Mock passes/analysis, we definitely want
+ // to check these explicitly.
+ EXPECT_CALL(*this,
+ runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(
+ *this, runBeforeSkippedPass(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName), _))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runBeforeAnalysis(HasNameRegex("MachineModuleAnalysis"),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runBeforeAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runAfterAnalysis(HasNameRegex("MachineModuleAnalysis"),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runAfterAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ }
+};
+
+template <typename DerivedT> class MockAnalysisHandleBase {
+public:
+ class Analysis : public AnalysisInfoMixin<Analysis> {
+ friend AnalysisInfoMixin<Analysis>;
+ friend MockAnalysisHandleBase;
+ static AnalysisKey Key;
+
+ DerivedT *Handle;
+
+ Analysis(DerivedT &Handle) : Handle(&Handle) {
+ static_assert(std::is_base_of<MockAnalysisHandleBase, DerivedT>::value,
+ "Must pass the derived type to this template!");
+ }
+
+ public:
+ class Result {
+ friend MockAnalysisHandleBase;
+
+ DerivedT *Handle;
+
+ Result(DerivedT &Handle) : Handle(&Handle) {}
+
+ public:
+ // Forward invalidation events to the mock handle.
+ bool invalidate(MachineFunction &IR, const PreservedAnalyses &PA,
+ MachineFunctionAnalysisManager::Invalidator &Inv) {
+ return Handle->invalidate(IR, PA, Inv);
+ }
+ };
+
+ Result run(MachineFunction &IR, MachineFunctionAnalysisManager::Base &AM) {
+ return Handle->run(IR, AM);
+ }
+ };
+
+ Analysis getAnalysis() { return Analysis(static_cast<DerivedT &>(*this)); }
+ typename Analysis::Result getResult() {
+ return typename Analysis::Result(static_cast<DerivedT &>(*this));
+ }
+ static StringRef getName() { return llvm::getTypeName<DerivedT>(); }
+
+protected:
+ // FIXME: MSVC seems unable to handle a lambda argument to Invoke from within
+ // the template, so we use a boring static function.
+ static bool
+ invalidateCallback(MachineFunction &IR, const PreservedAnalyses &PA,
+ MachineFunctionAnalysisManager::Invalidator &Inv) {
+ auto PAC = PA.template getChecker<Analysis>();
+ return !PAC.preserved() &&
+ !PAC.template preservedSet<AllAnalysesOn<MachineFunction>>();
+ }
+
+ /// Derived classes should call this in their constructor to set up default
+ /// mock actions. (We can't do this in our constructor because this has to
+ /// run after the DerivedT is constructed.)
+ void setDefaults() {
+ ON_CALL(static_cast<DerivedT &>(*this), run(_, _))
+ .WillByDefault(Return(this->getResult()));
+ ON_CALL(static_cast<DerivedT &>(*this), invalidate(_, _, _))
+ .WillByDefault(&invalidateCallback);
+ }
+};
+
+template <typename DerivedT> class MockPassHandleBase {
+public:
+ class Pass : public MachinePassInfoMixin<Pass> {
+ friend MockPassHandleBase;
+
+ DerivedT *Handle;
+
+ Pass(DerivedT &Handle) : Handle(&Handle) {
+ static_assert(std::is_base_of<MockPassHandleBase, DerivedT>::value,
+ "Must pass the derived type to this template!");
+ }
+
+ public:
+ static MachinePassKey Key;
+ PreservedAnalyses run(MachineFunction &IR,
+ MachineFunctionAnalysisManager::Base &AM) {
+ return Handle->run(IR, AM);
+ }
+ };
+
+ static StringRef getName() { return llvm::getTypeName<DerivedT>(); }
+
+ Pass getPass() { return Pass(static_cast<DerivedT &>(*this)); }
+
+protected:
+ /// Derived classes should call this in their constructor to set up default
+ /// mock actions. (We can't do this in our constructor because this has to
+ /// run after the DerivedT is constructed.)
+ void setDefaults() {
+ ON_CALL(static_cast<DerivedT &>(*this), run(_, _))
+ .WillByDefault(Return(PreservedAnalyses::all()));
+ }
+};
+
+struct MockAnalysisHandle : public MockAnalysisHandleBase<MockAnalysisHandle> {
+ MOCK_METHOD2(run, Analysis::Result(MachineFunction &,
+ MachineFunctionAnalysisManager::Base &));
+
+ MOCK_METHOD3(invalidate, bool(MachineFunction &, const PreservedAnalyses &,
+ MachineFunctionAnalysisManager::Invalidator &));
+
+ MockAnalysisHandle() { setDefaults(); }
+};
+
+template <typename DerivedT>
+MachinePassKey MockPassHandleBase<DerivedT>::Pass::Key;
+
+template <typename DerivedT>
+AnalysisKey MockAnalysisHandleBase<DerivedT>::Analysis::Key;
+
+class MockPassHandle : public MockPassHandleBase<MockPassHandle> {
+public:
+ MOCK_METHOD2(run, PreservedAnalyses(MachineFunction &,
+ MachineFunctionAnalysisManager::Base &));
+
+ MockPassHandle() { setDefaults(); }
+};
+
+class MachineFunctionCallbacksTest : public testing::Test {
+protected:
+ static void SetUpTestCase() {
+ InitializeAllTargetInfos();
+ InitializeAllTargets();
+ InitializeAllTargetMCs();
+ }
+
+ TargetMachine *TM;
+
+ LLVMContext Context;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<MIRParser> MIR;
+
+ MockPassInstrumentationCallbacks CallbacksHandle;
+
+ PassBuilder PB;
+ ModulePassManager PM;
+ MachineFunctionPassManager MFPM;
+ FunctionAnalysisManager FAM;
+ ModuleAnalysisManager AM;
+ MachineFunctionAnalysisManager MFAM;
+
+ MockPassHandle PassHandle;
+ MockAnalysisHandle AnalysisHandle;
+
+ std::unique_ptr<Module> parseMIR(const TargetMachine &TM, StringRef MIRCode,
+ MachineModuleInfo &MMI) {
+ SMDiagnostic Diagnostic;
+ std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
+ MIR = createMIRParser(std::move(MBuffer), Context);
+ if (!MIR)
+ return nullptr;
+
+ std::unique_ptr<Module> Mod = MIR->parseIRModule();
+ if (!Mod)
+ return nullptr;
+
+ Mod->setDataLayout(TM.createDataLayout());
+
+ if (MIR->parseMachineFunctions(*Mod, MMI)) {
+ M.reset();
+ return nullptr;
+ }
+ return Mod;
+ }
+
+ static PreservedAnalyses
+ getAnalysisResult(MachineFunction &U,
+ MachineFunctionAnalysisManager::Base &AM) {
+ auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
+ MFAM.getResult<MockAnalysisHandle::Analysis>(U);
+ return PreservedAnalyses::all();
+ }
+
+ void SetUp() override {
+ std::string Error;
+ auto TripleName = "x86_64-pc-linux-gnu";
+ auto *T = TargetRegistry::lookupTarget(TripleName, Error);
+ if (!T)
+ GTEST_SKIP();
+ TM = T->createTargetMachine(TripleName, "", "", TargetOptions(),
+ std::nullopt);
+ MachineModuleInfo MMI(static_cast<LLVMTargetMachine *>(TM));
+ M = parseMIR(*TM, MIRString, MMI);
+ AM.registerPass([&] {
+ return MachineModuleAnalysis(static_cast<LLVMTargetMachine *>(TM));
+ });
+ }
+
+ MachineFunctionCallbacksTest()
+ : CallbacksHandle(), PB(nullptr, PipelineTuningOptions(), std::nullopt,
+ &CallbacksHandle.Callbacks),
+ PM(), FAM(), AM(), MFAM(FAM, AM) {
+
+ EXPECT_TRUE(&CallbacksHandle.Callbacks ==
+ PB.getPassInstrumentationCallbacks());
+
+ /// Register a callback for analysis registration.
+ ///
+ /// The callback is a function taking a reference to an AnalyisManager
+ /// object. When called, the callee gets to register its own analyses with
+ /// this PassBuilder instance.
+ PB.registerAnalysisRegistrationCallback(
+ [this](MachineFunctionAnalysisManager &AM) {
+ // Register our mock analysis
+ AM.registerPass([this] { return AnalysisHandle.getAnalysis(); });
+ });
+
+ /// Register a callback for pipeline parsing.
+ ///
+ /// During parsing of a textual pipeline, the PassBuilder will call these
+ /// callbacks for each encountered pass name that it does not know. This
+ /// includes both simple pass names as well as names of sub-pipelines. In
+ /// the latter case, the InnerPipeline is not empty.
+ PB.registerPipelineParsingCallback(
+ [this](StringRef Name, MachineFunctionPassManager &PM) {
+ if (parseAnalysisUtilityPasses<MockAnalysisHandle::Analysis>(
+ "test-analysis", Name, PM))
+ return true;
+
+ /// Parse the name of our pass mock handle
+ if (Name == "test-transform") {
+ MFPM.addPass(PassHandle.getPass());
+ return true;
+ }
+ return false;
+ });
+
+ /// Register builtin analyses and cross-register the analysis proxies
+ PB.registerModuleAnalyses(AM);
+ PB.registerFunctionAnalyses(FAM);
+ PB.registerMachineFunctionAnalyses(MFAM);
+ }
+};
+
+TEST_F(MachineFunctionCallbacksTest, Passes) {
+ EXPECT_CALL(AnalysisHandle, run(HasName("test"), _));
+ EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult);
+
+ StringRef PipelineText = "test-transform";
+ ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+ << "Pipeline was: " << PipelineText;
+ ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+TEST_F(MachineFunctionCallbacksTest, InstrumentedPasses) {
+ CallbacksHandle.registerPassInstrumentation();
+ // Non-mock instrumentation not specifically mentioned below can be ignored.
+ CallbacksHandle.ignoreNonMockPassInstrumentation("<string>");
+ CallbacksHandle.ignoreNonMockPassInstrumentation("test");
+ CallbacksHandle.ignoreNonMockPassInstrumentation("");
+
+ // PassInstrumentation calls should happen in-sequence, in the same order
+ // as passes/analyses are scheduled.
+ ::testing::Sequence PISequence;
+ EXPECT_CALL(CallbacksHandle,
+ runBeforePass(HasNameRegex("MockPassHandle"), HasName("test")))
+ .InSequence(PISequence);
+ EXPECT_CALL(
+ CallbacksHandle,
+ runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("test")))
+ .InSequence(PISequence);
+ EXPECT_CALL(CallbacksHandle,
+ runAfterPass(HasNameRegex("MockPassHandle"), HasName("test"), _))
+ .InSequence(PISequence);
+
+ EXPECT_CALL(AnalysisHandle, run(HasName("test"), _));
+ EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult);
+
+ StringRef PipelineText = "test-transform";
+ ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+ << "Pipeline was: " << PipelineText;
+ ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+TEST_F(MachineFunctionCallbacksTest, InstrumentedSkippedPasses) {
+ CallbacksHandle.registerPassInstrumentation();
+ // Non-mock instrumentation run here can safely be ignored.
+ CallbacksHandle.ignoreNonMockPassInstrumentation("<string>");
+ CallbacksHandle.ignoreNonMockPassInstrumentation("test");
+ CallbacksHandle.ignoreNonMockPassInstrumentation("");
+
+ // Skip the pass by returning false.
+ EXPECT_CALL(CallbacksHandle,
+ runBeforePass(HasNameRegex("MockPassHandle"), HasName("test")))
+ .WillOnce(Return(false));
+
+ EXPECT_CALL(
+ CallbacksHandle,
+ runBeforeSkippedPass(HasNameRegex("MockPassHandle"), HasName("test")))
+ .Times(1);
+
+ EXPECT_CALL(AnalysisHandle, run(HasName("test"), _)).Times(0);
+ EXPECT_CALL(PassHandle, run(HasName("test"), _)).Times(0);
+
+ // As the pass is skipped there is no afterPass, beforeAnalysis/afterAnalysis
+ // as well.
+ EXPECT_CALL(CallbacksHandle,
+ runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), _))
+ .Times(0);
+ EXPECT_CALL(CallbacksHandle,
+ runAfterPass(HasNameRegex("MockPassHandle"), _, _))
+ .Times(0);
+ EXPECT_CALL(CallbacksHandle,
+ runAfterPassInvalidated(HasNameRegex("MockPassHandle"), _))
+ .Times(0);
+ EXPECT_CALL(CallbacksHandle,
+ runAfterPass(HasNameRegex("MockPassHandle"), _, _))
+ .Times(0);
+ EXPECT_CALL(CallbacksHandle,
+ runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), _))
+ .Times(0);
+ EXPECT_CALL(CallbacksHandle,
+ runAfterAnalysis(HasNameRegex("MockAnalysisHandle"), _))
+ .Times(0);
+
+ StringRef PipelineText = "test-transform";
+ ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+ << "Pipeline was: " << PipelineText;
+ ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+} // end anonymous namespace
More information about the llvm-commits
mailing list