[llvm] [CodeGen] Let `PassBuilder` support machine passes (PR #76320)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 26 04:23:12 PST 2023


================
@@ -0,0 +1,499 @@
+//===- unittests/MIR/PassBuilderCallbacksTest.cpp - PB Callback Tests --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Testing/Support/Error.h"
+#include <functional>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <llvm/ADT/Any.h>
+#include <llvm/AsmParser/Parser.h>
+#include <llvm/CodeGen/MIRParser/MIRParser.h>
+#include <llvm/CodeGen/MachineFunction.h>
+#include <llvm/CodeGen/MachineModuleInfo.h>
+#include <llvm/CodeGen/MachinePassManager.h>
+#include <llvm/IR/LLVMContext.h>
+#include <llvm/IR/PassInstrumentation.h>
+#include <llvm/IR/PassManager.h>
+#include <llvm/Passes/PassBuilder.h>
+#include <llvm/Support/Regex.h>
+#include <llvm/Support/SourceMgr.h>
+#include <llvm/Support/TargetSelect.h>
+
+using namespace llvm;
+
+namespace {
+using testing::_;
+using testing::AnyNumber;
+using testing::DoAll;
+using testing::Not;
+using testing::Return;
+using testing::WithArgs;
+
+StringRef MIRString = R"MIR(
+--- |
+  define void @test() {
+    ret void
+  }
+...
+---
+name:            test
+body:             |
+  bb.0 (%ir-block.0):
+    RET64
+...
+)MIR";
+
+/// Helper for HasName matcher that returns getName both for IRUnit and
+/// for IRUnit pointer wrapper into llvm::Any (wrapped by PassInstrumentation).
+template <typename IRUnitT> std::string getName(const IRUnitT &IR) {
+  return std::string(IR.getName());
+}
+
+template <> std::string getName(const StringRef &name) {
+  return std::string(name);
+}
+
+template <> std::string getName(const Any &WrappedIR) {
+  if (const auto *const *M = llvm::any_cast<const Module *>(&WrappedIR))
+    return (*M)->getName().str();
+  if (const auto *const *F = llvm::any_cast<const Function *>(&WrappedIR))
+    return (*F)->getName().str();
+  if (const auto *const *MF =
+          llvm::any_cast<const MachineFunction *>(&WrappedIR))
+    return (*MF)->getName().str();
+  return "<UNKNOWN>";
+}
+/// Define a custom matcher for objects which support a 'getName' method.
+///
+/// LLVM often has IR objects or analysis objects which expose a name
+/// and in tests it is convenient to match these by name for readability.
+/// Usually, this name is either a StringRef or a plain std::string. This
+/// matcher supports any type exposing a getName() method of this form whose
+/// return value is compatible with an std::ostream. For StringRef, this uses
+/// the shift operator defined above.
+///
+/// It should be used as:
+///
+///   HasName("my_function")
+///
+/// No namespace or other qualification is required.
+MATCHER_P(HasName, Name, "") {
+  *result_listener << "has name '" << getName(arg) << "'";
+  return Name == getName(arg);
+}
+
+MATCHER_P(HasNameRegex, Name, "") {
+  *result_listener << "has name '" << getName(arg) << "'";
+  llvm::Regex r(Name);
+  return r.match(getName(arg));
+}
+
+struct MockPassInstrumentationCallbacks {
+  PassInstrumentationCallbacks Callbacks;
+
+  MockPassInstrumentationCallbacks() {
+    ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true));
+  }
+  MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runBeforeSkippedPass, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD3(runAfterPass,
+               void(StringRef PassID, llvm::Any, const PreservedAnalyses &PA));
+  MOCK_METHOD2(runAfterPassInvalidated,
+               void(StringRef PassID, const PreservedAnalyses &PA));
+  MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any));
+  MOCK_METHOD2(runAfterAnalysis, void(StringRef PassID, llvm::Any));
+
+  void registerPassInstrumentation() {
+    Callbacks.registerShouldRunOptionalPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          return this->runBeforePass(P, IR);
+        });
+    Callbacks.registerBeforeSkippedPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          this->runBeforeSkippedPass(P, IR);
+        });
+    Callbacks.registerBeforeNonSkippedPassCallback(
+        [this](StringRef P, llvm::Any IR) {
+          this->runBeforeNonSkippedPass(P, IR);
+        });
+    Callbacks.registerAfterPassCallback(
+        [this](StringRef P, llvm::Any IR, const PreservedAnalyses &PA) {
+          this->runAfterPass(P, IR, PA);
+        });
+    Callbacks.registerAfterPassInvalidatedCallback(
+        [this](StringRef P, const PreservedAnalyses &PA) {
+          this->runAfterPassInvalidated(P, PA);
+        });
+    Callbacks.registerBeforeAnalysisCallback([this](StringRef P, llvm::Any IR) {
+      return this->runBeforeAnalysis(P, IR);
+    });
+    Callbacks.registerAfterAnalysisCallback(
+        [this](StringRef P, llvm::Any IR) { this->runAfterAnalysis(P, IR); });
+  }
+
+  void ignoreNonMockPassInstrumentation(StringRef IRName) {
+    // Generic EXPECT_CALLs are needed to match instrumentation on unimportant
+    // parts of a pipeline that we do not care about (e.g. various passes added
+    // by default by PassBuilder - Verifier pass etc).
+    // Make sure to avoid ignoring Mock passes/analysis, we definitely want
+    // to check these explicitly.
+    EXPECT_CALL(*this,
+                runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(
+        *this, runBeforeSkippedPass(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")),
+                                               HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName), _))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runBeforeAnalysis(HasNameRegex("MachineModuleAnalysis"),
+                                         HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runBeforeAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this, runAfterAnalysis(HasNameRegex("MachineModuleAnalysis"),
+                                        HasName(IRName)))
+        .Times(AnyNumber());
+    EXPECT_CALL(*this,
+                runAfterAnalysis(Not(HasNameRegex("Mock")), HasName(IRName)))
+        .Times(AnyNumber());
+  }
+};
+
+template <typename DerivedT> class MockAnalysisHandleBase {
+public:
+  class Analysis : public AnalysisInfoMixin<Analysis> {
+    friend AnalysisInfoMixin<Analysis>;
+    friend MockAnalysisHandleBase;
+    static AnalysisKey Key;
+
+    DerivedT *Handle;
+
+    Analysis(DerivedT &Handle) : Handle(&Handle) {
+      static_assert(std::is_base_of<MockAnalysisHandleBase, DerivedT>::value,
+                    "Must pass the derived type to this template!");
+    }
+
+  public:
+    class Result {
+      friend MockAnalysisHandleBase;
+
+      DerivedT *Handle;
+
+      Result(DerivedT &Handle) : Handle(&Handle) {}
+
+    public:
+      // Forward invalidation events to the mock handle.
+      bool invalidate(MachineFunction &IR, const PreservedAnalyses &PA,
+                      MachineFunctionAnalysisManager::Invalidator &Inv) {
+        return Handle->invalidate(IR, PA, Inv);
+      }
+    };
+
+    Result run(MachineFunction &IR, MachineFunctionAnalysisManager::Base &AM) {
----------------
paperchalice wrote:

Not sure why `MachineFunctionAnalysisManager` is designed this way. An `InnerAnalysisManagerProxy<IRUnit, MachineFunction>` seems also OK.

https://github.com/llvm/llvm-project/pull/76320


More information about the llvm-commits mailing list