[llvm] [CodeGen] Let `PassBuilder` support machine passes (PR #76320)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 8 19:34:29 PST 2024


https://github.com/paperchalice updated https://github.com/llvm/llvm-project/pull/76320

>From 0a6fe7476748ab01ccaa9a93714c425e940b6101 Mon Sep 17 00:00:00 2001
From: PaperChalice <liujunchang97 at outlook.com>
Date: Sun, 24 Dec 2023 15:19:12 +0800
Subject: [PATCH] [CodeGen] Let `PassBuilder` support machine passes

`PassBuilder` would be a better place to parse MIR pipeline. We can reuse the code to support parsing pass with parameters and targets can reuse `registerPassBuilderCallbacks` to register the target specific passes. `PassBuilder` also has ability to check whether a Pass is a machine pass, so it can replace part of the work of `LLVMTargetMachine::getPassNameFromLegacyName`.
---
 llvm/include/llvm/Passes/PassBuilder.h        |  39 ++
 llvm/include/llvm/Target/TargetMachine.h      |   3 +-
 llvm/lib/Passes/PassBuilder.cpp               | 105 ++++
 llvm/unittests/MIR/CMakeLists.txt             |   2 +
 .../MIR/PassBuilderCallbacksTest.cpp          | 499 ++++++++++++++++++
 5 files changed, 646 insertions(+), 2 deletions(-)
 create mode 100644 llvm/unittests/MIR/PassBuilderCallbacksTest.cpp

diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h
index 61417431f8a8f3..6b0ad7e7d11a57 100644
--- a/llvm/include/llvm/Passes/PassBuilder.h
+++ b/llvm/include/llvm/Passes/PassBuilder.h
@@ -16,6 +16,7 @@
 #define LLVM_PASSES_PASSBUILDER_H
 
 #include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/CodeGen/MachinePassManager.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Support/Error.h"
@@ -165,6 +166,14 @@ class PassBuilder {
   /// additional analyses.
   void registerLoopAnalyses(LoopAnalysisManager &LAM);
 
+  /// Registers all available machine function analysis passes.
+  ///
+  /// This is an interface that can be used to populate a \c
+  /// MachineFunctionAnalysisManager with all registered function analyses.
+  /// Callers can still manually register any additional analyses. Callers can
+  /// also pre-register analyses and this will not override those.
+  void registerMachineFunctionAnalyses(MachineFunctionAnalysisManager &MFAM);
+
   /// Construct the core LLVM function canonicalization and simplification
   /// pipeline.
   ///
@@ -352,6 +361,18 @@ class PassBuilder {
   Error parsePassPipeline(LoopPassManager &LPM, StringRef PipelineText);
   /// @}}
 
+  /// Parse a textual MIR pipeline into the provided \c MachineFunctionPass
+  /// manager.
+  /// The format of the textual machine pipeline is a comma separated list of
+  /// machine pass names:
+  ///
+  ///   machine-funciton-pass,machine-module-pass,...
+  ///
+  /// There is no need to specify the pass nesting, and this function
+  /// currently cannot handle the pass nesting.
+  Error parsePassPipeline(MachineFunctionPassManager &MFPM,
+                          StringRef PipelineText);
+
   /// Parse a textual alias analysis pipeline into the provided AA manager.
   ///
   /// The format of the textual AA pipeline is a comma separated list of AA
@@ -520,6 +541,10 @@ class PassBuilder {
       const std::function<void(ModuleAnalysisManager &)> &C) {
     ModuleAnalysisRegistrationCallbacks.push_back(C);
   }
+  void registerAnalysisRegistrationCallback(
+      const std::function<void(MachineFunctionAnalysisManager &)> &C) {
+    MachineFunctionAnalysisRegistrationCallbacks.push_back(C);
+  }
   /// @}}
 
   /// {{@ Register pipeline parsing callbacks with this pass builder instance.
@@ -546,6 +571,11 @@ class PassBuilder {
                                ArrayRef<PipelineElement>)> &C) {
     ModulePipelineParsingCallbacks.push_back(C);
   }
+  void registerPipelineParsingCallback(
+      const std::function<bool(StringRef Name, MachineFunctionPassManager &)>
+          &C) {
+    MachinePipelineParsingCallbacks.push_back(C);
+  }
   /// @}}
 
   /// Register a callback for a top-level pipeline entry.
@@ -616,8 +646,12 @@ class PassBuilder {
   Error parseCGSCCPass(CGSCCPassManager &CGPM, const PipelineElement &E);
   Error parseFunctionPass(FunctionPassManager &FPM, const PipelineElement &E);
   Error parseLoopPass(LoopPassManager &LPM, const PipelineElement &E);
+  Error parseMachinePass(MachineFunctionPassManager &MFPM,
+                         const PipelineElement &E);
   bool parseAAPassName(AAManager &AA, StringRef Name);
 
+  Error parseMachinePassPipeline(MachineFunctionPassManager &MFPM,
+                                 ArrayRef<PipelineElement> Pipeline);
   Error parseLoopPassPipeline(LoopPassManager &LPM,
                               ArrayRef<PipelineElement> Pipeline);
   Error parseFunctionPassPipeline(FunctionPassManager &FPM,
@@ -699,6 +733,11 @@ class PassBuilder {
   // AA callbacks
   SmallVector<std::function<bool(StringRef Name, AAManager &AA)>, 2>
       AAParsingCallbacks;
+  // Machine pass callbackcs
+  SmallVector<std::function<void(MachineFunctionAnalysisManager &)>, 2>
+      MachineFunctionAnalysisRegistrationCallbacks;
+  SmallVector<std::function<bool(StringRef, MachineFunctionPassManager &)>, 2>
+      MachinePipelineParsingCallbacks;
 };
 
 /// This utility template takes care of adding require<> and invalidate<>
diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h
index 1fe47dec70b163..ec82cf90521551 100644
--- a/llvm/include/llvm/Target/TargetMachine.h
+++ b/llvm/include/llvm/Target/TargetMachine.h
@@ -463,8 +463,7 @@ class LLVMTargetMachine : public TargetMachine {
   }
 
   virtual std::pair<StringRef, bool> getPassNameFromLegacyName(StringRef) {
-    llvm_unreachable(
-        "getPassNameFromLegacyName parseMIRPipeline is not overridden");
+    llvm_unreachable("getPassNameFromLegacyName is not overridden");
   }
 
   /// Add passes to the specified pass manager to get machine code emitted with
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 649451edc0e2c6..85e39d0af11523 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -73,6 +73,7 @@
 #include "llvm/Analysis/TypeBasedAliasAnalysis.h"
 #include "llvm/Analysis/UniformityAnalysis.h"
 #include "llvm/CodeGen/CallBrPrepare.h"
+#include "llvm/CodeGen/CodeGenPassBuilder.h"
 #include "llvm/CodeGen/DwarfEHPrepare.h"
 #include "llvm/CodeGen/ExpandLargeDivRem.h"
 #include "llvm/CodeGen/ExpandLargeFpConvert.h"
@@ -484,6 +485,18 @@ PassBuilder::PassBuilder(TargetMachine *TM, PipelineTuningOptions PTO,
 #define CGSCC_ANALYSIS(NAME, CREATE_PASS)                                      \
   PIC->addClassToPassName(decltype(CREATE_PASS)::name(), NAME);
 #include "PassRegistry.def"
+
+#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR)                \
+  PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)                    \
+  PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)                      \
+  PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR)                \
+  PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)              \
+  PIC->addClassToPassName(PASS_NAME::name(), NAME);
+#include "llvm/CodeGen/MachinePassRegistry.def"
   }
 }
 
@@ -519,6 +532,19 @@ void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) {
     C(FAM);
 }
 
+void PassBuilder::registerMachineFunctionAnalyses(
+    MachineFunctionAnalysisManager &MFAM) {
+
+#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR)                \
+  MFAM.registerPass([&] { return PASS_NAME(); });
+#define DUMMY_MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR)          \
+  MFAM.registerPass([&] { return PASS_NAME(); });
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+  for (auto &C : MachineFunctionAnalysisRegistrationCallbacks)
+    C(MFAM);
+}
+
 void PassBuilder::registerLoopAnalyses(LoopAnalysisManager &LAM) {
 #define LOOP_ANALYSIS(NAME, CREATE_PASS)                                       \
   LAM.registerPass([&] { return CREATE_PASS; });
@@ -1873,6 +1899,43 @@ Error PassBuilder::parseLoopPass(LoopPassManager &LPM,
                                  inconvertibleErrorCode());
 }
 
+Error PassBuilder::parseMachinePass(MachineFunctionPassManager &MFPM,
+                                    const PipelineElement &E) {
+  StringRef Name = E.Name;
+  if (!E.InnerPipeline.empty())
+    return make_error<StringError>("invalid pipeline",
+                                   inconvertibleErrorCode());
+
+#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR)                      \
+  if (Name == NAME) {                                                          \
+    MFPM.addPass(PASS_NAME());                                                 \
+    return Error::success();                                                   \
+  }
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)                    \
+  if (Name == NAME) {                                                          \
+    MFPM.addPass(PASS_NAME());                                                 \
+    return Error::success();                                                   \
+  }
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR)                \
+  if (Name == NAME) {                                                          \
+    MFPM.addPass(PASS_NAME());                                                 \
+    return Error::success();                                                   \
+  }
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)              \
+  if (Name == NAME) {                                                          \
+    MFPM.addPass(PASS_NAME());                                                 \
+    return Error::success();                                                   \
+  }
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+  for (auto &C : MachinePipelineParsingCallbacks)
+    if (C(Name, MFPM))
+      return Error::success();
+  return make_error<StringError>(
+      formatv("unknown machine pass '{0}'", Name).str(),
+      inconvertibleErrorCode());
+}
+
 bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) {
 #define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS)                               \
   if (Name == NAME) {                                                          \
@@ -1894,6 +1957,15 @@ bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) {
   return false;
 }
 
+Error PassBuilder::parseMachinePassPipeline(
+    MachineFunctionPassManager &MFPM, ArrayRef<PipelineElement> Pipeline) {
+  for (const auto &Element : Pipeline) {
+    if (auto Err = parseMachinePass(MFPM, Element))
+      return Err;
+  }
+  return Error::success();
+}
+
 Error PassBuilder::parseLoopPassPipeline(LoopPassManager &LPM,
                                          ArrayRef<PipelineElement> Pipeline) {
   for (const auto &Element : Pipeline) {
@@ -2053,6 +2125,20 @@ Error PassBuilder::parsePassPipeline(LoopPassManager &CGPM,
   return Error::success();
 }
 
+Error PassBuilder::parsePassPipeline(MachineFunctionPassManager &MFPM,
+                                     StringRef PipelineText) {
+  auto Pipeline = parsePipelineText(PipelineText);
+  if (!Pipeline || Pipeline->empty())
+    return make_error<StringError>(
+        formatv("invalid machine pass pipeline '{0}'", PipelineText).str(),
+        inconvertibleErrorCode());
+
+  if (auto Err = parseMachinePassPipeline(MFPM, *Pipeline))
+    return Err;
+
+  return Error::success();
+}
+
 Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) {
   // If the pipeline just consists of the word 'default' just replace the AA
   // manager with our default one.
@@ -2147,6 +2233,25 @@ void PassBuilder::printPassNames(raw_ostream &OS) {
   OS << "Loop analyses:\n";
 #define LOOP_ANALYSIS(NAME, CREATE_PASS) printPassName(NAME, OS);
 #include "PassRegistry.def"
+
+  OS << "Machine module passes (WIP):\n";
+#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR)                      \
+  printPassName(NAME, OS);
+#define DUMMY_MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR)                \
+  printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+  OS << "Machine function passes (WIP):\n";
+#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)                    \
+  printPassName(NAME, OS);
+#define DUMMY_MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR)              \
+  printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
+
+  OS << "Machine function analyses (WIP):\n";
+#define DUMMY_MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR)          \
+  printPassName(NAME, OS);
+#include "llvm/CodeGen/MachinePassRegistry.def"
 }
 
 void PassBuilder::registerParseTopLevelPipelineCallback(
diff --git a/llvm/unittests/MIR/CMakeLists.txt b/llvm/unittests/MIR/CMakeLists.txt
index 3c0e9e43f9afb1..f485dcbd971b6a 100644
--- a/llvm/unittests/MIR/CMakeLists.txt
+++ b/llvm/unittests/MIR/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
   FileCheck
   MC
   MIRParser
+  Passes
   Support
   Target
   TargetParser
@@ -13,6 +14,7 @@ set(LLVM_LINK_COMPONENTS
 
 add_llvm_unittest(MIRTests
   MachineMetadata.cpp
+  PassBuilderCallbacksTest.cpp
   )
 
 target_link_libraries(MIRTests PRIVATE LLVMTestingSupport)
diff --git a/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp
new file mode 100644
index 00000000000000..29e238e65c12a4
--- /dev/null
+++ b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp
@@ -0,0 +1,499 @@
+//===- unittests/MIR/PassBuilderCallbacksTest.cpp - PB Callback Tests --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Testing/Support/Error.h"
+#include <functional>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <llvm/ADT/Any.h>
+#include <llvm/AsmParser/Parser.h>
+#include <llvm/CodeGen/MIRParser/MIRParser.h>
+#include <llvm/CodeGen/MachineFunction.h>
+#include <llvm/CodeGen/MachineModuleInfo.h>
+#include <llvm/CodeGen/MachinePassManager.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/PassInstrumentation.h>
+#include <llvm/IR/PassManager.h>
+#include <llvm/Passes/PassBuilder.h>
+#include <llvm/Support/Regex.h>
+#include <llvm/Support/SourceMgr.h>
+#include <llvm/Support/TargetSelect.h>
+
+using namespace llvm;
+
+namespace {
+using testing::_;
+using testing::AnyNumber;
+using testing::DoAll;
+using testing::Not;
+using testing::Return;
+using testing::WithArgs;
+
+StringRef MIRString = R"MIR(
+--- |
+  define void @test() {
+    ret void
+  }
+...
+---
+name:            test
+body:             |
+  bb.0 (%ir-block.0):
+    RET64
+...
+)MIR";
+
+/// Helper for HasName matcher that returns getName both for IRUnit and
+/// for IRUnit pointer wrapper into llvm::Any (wrapped by PassInstrumentation).
+template <typename IRUnitT> std::string getName(const IRUnitT &IR) {
+  return std::string(IR.getName());
+}
+
+template <> std::string getName(const StringRef &name) {
+  return std::string(name);
+}
+
+template <> std::string getName(const Any &WrappedIR) {
+  if (const auto *const *M = llvm::any_cast<const Module *>(&WrappedIR))
+    return (*M)->getName().str();
+  if (const auto *const *F = llvm::any_cast<const Function *>(&WrappedIR))
+    return (*F)->getName().str();
+  if (const auto *const *MF =
+          llvm::any_cast<const MachineFunction *>(&WrappedIR))
+    return (*MF)->getName().str();
+  return "<UNKNOWN>";
+}
+/// Define a custom matcher for objects which support a 'getName' method.
+///
+/// LLVM often has IR objects or analysis objects which expose a name
+/// and in tests it is convenient to match these by name for readability.
+/// Usually, this name is either a StringRef or a plain std::string. This
+/// matcher supports any type exposing a getName() method of this form whose
+/// return value is compatible with an std::ostream. For StringRef, this uses
+/// the shift operator defined above.
+///
+/// It should be used as:
+///
+///   HasName("my_function")
+///
+/// No namespace or other qualification is required.
+MATCHER_P(HasName, Name, "") {
+  *result_listener << "has name '" << getName(arg) << "'";
+  return Name == getName(arg);
+}
+
+MATCHER_P(HasNameRegex, Name, "") {
+  *result_listener << "has name '" << getName(arg) << "'";
+  llvm::Regex r(Name);
+  return r.match(getName(arg));
+}
+
+struct MockPassInstrumentationCallbacks {
+  PassInstrumentationCallbacks Callbacks;
+
+  MockPassInstrumentationCallbacks() {
+    ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true));
+  }
+  MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runBeforeSkippedPass, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD3(runAfterPass,
+               void(StringRef PassID, llvm::Any, const PreservedAnalyses &PA));
+  MOCK_METHOD2(runAfterPassInvalidated,
+               void(StringRef PassID, const PreservedAnalyses &PA));
+  MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runAfterAnalysis, void(StringRef PassID, llvm::Any));
+
+  void registerPassInstrumentation() {
+    Callbacks.registerShouldRunOptionalPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          return this->runBeforePass(P, IR);
+        });
+    Callbacks.registerBeforeSkippedPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          this->runBeforeSkippedPass(P, IR);
+        });
+    Callbacks.registerBeforeNonSkippedPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          this->runBeforeNonSkippedPass(P, IR);
+        });
+    Callbacks.registerAfterPassCallback(
+        [this](StringRef P, llvm::Any IR, const PreservedAnalyses &PA) {
+          this->runAfterPass(P, IR, PA);
+        });
+    Callbacks.registerAfterPassInvalidatedCallback(
+        [this](StringRef P, const PreservedAnalyses &PA) {
+          this->runAfterPassInvalidated(P, PA);
+        });
+    Callbacks.registerBeforeAnalysisCallback([this](StringRef P, llvm::Any IR) {
+      return this->runBeforeAnalysis(P, IR);
+    });
+    Callbacks.registerAfterAnalysisCallback(
+        [this](StringRef P, llvm::Any IR) { this->runAfterAnalysis(P, IR); });
+  }
+
+  void ignoreNonMockPassInstrumentation(StringRef IRName) {
+    // Generic EXPECT_CALLs are needed to match instrumentation on unimportant
+    // parts of a pipeline that we do not care about (e.g. various passes added
+    // by default by PassBuilder - Verifier pass etc).
+    // Make sure to avoid ignoring Mock passes/analysis, we definitely want
+    // to check these explicitly.
+    EXPECT_CALL(*this,
+                runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(
+        *this, runBeforeSkippedPass(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")),
+                                               HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName), _))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runBeforeAnalysis(HasNameRegex("MachineModuleAnalysis"),
+                                         HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runBeforeAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runAfterAnalysis(HasNameRegex("MachineModuleAnalysis"),
+                                        HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runAfterAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+  }
+};
+
+template <typename DerivedT> class MockAnalysisHandleBase {
+public:
+  class Analysis : public AnalysisInfoMixin<Analysis> {
+    friend AnalysisInfoMixin<Analysis>;
+    friend MockAnalysisHandleBase;
+    static AnalysisKey Key;
+
+    DerivedT *Handle;
+
+    Analysis(DerivedT &Handle) : Handle(&Handle) {
+      static_assert(std::is_base_of<MockAnalysisHandleBase, DerivedT>::value,
+                    "Must pass the derived type to this template!");
+    }
+
+  public:
+    class Result {
+      friend MockAnalysisHandleBase;
+
+      DerivedT *Handle;
+
+      Result(DerivedT &Handle) : Handle(&Handle) {}
+
+    public:
+      // Forward invalidation events to the mock handle.
+      bool invalidate(MachineFunction &IR, const PreservedAnalyses &PA,
+                      MachineFunctionAnalysisManager::Invalidator &Inv) {
+        return Handle->invalidate(IR, PA, Inv);
+      }
+    };
+
+    Result run(MachineFunction &IR, MachineFunctionAnalysisManager::Base &AM) {
+      return Handle->run(IR, AM);
+    }
+  };
+
+  Analysis getAnalysis() { return Analysis(static_cast<DerivedT &>(*this)); }
+  typename Analysis::Result getResult() {
+    return typename Analysis::Result(static_cast<DerivedT &>(*this));
+  }
+  static StringRef getName() { return llvm::getTypeName<DerivedT>(); }
+
+protected:
+  // FIXME: MSVC seems unable to handle a lambda argument to Invoke from within
+  // the template, so we use a boring static function.
+  static bool
+  invalidateCallback(MachineFunction &IR, const PreservedAnalyses &PA,
+                     MachineFunctionAnalysisManager::Invalidator &Inv) {
+    auto PAC = PA.template getChecker<Analysis>();
+    return !PAC.preserved() &&
+           !PAC.template preservedSet<AllAnalysesOn<MachineFunction>>();
+  }
+
+  /// Derived classes should call this in their constructor to set up default
+  /// mock actions. (We can't do this in our constructor because this has to
+  /// run after the DerivedT is constructed.)
+  void setDefaults() {
+    ON_CALL(static_cast<DerivedT &>(*this), run(_, _))
+        .WillByDefault(Return(this->getResult()));
+    ON_CALL(static_cast<DerivedT &>(*this), invalidate(_, _, _))
+        .WillByDefault(&invalidateCallback);
+  }
+};
+
+template <typename DerivedT> class MockPassHandleBase {
+public:
+  class Pass : public MachinePassInfoMixin<Pass> {
+    friend MockPassHandleBase;
+
+    DerivedT *Handle;
+
+    Pass(DerivedT &Handle) : Handle(&Handle) {
+      static_assert(std::is_base_of<MockPassHandleBase, DerivedT>::value,
+                    "Must pass the derived type to this template!");
+    }
+
+  public:
+    static MachinePassKey Key;
+    PreservedAnalyses run(MachineFunction &IR,
+                          MachineFunctionAnalysisManager::Base &AM) {
+      return Handle->run(IR, AM);
+    }
+  };
+
+  static StringRef getName() { return llvm::getTypeName<DerivedT>(); }
+
+  Pass getPass() { return Pass(static_cast<DerivedT &>(*this)); }
+
+protected:
+  /// Derived classes should call this in their constructor to set up default
+  /// mock actions. (We can't do this in our constructor because this has to
+  /// run after the DerivedT is constructed.)
+  void setDefaults() {
+    ON_CALL(static_cast<DerivedT &>(*this), run(_, _))
+        .WillByDefault(Return(PreservedAnalyses::all()));
+  }
+};
+
+struct MockAnalysisHandle : public MockAnalysisHandleBase<MockAnalysisHandle> {
+  MOCK_METHOD2(run, Analysis::Result(MachineFunction &,
+                                     MachineFunctionAnalysisManager::Base &));
+
+  MOCK_METHOD3(invalidate, bool(MachineFunction &, const PreservedAnalyses &,
+                                MachineFunctionAnalysisManager::Invalidator &));
+
+  MockAnalysisHandle() { setDefaults(); }
+};
+
+template <typename DerivedT>
+MachinePassKey MockPassHandleBase<DerivedT>::Pass::Key;
+
+template <typename DerivedT>
+AnalysisKey MockAnalysisHandleBase<DerivedT>::Analysis::Key;
+
+class MockPassHandle : public MockPassHandleBase<MockPassHandle> {
+public:
+  MOCK_METHOD2(run, PreservedAnalyses(MachineFunction &,
+                                      MachineFunctionAnalysisManager::Base &));
+
+  MockPassHandle() { setDefaults(); }
+};
+
+class MachineFunctionCallbacksTest : public testing::Test {
+protected:
+  static void SetUpTestCase() {
+    InitializeAllTargetInfos();
+    InitializeAllTargets();
+    InitializeAllTargetMCs();
+  }
+
+  TargetMachine *TM;
+
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<MIRParser> MIR;
+
+  MockPassInstrumentationCallbacks CallbacksHandle;
+
+  PassBuilder PB;
+  ModulePassManager PM;
+  MachineFunctionPassManager MFPM;
+  FunctionAnalysisManager FAM;
+  ModuleAnalysisManager AM;
+  MachineFunctionAnalysisManager MFAM;
+
+  MockPassHandle PassHandle;
+  MockAnalysisHandle AnalysisHandle;
+
+  std::unique_ptr<Module> parseMIR(const TargetMachine &TM, StringRef MIRCode,
+                                   MachineModuleInfo &MMI) {
+    SMDiagnostic Diagnostic;
+    std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
+    MIR = createMIRParser(std::move(MBuffer), Context);
+    if (!MIR)
+      return nullptr;
+
+    std::unique_ptr<Module> Mod = MIR->parseIRModule();
+    if (!Mod)
+      return nullptr;
+
+    Mod->setDataLayout(TM.createDataLayout());
+
+    if (MIR->parseMachineFunctions(*Mod, MMI)) {
+      M.reset();
+      return nullptr;
+    }
+    return Mod;
+  }
+
+  static PreservedAnalyses
+  getAnalysisResult(MachineFunction &U,
+                    MachineFunctionAnalysisManager::Base &AM) {
+    auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
+    MFAM.getResult<MockAnalysisHandle::Analysis>(U);
+    return PreservedAnalyses::all();
+  }
+
+  void SetUp() override {
+    std::string Error;
+    auto TripleName = "x86_64-pc-linux-gnu";
+    auto *T = TargetRegistry::lookupTarget(TripleName, Error);
+    if (!T)
+      GTEST_SKIP();
+    TM = T->createTargetMachine(TripleName, "", "", TargetOptions(),
+                                std::nullopt);
+    MachineModuleInfo MMI(static_cast<LLVMTargetMachine *>(TM));
+    M = parseMIR(*TM, MIRString, MMI);
+    AM.registerPass([&] {
+      return MachineModuleAnalysis(static_cast<LLVMTargetMachine *>(TM));
+    });
+  }
+
+  MachineFunctionCallbacksTest()
+      : CallbacksHandle(), PB(nullptr, PipelineTuningOptions(), std::nullopt,
+                              &CallbacksHandle.Callbacks),
+        PM(), FAM(), AM(), MFAM(FAM, AM) {
+
+    EXPECT_TRUE(&CallbacksHandle.Callbacks ==
+                PB.getPassInstrumentationCallbacks());
+
+    /// Register a callback for analysis registration.
+    ///
+    /// The callback is a function taking a reference to an AnalyisManager
+    /// object. When called, the callee gets to register its own analyses with
+    /// this PassBuilder instance.
+    PB.registerAnalysisRegistrationCallback(
+        [this](MachineFunctionAnalysisManager &AM) {
+          // Register our mock analysis
+          AM.registerPass([this] { return AnalysisHandle.getAnalysis(); });
+        });
+
+    /// Register a callback for pipeline parsing.
+    ///
+    /// During parsing of a textual pipeline, the PassBuilder will call these
+    /// callbacks for each encountered pass name that it does not know. This
+    /// includes both simple pass names as well as names of sub-pipelines. In
+    /// the latter case, the InnerPipeline is not empty.
+    PB.registerPipelineParsingCallback(
+        [this](StringRef Name, MachineFunctionPassManager &PM) {
+          if (parseAnalysisUtilityPasses<MockAnalysisHandle::Analysis>(
+                  "test-analysis", Name, PM))
+            return true;
+
+          /// Parse the name of our pass mock handle
+          if (Name == "test-transform") {
+            MFPM.addPass(PassHandle.getPass());
+            return true;
+          }
+          return false;
+        });
+
+    /// Register builtin analyses and cross-register the analysis proxies
+    PB.registerModuleAnalyses(AM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerMachineFunctionAnalyses(MFAM);
+  }
+};
+
+TEST_F(MachineFunctionCallbacksTest, Passes) {
+  EXPECT_CALL(AnalysisHandle, run(HasName("test"), _));
+  EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult);
+
+  StringRef PipelineText = "test-transform";
+  ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+      << "Pipeline was: " << PipelineText;
+  ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+TEST_F(MachineFunctionCallbacksTest, InstrumentedPasses) {
+  CallbacksHandle.registerPassInstrumentation();
+  // Non-mock instrumentation not specifically mentioned below can be ignored.
+  CallbacksHandle.ignoreNonMockPassInstrumentation("<string>");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("test");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("");
+
+  // PassInstrumentation calls should happen in-sequence, in the same order
+  // as passes/analyses are scheduled.
+  ::testing::Sequence PISequence;
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle"), HasName("test")))
+      .InSequence(PISequence);
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("test")))
+      .InSequence(PISequence);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle"), HasName("test"), _))
+      .InSequence(PISequence);
+
+  EXPECT_CALL(AnalysisHandle, run(HasName("test"), _));
+  EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult);
+
+  StringRef PipelineText = "test-transform";
+  ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+      << "Pipeline was: " << PipelineText;
+  ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+TEST_F(MachineFunctionCallbacksTest, InstrumentedSkippedPasses) {
+  CallbacksHandle.registerPassInstrumentation();
+  // Non-mock instrumentation run here can safely be ignored.
+  CallbacksHandle.ignoreNonMockPassInstrumentation("<string>");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("test");
+  CallbacksHandle.ignoreNonMockPassInstrumentation("");
+
+  // Skip the pass by returning false.
+  EXPECT_CALL(CallbacksHandle,
+              runBeforePass(HasNameRegex("MockPassHandle"), HasName("test")))
+      .WillOnce(Return(false));
+
+  EXPECT_CALL(
+      CallbacksHandle,
+      runBeforeSkippedPass(HasNameRegex("MockPassHandle"), HasName("test")))
+      .Times(1);
+
+  EXPECT_CALL(AnalysisHandle, run(HasName("test"), _)).Times(0);
+  EXPECT_CALL(PassHandle, run(HasName("test"), _)).Times(0);
+
+  // As the pass is skipped there is no afterPass, beforeAnalysis/afterAnalysis
+  // as well.
+  EXPECT_CALL(CallbacksHandle,
+              runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle"), _, _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPassInvalidated(HasNameRegex("MockPassHandle"), _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterPass(HasNameRegex("MockPassHandle"), _, _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), _))
+      .Times(0);
+  EXPECT_CALL(CallbacksHandle,
+              runAfterAnalysis(HasNameRegex("MockAnalysisHandle"), _))
+      .Times(0);
+
+  StringRef PipelineText = "test-transform";
+  ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded())
+      << "Pipeline was: " << PipelineText;
+  ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded());
+}
+
+} // end anonymous namespace



More information about the llvm-commits mailing list