[llvm] [CodeGen] Let `PassBuilder` support machine passes (PR #76320)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 11 18:22:01 PST 2024
================
@@ -0,0 +1,499 @@
+//===- unittests/MIR/PassBuilderCallbacksTest.cpp - PB Callback Tests --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Testing/Support/Error.h"
+#include <functional>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <llvm/ADT/Any.h>
+#include <llvm/AsmParser/Parser.h>
+#include <llvm/CodeGen/MIRParser/MIRParser.h>
+#include <llvm/CodeGen/MachineFunction.h>
+#include <llvm/CodeGen/MachineModuleInfo.h>
+#include <llvm/CodeGen/MachinePassManager.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/PassInstrumentation.h>
+#include <llvm/IR/PassManager.h>
+#include <llvm/Passes/PassBuilder.h>
+#include <llvm/Support/Regex.h>
+#include <llvm/Support/SourceMgr.h>
+#include <llvm/Support/TargetSelect.h>
+
+using namespace llvm;
+
+namespace {
+using testing::_;
+using testing::AnyNumber;
+using testing::DoAll;
+using testing::Not;
+using testing::Return;
+using testing::WithArgs;
+
+StringRef MIRString = R"MIR(
+--- |
+ define void @test() {
+ ret void
+ }
+...
+---
+name: test
+body: |
+ bb.0 (%ir-block.0):
+ RET64
+...
+)MIR";
+
+/// Helper for HasName matcher that returns getName both for IRUnit and
+/// for IRUnit pointer wrapper into llvm::Any (wrapped by PassInstrumentation).
+template <typename IRUnitT> std::string getName(const IRUnitT &IR) {
+ return std::string(IR.getName());
+}
+
+template <> std::string getName(const StringRef &name) {
+ return std::string(name);
+}
+
+template <> std::string getName(const Any &WrappedIR) {
+ if (const auto *const *M = llvm::any_cast<const Module *>(&WrappedIR))
+ return (*M)->getName().str();
+ if (const auto *const *F = llvm::any_cast<const Function *>(&WrappedIR))
+ return (*F)->getName().str();
+ if (const auto *const *MF =
+ llvm::any_cast<const MachineFunction *>(&WrappedIR))
+ return (*MF)->getName().str();
+ return "<UNKNOWN>";
+}
+/// Define a custom matcher for objects which support a 'getName' method.
+///
+/// LLVM often has IR objects or analysis objects which expose a name
+/// and in tests it is convenient to match these by name for readability.
+/// Usually, this name is either a StringRef or a plain std::string. This
+/// matcher supports any type exposing a getName() method of this form whose
+/// return value is compatible with an std::ostream. For StringRef, this uses
+/// the shift operator defined above.
+///
+/// It should be used as:
+///
+/// HasName("my_function")
+///
+/// No namespace or other qualification is required.
+MATCHER_P(HasName, Name, "") {
+ *result_listener << "has name '" << getName(arg) << "'";
+ return Name == getName(arg);
+}
+
+MATCHER_P(HasNameRegex, Name, "") {
+ *result_listener << "has name '" << getName(arg) << "'";
+ llvm::Regex r(Name);
+ return r.match(getName(arg));
+}
+
+struct MockPassInstrumentationCallbacks {
+ PassInstrumentationCallbacks Callbacks;
+
+ MockPassInstrumentationCallbacks() {
+ ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true));
+ }
+ MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runBeforeSkippedPass, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD3(runAfterPass,
+ void(StringRef PassID, llvm::Any, const PreservedAnalyses &PA));
+ MOCK_METHOD2(runAfterPassInvalidated,
+ void(StringRef PassID, const PreservedAnalyses &PA));
+ MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any));
+ MOCK_METHOD2(runAfterAnalysis, void(StringRef PassID, llvm::Any));
+
+ void registerPassInstrumentation() {
+ Callbacks.registerShouldRunOptionalPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ return this->runBeforePass(P, IR);
+ });
+ Callbacks.registerBeforeSkippedPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ this->runBeforeSkippedPass(P, IR);
+ });
+ Callbacks.registerBeforeNonSkippedPassCallback(
+ [this](StringRef P, llvm::Any IR) {
+ this->runBeforeNonSkippedPass(P, IR);
+ });
+ Callbacks.registerAfterPassCallback(
+ [this](StringRef P, llvm::Any IR, const PreservedAnalyses &PA) {
+ this->runAfterPass(P, IR, PA);
+ });
+ Callbacks.registerAfterPassInvalidatedCallback(
+ [this](StringRef P, const PreservedAnalyses &PA) {
+ this->runAfterPassInvalidated(P, PA);
+ });
+ Callbacks.registerBeforeAnalysisCallback([this](StringRef P, llvm::Any IR) {
+ return this->runBeforeAnalysis(P, IR);
+ });
+ Callbacks.registerAfterAnalysisCallback(
+ [this](StringRef P, llvm::Any IR) { this->runAfterAnalysis(P, IR); });
+ }
+
+ void ignoreNonMockPassInstrumentation(StringRef IRName) {
+ // Generic EXPECT_CALLs are needed to match instrumentation on unimportant
+ // parts of a pipeline that we do not care about (e.g. various passes added
+ // by default by PassBuilder - Verifier pass etc).
+ // Make sure to avoid ignoring Mock passes/analysis, we definitely want
+ // to check these explicitly.
+ EXPECT_CALL(*this,
+ runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(
+ *this, runBeforeSkippedPass(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName), _))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runBeforeAnalysis(HasNameRegex("MachineModuleAnalysis"),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runBeforeAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this, runAfterAnalysis(HasNameRegex("MachineModuleAnalysis"),
+ HasName(IRName)))
+ .Times(AnyNumber());
+ EXPECT_CALL(*this,
+ runAfterAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+ .Times(AnyNumber());
+ }
+};
+
+template <typename DerivedT> class MockAnalysisHandleBase {
+public:
+ class Analysis : public AnalysisInfoMixin<Analysis> {
+ friend AnalysisInfoMixin<Analysis>;
+ friend MockAnalysisHandleBase;
+ static AnalysisKey Key;
+
+ DerivedT *Handle;
+
+ Analysis(DerivedT &Handle) : Handle(&Handle) {
+ static_assert(std::is_base_of<MockAnalysisHandleBase, DerivedT>::value,
+ "Must pass the derived type to this template!");
+ }
+
+ public:
+ class Result {
+ friend MockAnalysisHandleBase;
+
+ DerivedT *Handle;
+
+ Result(DerivedT &Handle) : Handle(&Handle) {}
+
+ public:
+ // Forward invalidation events to the mock handle.
+ bool invalidate(MachineFunction &IR, const PreservedAnalyses &PA,
+ MachineFunctionAnalysisManager::Invalidator &Inv) {
+ return Handle->invalidate(IR, PA, Inv);
+ }
+ };
+
+ Result run(MachineFunction &IR, MachineFunctionAnalysisManager::Base &AM) {
----------------
paperchalice wrote:
`MachineFunctionPassManager` can run on module directly (see the [run](https://github.com/llvm/llvm-project/blob/9095eec0524d39d447d6f94cd3f9896cc5fc656f/llvm/lib/CodeGen/MachinePassManager.cpp#L25-L34)`in MachinePassManager`), which seems a contrary to the pass manager's design.
https://github.com/llvm/llvm-project/pull/76320
More information about the llvm-commits
mailing list