[llvm-branch-commits] [mlir] 00c6ef8 - [mlir][Pass] Remove the restriction that PassManager can only run on ModuleOp

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 3 15:58:04 PST 2020


Author: River Riddle
Date: 2020-12-03T15:47:01-08:00
New Revision: 00c6ef8628a6ee062d5104692a92d5e000dd7c05

URL: https://github.com/llvm/llvm-project/commit/00c6ef8628a6ee062d5104692a92d5e000dd7c05
DIFF: https://github.com/llvm/llvm-project/commit/00c6ef8628a6ee062d5104692a92d5e000dd7c05.diff

LOG: [mlir][Pass] Remove the restriction that PassManager can only run on ModuleOp

This was a somewhat important restriction in the past when ModuleOp was distinctly the top-level container operation, as well as before the pass manager had support for running nested pass managers natively. With these two issues fading away, there isn't really a good reason to enforce that a ModuleOp is the thing running within a pass manager. As such, this revision removes the restriction and allows for users to pass in the name of the operation that the pass manager will be scheduled on.

The only remaining dependency on BuiltinOps from Pass after this revision is due to FunctionPass, which will be resolved in a followup revision.

Differential Revision: https://reviews.llvm.org/D92450

Added: 
    

Modified: 
    mlir/include/mlir/Pass/AnalysisManager.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/IRPrinting.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassManagerOptions.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index de428da5abd4..ec6b7696ce60 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -9,7 +9,7 @@
 #ifndef MLIR_PASS_ANALYSISMANAGER_H
 #define MLIR_PASS_ANALYSISMANAGER_H
 
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/Pass/PassInstrumentation.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/DenseMap.h"
@@ -177,8 +177,8 @@ class AnalysisMap {
     bool wasInserted;
     std::tie(it, wasInserted) = analyses.try_emplace(id);
 
-    // If we don't have a cached analysis for this function, compute it directly
-    // and add it to the cache.
+    // If we don't have a cached analysis for this operation, compute it
+    // directly and add it to the cache.
     if (wasInserted) {
       if (pi)
         pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
@@ -321,14 +321,14 @@ class AnalysisManager {
   friend class ModuleAnalysisManager;
 };
 
-/// An analysis manager class specifically for the top-level module operation.
-/// This class contains the memory allocations for all nested analysis managers,
-/// and provides an anchor point. This is necessary because AnalysisManager is
+/// An analysis manager class specifically for the top-level operation. This
+/// class contains the memory allocations for all nested analysis managers, and
+/// provides an anchor point. This is necessary because AnalysisManager is
 /// designed to be a thin wrapper around an existing analysis map instance.
 class ModuleAnalysisManager {
 public:
-  ModuleAnalysisManager(ModuleOp module, PassInstrumentor *passInstrumentor)
-      : analyses(module), passInstrumentor(passInstrumentor) {}
+  ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
+      : analyses(op), passInstrumentor(passInstrumentor) {}
   ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
   ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
 

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index eb21359d6211..5e9c9a790d29 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -28,7 +28,6 @@ namespace mlir {
 class AnalysisManager;
 class Identifier;
 class MLIRContext;
-class ModuleOp;
 class Operation;
 class Pass;
 class PassInstrumentation;
@@ -158,12 +157,20 @@ enum class PassDisplayMode {
 /// The main pass manager and pipeline builder.
 class PassManager : public OpPassManager {
 public:
-  PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit);
+  /// Create a new pass manager under the given context with a specific nesting
+  /// style. The created pass manager can schedule operations that match
+  /// `operationName`.
+  PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit,
+              StringRef operationName = "module");
+  PassManager(MLIRContext *ctx, StringRef operationName)
+      : PassManager(ctx, Nesting::Explicit, operationName) {}
   ~PassManager();
 
-  /// Run the passes within this manager on the provided module.
+  /// Run the passes within this manager on the provided operation. The
+  /// specified operation must have the same name as the one provided the pass
+  /// manager on construction.
   LLVM_NODISCARD
-  LogicalResult run(ModuleOp module);
+  LogicalResult run(Operation *op);
 
   /// Return an instance of the context.
   MLIRContext *getContext() const { return context; }
@@ -318,11 +325,11 @@ class PassManager : public OpPassManager {
   void dumpStatistics();
 
   /// Run the pass manager with crash recover enabled.
-  LogicalResult runWithCrashRecovery(ModuleOp module, AnalysisManager am);
+  LogicalResult runWithCrashRecovery(Operation *op, AnalysisManager am);
   /// Run the given passes with crash recover enabled.
   LogicalResult
   runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                       ModuleOp module, AnalysisManager am);
+                       Operation *op, AnalysisManager am);
 
   /// Context this PassManager was initialized with.
   MLIRContext *context;

diff  --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 2f6c3a2a5af4..b27b39dd322d 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "PassDetail.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/PassManager.h"
 #include "llvm/Support/Format.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -97,14 +96,10 @@ class IRPrinterInstrumentation : public PassInstrumentation {
 
 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
                     OpPrintingFlags flags) {
-  // Check to see if we are printing the top-level module.
-  auto module = dyn_cast<ModuleOp>(op);
-  if (module && !op->getBlock())
-    return module.print(out << "\n", flags);
-
   // Otherwise, check to see if we are not printing at module scope.
   if (!printModuleScope)
-    return op->print(out << "\n", flags.useLocalScope());
+    return op->print(out << "\n",
+                     op->getBlock() ? flags.useLocalScope() : flags);
 
   // Otherwise, we are printing at module scope.
   out << " ('" << op->getName() << "' operation";
@@ -113,17 +108,11 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
     out << ": @" << symbolName.getValue();
   out << ")\n";
 
-  // Find the top-level module operation.
+  // Find the top-level operation.
   auto *topLevelOp = op;
   while (auto *parentOp = topLevelOp->getParentOp())
     topLevelOp = parentOp;
-
-  // Check to see if the top-level operation is actually a module in the case of
-  // invalid-ir.
-  if (auto module = dyn_cast<ModuleOp>(topLevelOp))
-    module.print(out, flags);
-  else
-    topLevelOp->print(out, flags);
+  topLevelOp->print(out, flags);
 }
 
 /// Instrumentation hooks.

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 813b7a8db509..056da035a5b5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Pass/Pass.h"
 #include "PassDetail.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
@@ -528,9 +527,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
                           mgrs);
 
-  // Run a prepass over the module to collect the operations to execute over.
-  // This ensures that an analysis manager exists for each operation, as well as
-  // providing a queue of operations to execute over.
+  // Run a prepass over the operation to collect the nested operations to
+  // execute over. This ensures that an analysis manager exists for each
+  // operation, as well as providing a queue of operations to execute over.
   std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
   for (auto &region : getOperation()->getRegions()) {
     for (auto &block : region) {
@@ -614,7 +613,7 @@ namespace {
 /// reproducers when a signal is raised, such as a segfault.
 struct RecoveryReproducerContext {
   RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                            ModuleOp module, StringRef filename,
+                            Operation *op, StringRef filename,
                             bool disableThreads, bool verifyPasses);
   ~RecoveryReproducerContext();
 
@@ -631,8 +630,8 @@ struct RecoveryReproducerContext {
   /// The textual description of the currently executing pipeline.
   std::string pipeline;
 
-  /// The MLIR module representing the IR before the crash.
-  OwningModuleRef module;
+  /// The MLIR operation representing the IR before the crash.
+  Operation *preCrashOperation;
 
   /// The filename to use when generating the reproducer.
   StringRef filename;
@@ -658,9 +657,9 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
     RecoveryReproducerContext::reproducerSet;
 
 RecoveryReproducerContext::RecoveryReproducerContext(
-    MutableArrayRef<std::unique_ptr<Pass>> passes, ModuleOp module,
+    MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
     StringRef filename, bool disableThreads, bool verifyPasses)
-    : module(module.clone()), filename(filename),
+    : preCrashOperation(op->clone()), filename(filename),
       disableThreads(disableThreads), verifyPasses(verifyPasses) {
   // Grab the textual pipeline being executed..
   {
@@ -677,6 +676,9 @@ RecoveryReproducerContext::RecoveryReproducerContext(
 }
 
 RecoveryReproducerContext::~RecoveryReproducerContext() {
+  // Erase the cloned preCrash IR that we cached.
+  preCrashOperation->erase();
+
   llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
   reproducerSet->remove(this);
   if (reproducerSet->empty())
@@ -700,7 +702,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) {
            << "\n";
 
   // Output the .mlir module.
-  module->print(outputOS);
+  preCrashOperation->print(outputOS);
   outputFile->keep();
   return success();
 }
@@ -722,11 +724,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
 }
 
 /// Run the pass manager with crash recover enabled.
-LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
+LogicalResult PassManager::runWithCrashRecovery(Operation *op,
                                                 AnalysisManager am) {
   // If this isn't a local producer, run all of the passes in recovery mode.
   if (!localReproducer)
-    return runWithCrashRecovery(impl->passes, module, am);
+    return runWithCrashRecovery(impl->passes, op, am);
 
   // Split the passes within adaptors to ensure that each pass can be run in
   // isolation.
@@ -735,7 +737,7 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
   // If this is a local producer, run each of the passes individually.
   MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
   for (std::unique_ptr<Pass> &pass : passes)
-    if (failed(runWithCrashRecovery(pass, module, am)))
+    if (failed(runWithCrashRecovery(pass, op, am)))
       return failure();
   return success();
 }
@@ -743,8 +745,8 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
 /// Run the given passes with crash recover enabled.
 LogicalResult
 PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
-                                  ModuleOp module, AnalysisManager am) {
-  RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
+                                  Operation *op, AnalysisManager am) {
+  RecoveryReproducerContext context(passes, op, *crashReproducerFileName,
                                     !getContext()->isMultithreadingEnabled(),
                                     verifyPasses);
 
@@ -753,7 +755,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
   llvm::CrashRecoveryContext recoveryContext;
   recoveryContext.RunSafelyOnThread([&] {
     for (std::unique_ptr<Pass> &pass : passes)
-      if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses)))
+      if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses)))
         return;
     passManagerResult = success();
   });
@@ -762,8 +764,8 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 
   std::string error;
   if (failed(context.generate(error)))
-    return module.emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
-  return module.emitError()
+    return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
+  return op->emitError()
          << "A failure has been detected while processing the MLIR module, a "
             "reproducer has been generated in '"
          << *crashReproducerFileName << "'";
@@ -773,18 +775,21 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
 // PassManager
 //===----------------------------------------------------------------------===//
 
-PassManager::PassManager(MLIRContext *ctx, Nesting nesting)
-    : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
-                    nesting),
-      context(ctx), passTiming(false), localReproducer(false),
-      verifyPasses(true) {}
+PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
+                         StringRef operationName)
+    : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
+      passTiming(false), localReproducer(false), verifyPasses(true) {}
 
 PassManager::~PassManager() {}
 
 void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
 
-/// Run the passes within this manager on the provided module.
-LogicalResult PassManager::run(ModuleOp module) {
+/// Run the passes within this manager on the provided operation.
+LogicalResult PassManager::run(Operation *op) {
+  MLIRContext *context = getContext();
+  assert(op->getName().getIdentifier() == getOpName(*context) &&
+         "operation has a 
diff erent name than the PassManager");
+
   // Before running, make sure to coalesce any adjacent pass adaptors in the
   // pipeline.
   getImpl().coalesceAdjacentAdaptorPasses();
@@ -792,23 +797,23 @@ LogicalResult PassManager::run(ModuleOp module) {
   // Register all dialects for the current pipeline.
   DialectRegistry dependentDialects;
   getDependentDialects(dependentDialects);
-  dependentDialects.loadAll(module.getContext());
+  dependentDialects.loadAll(context);
 
-  // Construct an analysis manager for the pipeline.
-  ModuleAnalysisManager am(module, instrumentor.get());
+  // Construct a top level analysis manager for the pipeline.
+  ModuleAnalysisManager am(op, instrumentor.get());
 
   // Notify the context that we start running a pipeline for book keeping.
-  module.getContext()->enterMultiThreadedExecution();
+  context->enterMultiThreadedExecution();
 
   // If reproducer generation is enabled, run the pass manager with crash
   // handling enabled.
-  LogicalResult result = crashReproducerFileName
-                             ? runWithCrashRecovery(module, am)
-                             : OpToOpPassAdaptor::runPipeline(
-                                   getPasses(), module, am, verifyPasses);
+  LogicalResult result =
+      crashReproducerFileName
+          ? runWithCrashRecovery(op, am)
+          : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses);
 
   // Notify the context that the run is done.
-  module.getContext()->exitMultiThreadedExecution();
+  context->exitMultiThreadedExecution();
 
   // Dump all of the pass statistics if necessary.
   if (passStatisticsMode)

diff  --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index b00f992eceb9..a581ce070fc4 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -50,7 +50,7 @@ struct PassManagerOptions {
   llvm::cl::opt<bool> printModuleScope{
       "print-ir-module-scope",
       llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
-                     "always print the top-level module operation"),
+                     "always print the top-level operation"),
       llvm::cl::init(false)};
 
   /// Add an IR printing instrumentation if enabled by any 'print-ir' flags.


        


More information about the llvm-branch-commits mailing list