[llvm] 911565d - [NewPM][CodeGen] Introduce machine pass and machine pass manager

Yuanfang Chen via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 7 11:00:59 PDT 2020


Author: Yuanfang Chen
Date: 2020-08-07T11:00:31-07:00
New Revision: 911565d1085d9447363fe8ad041817436c4998fe

URL: https://github.com/llvm/llvm-project/commit/911565d1085d9447363fe8ad041817436c4998fe
DIFF: https://github.com/llvm/llvm-project/commit/911565d1085d9447363fe8ad041817436c4998fe.diff

LOG: [NewPM][CodeGen] Introduce machine pass and machine pass manager

machine pass could define four methods:
- `PreservedAnalyses run(MachineFunction &, MachineFunctionAnalysisManager &)`
- `Error doInitialization(Module &, MachineFunctionAnalysisManager &)`
- `Error doFinalization(Module &, MachineFunctionAnalysisManager &)`
- `Error run(Module &, MachineFunctionAnalysisManager &)`

machine pass manger:
- MachineFunctionAnalysisManager:
  Basically an AnalysisManager<MachineFunction> augmented with the ability to
  register and query IR analyses
- MachineFunctionPassManager: support only two methods, `addPass` and `run`

Reviewed By: arsenm, asbirlea, aeubanks

Differential Revision: https://reviews.llvm.org/D67687

Added: 
    llvm/include/llvm/CodeGen/MachinePassManager.h
    llvm/lib/CodeGen/MachinePassManager.cpp
    llvm/unittests/CodeGen/PassManagerTest.cpp

Modified: 
    llvm/include/llvm/IR/PassManager.h
    llvm/lib/CodeGen/CMakeLists.txt
    llvm/lib/CodeGen/LLVMBuild.txt
    llvm/unittests/CodeGen/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachinePassManager.h b/llvm/include/llvm/CodeGen/MachinePassManager.h
new file mode 100644
index 000000000000..3d1684b722fd
--- /dev/null
+++ b/llvm/include/llvm/CodeGen/MachinePassManager.h
@@ -0,0 +1,252 @@
+//===- PassManager.h --- Pass management for CodeGen ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines the pass manager interface for codegen. The codegen
+// pipeline consists of only machine function passes. There is no container
+// relationship between IR module/function and machine function in terms of pass
+// manager organization. So there is no need for adaptor classes (for example
+// ModuleToMachineFunctionAdaptor). Since invalidation could only happen among
+// machine function passes, there is no proxy classes to handle cross-IR-unit
+// invalidation. IR analysis results are provided for machine function passes by
+// their respective analysis managers such as ModuleAnalysisManager and
+// FunctionAnalysisManager.
+//
+// TODO: Add MachineFunctionProperties support.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_MACHINEPASSMANAGER_H
+#define LLVM_CODEGEN_MACHINEPASSMANAGER_H
+
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/type_traits.h"
+
+namespace llvm {
+class Module;
+
+extern template class AnalysisManager<MachineFunction>;
+
+/// An AnalysisManager<MachineFunction> that also exposes IR analysis results.
+class MachineFunctionAnalysisManager : public AnalysisManager<MachineFunction> {
+public:
+  using Base = AnalysisManager<MachineFunction>;
+
+  MachineFunctionAnalysisManager() : Base(false), FAM(nullptr), MAM(nullptr) {}
+  MachineFunctionAnalysisManager(FunctionAnalysisManager &FAM,
+                                 ModuleAnalysisManager &MAM,
+                                 bool DebugLogging = false)
+      : Base(DebugLogging), FAM(&FAM), MAM(&MAM) {}
+  MachineFunctionAnalysisManager(MachineFunctionAnalysisManager &&) = default;
+  MachineFunctionAnalysisManager &
+  operator=(MachineFunctionAnalysisManager &&) = default;
+
+  /// Get the result of an analysis pass for a Function.
+  ///
+  /// Runs the analysis if a cached result is not available.
+  template <typename PassT> typename PassT::Result &getResult(Function &F) {
+    return FAM->getResult<PassT>(F);
+  }
+
+  /// Get the cached result of an analysis pass for a Function.
+  ///
+  /// This method never runs the analysis.
+  ///
+  /// \returns null if there is no cached result.
+  template <typename PassT>
+  typename PassT::Result *getCachedResult(Function &F) {
+    return FAM->getCachedResult<PassT>(F);
+  }
+
+  /// Get the result of an analysis pass for a Module.
+  ///
+  /// Runs the analysis if a cached result is not available.
+  template <typename PassT> typename PassT::Result &getResult(Module &M) {
+    return MAM->getResult<PassT>(M);
+  }
+
+  /// Get the cached result of an analysis pass for a Module.
+  ///
+  /// This method never runs the analysis.
+  ///
+  /// \returns null if there is no cached result.
+  template <typename PassT> typename PassT::Result *getCachedResult(Module &M) {
+    return MAM->getCachedResult<PassT>(M);
+  }
+
+  /// Get the result of an analysis pass for a MachineFunction.
+  ///
+  /// Runs the analysis if a cached result is not available.
+  using Base::getResult;
+
+  /// Get the cached result of an analysis pass for a MachineFunction.
+  ///
+  /// This method never runs the analysis.
+  ///
+  /// \returns null if there is no cached result.
+  using Base::getCachedResult;
+
+  // FIXME: Add LoopAnalysisManager or CGSCCAnalysisManager if needed.
+  FunctionAnalysisManager *FAM;
+  ModuleAnalysisManager *MAM;
+};
+
+extern template class PassManager<MachineFunction>;
+
+/// MachineFunctionPassManager adds/removes below features to/from the base
+/// PassManager template instantiation.
+///
+/// - Support passes that implement doInitialization/doFinalization. This is for
+///   machine function passes to work on module level constructs. One such pass
+///   is AsmPrinter.
+///
+/// - Support machine module pass which runs over the module (for example,
+///   MachineOutliner). A machine module pass needs to define the method:
+///
+///   ```Error run(Module &, MachineFunctionAnalysisManager &)```
+///
+///   FIXME: machine module passes still need to define the usual machine
+///          function pass interface, namely,
+///          `PreservedAnalyses run(MachineFunction &,
+///                                 MachineFunctionAnalysisManager &)`
+///          But this interface wouldn't be executed. It is just a placeholder
+///          to satisfy the pass manager type-erased inteface. This
+///          special-casing of machine module pass is due to its limited use
+///          cases and the unnecessary complexity it may bring to the machine
+///          pass manager.
+///
+/// - The base class `run` method is replaced by an alternative `run` method.
+///   See details below.
+///
+/// - Support codegening in the SCC order. Users include interprocedural
+///   register allocation (IPRA).
+class MachineFunctionPassManager
+    : public PassManager<MachineFunction, MachineFunctionAnalysisManager> {
+  using Base = PassManager<MachineFunction, MachineFunctionAnalysisManager>;
+
+public:
+  MachineFunctionPassManager(bool DebugLogging = false,
+                             bool RequireCodeGenSCCOrder = false)
+      : Base(DebugLogging), RequireCodeGenSCCOrder(RequireCodeGenSCCOrder) {}
+  MachineFunctionPassManager(MachineFunctionPassManager &&) = default;
+  MachineFunctionPassManager &
+  operator=(MachineFunctionPassManager &&) = default;
+
+  /// Run machine passes for a Module.
+  ///
+  /// The intended use is to start the codegen pipeline for a Module. The base
+  /// class's `run` method is deliberately hidden by this due to the observation
+  /// that we don't yet have the use cases of compositing two instances of
+  /// machine pass managers, or compositing machine pass managers with other
+  /// types of pass managers.
+  Error run(Module &M, MachineFunctionAnalysisManager &MFAM);
+
+  template <typename PassT> void addPass(PassT &&Pass) {
+    Base::addPass(std::forward<PassT>(Pass));
+    PassConceptT *P = Passes.back().get();
+    addDoInitialization<PassT>(P);
+    addDoFinalization<PassT>(P);
+
+    // Add machine module pass.
+    addRunOnModule<PassT>(P);
+  }
+
+private:
+  template <typename PassT>
+  using has_init_t = decltype(std::declval<PassT &>().doInitialization(
+      std::declval<Module &>(),
+      std::declval<MachineFunctionAnalysisManager &>()));
+
+  template <typename PassT>
+  std::enable_if_t<!is_detected<has_init_t, PassT>::value>
+  addDoInitialization(PassConceptT *Pass) {}
+
+  template <typename PassT>
+  std::enable_if_t<is_detected<has_init_t, PassT>::value>
+  addDoInitialization(PassConceptT *Pass) {
+    using PassModelT =
+        detail::PassModel<MachineFunction, PassT, PreservedAnalyses,
+                          MachineFunctionAnalysisManager>;
+    auto *P = static_cast<PassModelT *>(Pass);
+    InitializationFuncs.emplace_back(
+        [=](Module &M, MachineFunctionAnalysisManager &MFAM) {
+          return P->Pass.doInitialization(M, MFAM);
+        });
+  }
+
+  template <typename PassT>
+  using has_fini_t = decltype(std::declval<PassT &>().doFinalization(
+      std::declval<Module &>(),
+      std::declval<MachineFunctionAnalysisManager &>()));
+
+  template <typename PassT>
+  std::enable_if_t<!is_detected<has_fini_t, PassT>::value>
+  addDoFinalization(PassConceptT *Pass) {}
+
+  template <typename PassT>
+  std::enable_if_t<is_detected<has_fini_t, PassT>::value>
+  addDoFinalization(PassConceptT *Pass) {
+    using PassModelT =
+        detail::PassModel<MachineFunction, PassT, PreservedAnalyses,
+                          MachineFunctionAnalysisManager>;
+    auto *P = static_cast<PassModelT *>(Pass);
+    FinalizationFuncs.emplace_back(
+        [=](Module &M, MachineFunctionAnalysisManager &MFAM) {
+          return P->Pass.doFinalization(M, MFAM);
+        });
+  }
+
+  template <typename PassT>
+  using is_machine_module_pass_t = decltype(std::declval<PassT &>().run(
+      std::declval<Module &>(),
+      std::declval<MachineFunctionAnalysisManager &>()));
+
+  template <typename PassT>
+  using is_machine_function_pass_t = decltype(std::declval<PassT &>().run(
+      std::declval<MachineFunction &>(),
+      std::declval<MachineFunctionAnalysisManager &>()));
+
+  template <typename PassT>
+  std::enable_if_t<!is_detected<is_machine_module_pass_t, PassT>::value>
+  addRunOnModule(PassConceptT *Pass) {}
+
+  template <typename PassT>
+  std::enable_if_t<is_detected<is_machine_module_pass_t, PassT>::value>
+  addRunOnModule(PassConceptT *Pass) {
+    static_assert(is_detected<is_machine_function_pass_t, PassT>::value,
+                  "machine module pass needs to define machine function pass "
+                  "api. sorry.");
+
+    using PassModelT =
+        detail::PassModel<MachineFunction, PassT, PreservedAnalyses,
+                          MachineFunctionAnalysisManager>;
+    auto *P = static_cast<PassModelT *>(Pass);
+    MachineModulePasses.emplace(
+        Passes.size() - 1,
+        [=](Module &M, MachineFunctionAnalysisManager &MFAM) {
+          return P->Pass.run(M, MFAM);
+        });
+  }
+
+  using FuncTy = Error(Module &, MachineFunctionAnalysisManager &);
+  SmallVector<llvm::unique_function<FuncTy>, 4> InitializationFuncs;
+  SmallVector<llvm::unique_function<FuncTy>, 4> FinalizationFuncs;
+
+  using PassIndex = decltype(Passes)::size_type;
+  std::map<PassIndex, llvm::unique_function<FuncTy>> MachineModulePasses;
+
+  // Run codegen in the SCC order.
+  bool RequireCodeGenSCCOrder;
+};
+
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_MACHINEPASSMANAGER_H

diff  --git a/llvm/include/llvm/IR/PassManager.h b/llvm/include/llvm/IR/PassManager.h
index f16696d7c2e3..b2592669f947 100644
--- a/llvm/include/llvm/IR/PassManager.h
+++ b/llvm/include/llvm/IR/PassManager.h
@@ -558,7 +558,7 @@ class PassManager : public PassInfoMixin<
 
   static bool isRequired() { return true; }
 
-private:
+protected:
   using PassConceptT =
       detail::PassConcept<IRUnitT, AnalysisManagerT, ExtraArgTs...>;
 

diff  --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt
index ea1a10d58263..64881c6f12d5 100644
--- a/llvm/lib/CodeGen/CMakeLists.txt
+++ b/llvm/lib/CodeGen/CMakeLists.txt
@@ -92,6 +92,7 @@ add_llvm_component_library(LLVMCodeGen
   MachineOperand.cpp
   MachineOptimizationRemarkEmitter.cpp
   MachineOutliner.cpp
+  MachinePassManager.cpp
   MachinePipeliner.cpp
   MachinePostDominators.cpp
   MachineRegionInfo.cpp

diff  --git a/llvm/lib/CodeGen/LLVMBuild.txt b/llvm/lib/CodeGen/LLVMBuild.txt
index 0a766fd7aca6..442ef407bc7a 100644
--- a/llvm/lib/CodeGen/LLVMBuild.txt
+++ b/llvm/lib/CodeGen/LLVMBuild.txt
@@ -21,4 +21,4 @@ subdirectories = AsmPrinter SelectionDAG MIRParser GlobalISel
 type = Library
 name = CodeGen
 parent = Libraries
-required_libraries = Analysis BitReader BitWriter Core MC ProfileData Scalar Support Target TransformUtils
+required_libraries = Analysis BitReader BitWriter Core MC Passes ProfileData Scalar Support Target TransformUtils

diff  --git a/llvm/lib/CodeGen/MachinePassManager.cpp b/llvm/lib/CodeGen/MachinePassManager.cpp
new file mode 100644
index 000000000000..654d97fde049
--- /dev/null
+++ b/llvm/lib/CodeGen/MachinePassManager.cpp
@@ -0,0 +1,103 @@
+//===---------- MachinePassManager.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains the pass management machinery for machine functions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/IR/PassManagerImpl.h"
+
+using namespace llvm;
+
+namespace llvm {
+template class AllAnalysesOn<MachineFunction>;
+template class AnalysisManager<MachineFunction>;
+template class PassManager<MachineFunction>;
+
+Error MachineFunctionPassManager::run(Module &M,
+                                      MachineFunctionAnalysisManager &MFAM) {
+  // MachineModuleAnalysis is a module analysis pass that is never invalidated
+  // because we don't run any module pass in codegen pipeline. This is very
+  // important because the codegen state is stored in MMI which is the analysis
+  // result of MachineModuleAnalysis. MMI should not be recomputed.
+  auto &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
+
+  assert(!RequireCodeGenSCCOrder && "not implemented");
+
+  if (DebugLogging) {
+    dbgs() << "Starting " << getTypeName<MachineFunction>()
+           << " pass manager run.\n";
+  }
+
+  for (auto &F : InitializationFuncs) {
+    if (auto Err = F(M, MFAM))
+      return Err;
+  }
+
+  unsigned Idx = 0;
+  size_t Size = Passes.size();
+  do {
+    // Run machine module passes
+    for (; MachineModulePasses.count(Idx) && Idx != Size; ++Idx) {
+      if (DebugLogging)
+        dbgs() << "Running pass: " << Passes[Idx]->name() << " on "
+               << M.getName() << '\n';
+      if (auto Err = MachineModulePasses.at(Idx)(M, MFAM))
+        return Err;
+    }
+
+    // Finish running all passes.
+    if (Idx == Size)
+      break;
+
+    // Run machine function passes
+
+    // Get index range of machine function passes.
+    unsigned Begin = Idx;
+    for (; !MachineModulePasses.count(Idx) && Idx != Size; ++Idx)
+      ;
+
+    for (Function &F : M) {
+      // Do not codegen any 'available_externally' functions at all, they have
+      // definitions outside the translation unit.
+      if (F.hasAvailableExternallyLinkage())
+        continue;
+
+      MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
+      PassInstrumentation PI = MFAM.getResult<PassInstrumentationAnalysis>(MF);
+
+      for (unsigned I = Begin, E = Idx; I != E; ++I) {
+        auto *P = Passes[I].get();
+
+        if (!PI.runBeforePass<MachineFunction>(*P, MF))
+          continue;
+
+        // TODO: EmitSizeRemarks
+        PreservedAnalyses PassPA = P->run(MF, MFAM);
+        PI.runAfterPass(*P, MF);
+        MFAM.invalidate(MF, PassPA);
+      }
+    }
+  } while (true);
+
+  for (auto &F : FinalizationFuncs) {
+    if (auto Err = F(M, MFAM))
+      return Err;
+  }
+
+  if (DebugLogging) {
+    dbgs() << "Finished " << getTypeName<MachineFunction>()
+           << " pass manager run.\n";
+  }
+
+  return Error::success();
+}
+
+} // namespace llvm

diff  --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt
index fa3cb1fa7669..831eb66e82cf 100644
--- a/llvm/unittests/CodeGen/CMakeLists.txt
+++ b/llvm/unittests/CodeGen/CMakeLists.txt
@@ -7,6 +7,7 @@ set(LLVM_LINK_COMPONENTS
   Core
   MC
   MIRParser
+  Passes
   SelectionDAG
   Support
   Target
@@ -20,6 +21,7 @@ add_llvm_unittest(CodeGenTests
   MachineInstrBundleIteratorTest.cpp
   MachineInstrTest.cpp
   MachineOperandTest.cpp
+  PassManagerTest.cpp
   ScalableVectorMVTsTest.cpp
   TypeTraitsTest.cpp
   TargetOptionsTest.cpp

diff  --git a/llvm/unittests/CodeGen/PassManagerTest.cpp b/llvm/unittests/CodeGen/PassManagerTest.cpp
new file mode 100644
index 000000000000..871e46581adc
--- /dev/null
+++ b/llvm/unittests/CodeGen/PassManagerTest.cpp
@@ -0,0 +1,305 @@
+//===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/CGSCCPassManager.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/CodeGen/MachineModuleInfo.h"
+#include "llvm/CodeGen/MachinePassManager.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/TargetRegistry.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Target/TargetMachine.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
+public:
+  struct Result {
+    Result(int Count) : InstructionCount(Count) {}
+    int InstructionCount;
+  };
+
+  /// Run the analysis pass over the function and return a result.
+  Result run(Function &F, FunctionAnalysisManager &AM) {
+    int Count = 0;
+    for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
+      for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
+           ++II)
+        ++Count;
+    return Result(Count);
+  }
+
+private:
+  friend AnalysisInfoMixin<TestFunctionAnalysis>;
+  static AnalysisKey Key;
+};
+
+AnalysisKey TestFunctionAnalysis::Key;
+
+class TestMachineFunctionAnalysis
+    : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
+public:
+  struct Result {
+    Result(int Count) : InstructionCount(Count) {}
+    int InstructionCount;
+  };
+
+  /// Run the analysis pass over the machine function and return a result.
+  Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) {
+    auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
+    // Query function analysis result.
+    TestFunctionAnalysis::Result &FAR =
+        MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
+    // + 5
+    return FAR.InstructionCount;
+  }
+
+private:
+  friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
+  static AnalysisKey Key;
+};
+
+AnalysisKey TestMachineFunctionAnalysis::Key;
+
+const std::string DoInitErrMsg = "doInitialization failed";
+const std::string DoFinalErrMsg = "doFinalization failed";
+
+struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
+  TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization,
+                          std::vector<int> &BeforeFinalization,
+                          std::vector<int> &MachineFunctionPassCount)
+      : Count(Count), BeforeInitialization(BeforeInitialization),
+        BeforeFinalization(BeforeFinalization),
+        MachineFunctionPassCount(MachineFunctionPassCount) {}
+
+  Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) {
+    // Force doInitialization fail by starting with big `Count`.
+    if (Count > 10000)
+      return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode());
+
+    // + 1
+    ++Count;
+    BeforeInitialization.push_back(Count);
+    return Error::success();
+  }
+  Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) {
+    // Force doFinalization fail by starting with big `Count`.
+    if (Count > 1000)
+      return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode());
+
+    // + 1
+    ++Count;
+    BeforeFinalization.push_back(Count);
+    return Error::success();
+  }
+
+  PreservedAnalyses run(MachineFunction &MF,
+                        MachineFunctionAnalysisManager &MFAM) {
+    // Query function analysis result.
+    TestFunctionAnalysis::Result &FAR =
+        MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
+    // 3 + 1 + 1 = 5
+    Count += FAR.InstructionCount;
+
+    // Query module analysis result.
+    MachineModuleInfo &MMI =
+        MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent());
+    // 1 + 1 + 1 = 3
+    Count += (MMI.getModule() == MF.getFunction().getParent());
+
+    // Query machine function analysis result.
+    TestMachineFunctionAnalysis::Result &MFAR =
+        MFAM.getResult<TestMachineFunctionAnalysis>(MF);
+    // 3 + 1 + 1 = 5
+    Count += MFAR.InstructionCount;
+
+    MachineFunctionPassCount.push_back(Count);
+
+    return PreservedAnalyses::none();
+  }
+
+  int &Count;
+  std::vector<int> &BeforeInitialization;
+  std::vector<int> &BeforeFinalization;
+  std::vector<int> &MachineFunctionPassCount;
+};
+
+struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
+  TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount)
+      : Count(Count), MachineModulePassCount(MachineModulePassCount) {}
+
+  Error run(Module &M, MachineFunctionAnalysisManager &MFAM) {
+    MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
+    // + 1
+    Count += (MMI.getModule() == &M);
+    MachineModulePassCount.push_back(Count);
+    return Error::success();
+  }
+
+  PreservedAnalyses run(MachineFunction &MF,
+                        MachineFunctionAnalysisManager &AM) {
+    llvm_unreachable(
+        "This should never be reached because this is machine module pass");
+  }
+
+  int &Count;
+  std::vector<int> &MachineModulePassCount;
+};
+
+std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
+  SMDiagnostic Err;
+  return parseAssemblyString(IR, Err, Context);
+}
+
+class PassManagerTest : public ::testing::Test {
+protected:
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+  std::unique_ptr<TargetMachine> TM;
+
+public:
+  PassManagerTest()
+      : M(parseIR(Context, "define void @f() {\n"
+                           "entry:\n"
+                           "  call void @g()\n"
+                           "  call void @h()\n"
+                           "  ret void\n"
+                           "}\n"
+                           "define void @g() {\n"
+                           "  ret void\n"
+                           "}\n"
+                           "define void @h() {\n"
+                           "  ret void\n"
+                           "}\n")) {
+    // MachineModuleAnalysis needs a TargetMachine instance.
+    llvm::InitializeAllTargets();
+
+    std::string Error;
+    const Target *TheTarget =
+        TargetRegistry::lookupTarget("x86_64-unknown-linux", Error);
+    // If we didn't build x86, do not run the test.
+    if (!TheTarget)
+      return;
+
+    TargetOptions Options;
+    TM.reset(TheTarget->createTargetMachine("x86_64-unknown-linux", "", "",
+                                            Options, None));
+  }
+};
+
+TEST_F(PassManagerTest, Basic) {
+  LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
+  M->setDataLayout(TM->createDataLayout());
+
+  LoopAnalysisManager LAM(/*DebugLogging=*/true);
+  FunctionAnalysisManager FAM(/*DebugLogging=*/true);
+  CGSCCAnalysisManager CGAM(/*DebugLogging=*/true);
+  ModuleAnalysisManager MAM(/*DebugLogging=*/true);
+  PassBuilder PB(TM.get());
+  PB.registerModuleAnalyses(MAM);
+  PB.registerFunctionAnalyses(FAM);
+  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+
+  FAM.registerPass([&] { return TestFunctionAnalysis(); });
+  FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); });
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+  MachineFunctionAnalysisManager MFAM;
+  {
+    // Test move assignment.
+    MachineFunctionAnalysisManager NestedMFAM(FAM, MAM,
+                                              /*DebugLogging*/ true);
+    NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
+    MFAM = std::move(NestedMFAM);
+  }
+
+  int Count = 0;
+  std::vector<int> BeforeInitialization[2];
+  std::vector<int> BeforeFinalization[2];
+  std::vector<int> TestMachineFunctionCount[2];
+  std::vector<int> TestMachineModuleCount[2];
+
+  MachineFunctionPassManager MFPM;
+  {
+    // Test move assignment.
+    MachineFunctionPassManager NestedMFPM(/*DebugLogging*/ true);
+    NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0]));
+    NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0],
+                                               BeforeFinalization[0],
+                                               TestMachineFunctionCount[0]));
+    NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1]));
+    NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
+                                               BeforeFinalization[1],
+                                               TestMachineFunctionCount[1]));
+    MFPM = std::move(NestedMFPM);
+  }
+
+  ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM)));
+
+  // Check first machine module pass
+  EXPECT_EQ(1u, TestMachineModuleCount[0].size());
+  EXPECT_EQ(3, TestMachineModuleCount[0][0]);
+
+  // Check first machine function pass
+  EXPECT_EQ(1u, BeforeInitialization[0].size());
+  EXPECT_EQ(1, BeforeInitialization[0][0]);
+  EXPECT_EQ(3u, TestMachineFunctionCount[0].size());
+  EXPECT_EQ(10, TestMachineFunctionCount[0][0]);
+  EXPECT_EQ(13, TestMachineFunctionCount[0][1]);
+  EXPECT_EQ(16, TestMachineFunctionCount[0][2]);
+  EXPECT_EQ(1u, BeforeFinalization[0].size());
+  EXPECT_EQ(31, BeforeFinalization[0][0]);
+
+  // Check second machine module pass
+  EXPECT_EQ(1u, TestMachineModuleCount[1].size());
+  EXPECT_EQ(17, TestMachineModuleCount[1][0]);
+
+  // Check second machine function pass
+  EXPECT_EQ(1u, BeforeInitialization[1].size());
+  EXPECT_EQ(2, BeforeInitialization[1][0]);
+  EXPECT_EQ(3u, TestMachineFunctionCount[1].size());
+  EXPECT_EQ(24, TestMachineFunctionCount[1][0]);
+  EXPECT_EQ(27, TestMachineFunctionCount[1][1]);
+  EXPECT_EQ(30, TestMachineFunctionCount[1][2]);
+  EXPECT_EQ(1u, BeforeFinalization[1].size());
+  EXPECT_EQ(32, BeforeFinalization[1][0]);
+
+  EXPECT_EQ(32, Count);
+
+  // doInitialization returns error
+  Count = 10000;
+  MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
+                                       BeforeFinalization[1],
+                                       TestMachineFunctionCount[1]));
+  std::string Message;
+  llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
+    Message = Error.getMessage();
+  });
+  EXPECT_EQ(Message, DoInitErrMsg);
+
+  // doFinalization returns error
+  Count = 1000;
+  MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
+                                       BeforeFinalization[1],
+                                       TestMachineFunctionCount[1]));
+  llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
+    Message = Error.getMessage();
+  });
+  EXPECT_EQ(Message, DoFinalErrMsg);
+}
+
+} // namespace


        


More information about the llvm-commits mailing list