[llvm-branch-commits] [mlir] 1ba5ea6 - [mlir] Add a hook for initializing passes before execution and use it in the Canonicalizer

River Riddle via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 8 13:42:03 PST 2021


Author: River Riddle
Date: 2021-01-08T13:36:12-08:00
New Revision: 1ba5ea67a30170053964a28f2f47aea4bb7f5ff1

URL: https://github.com/llvm/llvm-project/commit/1ba5ea67a30170053964a28f2f47aea4bb7f5ff1
DIFF: https://github.com/llvm/llvm-project/commit/1ba5ea67a30170053964a28f2f47aea4bb7f5ff1.diff

LOG: [mlir] Add a hook for initializing passes before execution and use it in the Canonicalizer

This revision adds a new `initialize(MLIRContext *)` hook to passes that allows for them to initialize any heavy state before the first execution of the pass. A concrete use case of this is with patterns that rely on PDL, given that PDL is compiled at run time it is imperative that compilation results are cached as much as possible. The first use of this hook is in the Canonicalizer, which has the added benefit of reducing the number of expensive accesses to the context when collecting patterns.

Differential Revision: https://reviews.llvm.org/D93147

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassDetail.h
    mlir/lib/Transforms/Canonicalizer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index ee47d0bcc437..558d5f6d315f 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -131,6 +131,23 @@ end, a pass that may create an entity from a dialect that isn't guaranteed to
 already ne loaded must express this by overriding the `getDependentDialects()`
 method and declare this list of Dialects explicitly.
 
+### Initialization
+
+In certain situations, a Pass may contain state that is constructed dynamically,
+but is potentially expensive to recompute in successive runs of the Pass. One
+such example is when using [`PDL`-based](Dialects/PDLOps.md)
+[patterns](PatternRewriter.md), which are compiled into a bytecode during
+runtime. In these situations, a pass may override the following hook to
+initialize this heavy state:
+
+*   `void initialize(MLIRContext *context)`
+
+This hook is executed once per run of a full pass pipeline, meaning that it does
+not have access to the state available during a `runOnOperation` call. More
+concretely, all necessary accesses to an `MLIRContext` should be driven via the
+provided `context` parameter, and methods that utilize "per-run" state such as
+`getContext`/`getOperation`/`getAnalysis`/etc. must not be used.
+
 ## Analysis Management
 
 An important concept, along with transformation passes, are analyses. These are

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 7a9523714293..3eeaf79b9610 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -166,6 +166,12 @@ class Pass {
   /// The polymorphic API that runs the pass over the currently held operation.
   virtual void runOnOperation() = 0;
 
+  /// Initialize any complex state necessary for running this pass. This hook
+  /// should not rely on any state accessible during the execution of a pass.
+  /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be
+  /// invoked within this hook.
+  virtual void initialize(MLIRContext *context) {}
+
   /// Schedule an arbitrary pass pipeline on the provided operation.
   /// This can be invoke any time in a pass to dynamic schedule more passes.
   /// The provided operation must be the current one or one nested below.

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 2715ebd05cac..9fc81b5e421d 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -35,6 +35,7 @@ class PassInstrumentor;
 
 namespace detail {
 struct OpPassManagerImpl;
+class OpToOpPassAdaptor;
 struct PassExecutionState;
 } // end namespace detail
 
@@ -126,9 +127,17 @@ class OpPassManager {
   Nesting getNesting();
 
 private:
+  /// Initialize all of the passes within this pass manager with the given
+  /// initialization generation. The initialization generation is used to detect
+  /// if a pass manager has already been initialized.
+  void initialize(MLIRContext *context, unsigned newInitGeneration);
+
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;
 
+  /// Allow access to initialize.
+  friend detail::OpToOpPassAdaptor;
+
   /// Allow access to the constructor.
   friend class PassManager;
   friend class Pass;

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index d9046bef1469..fdc6d56d86a5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -81,9 +81,10 @@ namespace mlir {
 namespace detail {
 struct OpPassManagerImpl {
   OpPassManagerImpl(Identifier identifier, OpPassManager::Nesting nesting)
-      : name(identifier.str()), identifier(identifier), nesting(nesting) {}
+      : name(identifier.str()), identifier(identifier),
+        initializationGeneration(0), nesting(nesting) {}
   OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting)
-      : name(name), nesting(nesting) {}
+      : name(name), initializationGeneration(0), nesting(nesting) {}
 
   /// Merge the passes of this pass manager into the one provided.
   void mergeInto(OpPassManagerImpl &rhs);
@@ -105,6 +106,7 @@ struct OpPassManagerImpl {
   /// pass.
   void splitAdaptorPasses();
 
+  /// Return the operation name of this pass manager as an identifier.
   Identifier getOpName(MLIRContext &context) {
     if (!identifier)
       identifier = Identifier::get(name, &context);
@@ -121,6 +123,10 @@ struct OpPassManagerImpl {
   /// The set of passes to run as part of this pass manager.
   std::vector<std::unique_ptr<Pass>> passes;
 
+  /// The current initialization generation of this pass manager. This is used
+  /// to indicate when a pass manager should be reinitialized.
+  unsigned initializationGeneration;
+
   /// Control the implicit nesting of passes that mismatch the name set for this
   /// OpPassManager.
   OpPassManager::Nesting nesting;
@@ -320,16 +326,36 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
   registerDialectsForPipeline(*this, dialects);
 }
 
+void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
+
 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
 
-void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
+void OpPassManager::initialize(MLIRContext *context,
+                               unsigned newInitGeneration) {
+  if (impl->initializationGeneration == newInitGeneration)
+    return;
+  impl->initializationGeneration = newInitGeneration;
+  for (Pass &pass : getPasses()) {
+    // If this pass isn't an adaptor, directly initialize it.
+    auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
+    if (!adaptor) {
+      pass.initialize(context);
+      continue;
+    }
+
+    // Otherwise, initialize each of the adaptors pass managers.
+    for (OpPassManager &adaptorPM : adaptor->getPassManagers())
+      adaptorPM.initialize(context, newInitGeneration);
+  }
+}
 
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
 
 LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
-                                     AnalysisManager am, bool verifyPasses) {
+                                     AnalysisManager am, bool verifyPasses,
+                                     unsigned parentInitGeneration) {
   if (!op->getName().getAbstractOperation())
     return op->emitOpError()
            << "trying to schedule a pass on an unregistered operation";
@@ -352,9 +378,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
                 "nested under the current operation the pass is processing";
     assert(pipeline.getOpName() == root->getName().getStringRef());
 
+    // Initialize the user provided pipeline and execute the pipeline.
+    pipeline.initialize(root->getContext(), parentInitGeneration);
     AnalysisManager nestedAm = root == op ? am : am.nest(root);
     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
-                                          verifyPasses, pi, &parentInfo);
+                                          verifyPasses, parentInitGeneration,
+                                          pi, &parentInfo);
   };
   pass->passState.emplace(op, am, dynamic_pipeline_callback);
 
@@ -391,7 +420,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
 /// Run the given operation and analysis manager on a provided op pass manager.
 LogicalResult OpToOpPassAdaptor::runPipeline(
     iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
-    AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor,
+    AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration,
+    PassInstrumentor *instrumentor,
     const PassInstrumentation::PipelineParentInfo *parentInfo) {
   assert((!instrumentor || parentInfo) &&
          "expected parent info if instrumentor is provided");
@@ -407,7 +437,7 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
   if (instrumentor)
     instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
   for (Pass &pass : passes)
-    if (failed(run(&pass, op, am, verifyPasses)))
+    if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration)))
       return failure();
   if (instrumentor)
     instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
@@ -502,8 +532,10 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
           continue;
 
         // Run the held pipeline over the current operation.
+        unsigned initGeneration = mgr->impl->initializationGeneration;
         if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
-                               verifyPasses, instrumentor, &parentInfo)))
+                               verifyPasses, initGeneration, instrumentor,
+                               &parentInfo)))
           return signalPassFailure();
       }
     }
@@ -578,9 +610,10 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
               pms, it.first->getName().getIdentifier(), getContext());
           assert(pm && "expected valid pass manager for operation");
 
+          unsigned initGeneration = pm->impl->initializationGeneration;
           LogicalResult pipelineResult =
               runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
-                          instrumentor, &parentInfo);
+                          initGeneration, instrumentor, &parentInfo);
 
           // Drop this thread from being tracked by the diagnostic handler.
           // After this task has finished, the thread may be used outside of
@@ -753,7 +786,8 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
   llvm::CrashRecoveryContext recoveryContext;
   recoveryContext.RunSafelyOnThread([&] {
     for (std::unique_ptr<Pass> &pass : passes)
-      if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses)))
+      if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses,
+                                        impl->initializationGeneration)))
         return;
     passManagerResult = success();
   });
@@ -801,6 +835,9 @@ LogicalResult PassManager::run(Operation *op) {
   getDependentDialects(dependentDialects);
   dependentDialects.loadAll(context);
 
+  // Initialize all of the passes within the pass manager with a new generation.
+  initialize(context, impl->initializationGeneration + 1);
+
   // Construct a top level analysis manager for the pipeline.
   ModuleAnalysisManager am(op, instrumentor.get());
 
@@ -812,7 +849,8 @@ LogicalResult PassManager::run(Operation *op) {
   LogicalResult result =
       crashReproducerFileName
           ? runWithCrashRecovery(op, am)
-          : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses);
+          : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses,
+                                           impl->initializationGeneration);
 
   // Notify the context that the run is done.
   context->exitMultiThreadedExecution();

diff  --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 2533d877fc00..436ad65629cc 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -55,14 +55,19 @@ class OpToOpPassAdaptor
   void runOnOperationAsyncImpl(bool verifyPasses);
 
   /// Run the given operation and analysis manager on a single pass.
+  /// `parentInitGeneration` is the initialization generation of the parent pass
+  /// manager, and is used to initialize any dynamic pass pipelines run by the
+  /// given pass.
   static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am,
-                           bool verifyPasses);
+                           bool verifyPasses, unsigned parentInitGeneration);
 
   /// Run the given operation and analysis manager on a provided op pass
-  /// manager.
+  /// manager. `parentInitGeneration` is the initialization generation of the
+  /// parent pass manager, and is used to initialize any dynamic pass pipelines
+  /// run by the given passes.
   static LogicalResult runPipeline(
       iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
-      AnalysisManager am, bool verifyPasses,
+      AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration,
       PassInstrumentor *instrumentor = nullptr,
       const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr);
 

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index 70208a89debf..81dfa9b74c5d 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -21,19 +21,19 @@ using namespace mlir;
 namespace {
 /// Canonicalize operations in nested regions.
 struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
-  void runOnOperation() override {
-    OwningRewritePatternList patterns;
-
-    // TODO: Instead of adding all known patterns from the whole system lazily
-    // add and cache the canonicalization patterns for ops we see in practice
-    // when building the worklist.  For now, we just grab everything.
-    auto *context = &getContext();
+  /// Initialize the canonicalizer by building the set of patterns used during
+  /// execution.
+  void initialize(MLIRContext *context) override {
+    OwningRewritePatternList owningPatterns;
     for (auto *op : context->getRegisteredOperations())
-      op->getCanonicalizationPatterns(patterns, context);
-
-    Operation *op = getOperation();
-    applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
+      op->getCanonicalizationPatterns(owningPatterns, context);
+    patterns = std::move(owningPatterns);
   }
+  void runOnOperation() override {
+    applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);
+  }
+
+  FrozenRewritePatternList patterns;
 };
 } // end anonymous namespace
 


        


More information about the llvm-branch-commits mailing list