[Mlir-commits] [mlir] b1aaed0 - Enable `Pass::initialize()` to fail by returning a LogicalResult

Mehdi Amini llvmlistbot at llvm.org
Wed Feb 10 17:52:05 PST 2021


Author: Mehdi Amini
Date: 2021-02-11T01:51:53Z
New Revision: b1aaed023e98ee9989532c8d3914c3bec7bbf964

URL: https://github.com/llvm/llvm-project/commit/b1aaed023e98ee9989532c8d3914c3bec7bbf964
DIFF: https://github.com/llvm/llvm-project/commit/b1aaed023e98ee9989532c8d3914c3bec7bbf964.diff

LOG: Enable `Pass::initialize()` to fail by returning a LogicalResult

Differential Revision: https://reviews.llvm.org/D96474

Added: 
    

Modified: 
    mlir/docs/PassManagement.md
    mlir/include/mlir/Pass/Pass.h
    mlir/include/mlir/Pass/PassManager.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Transforms/Canonicalizer.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md
index 8fc7ff2f0ae1..7588f4fba1b5 100644
--- a/mlir/docs/PassManagement.md
+++ b/mlir/docs/PassManagement.md
@@ -140,13 +140,15 @@ such example is when using [`PDL`-based](Dialects/PDLOps.md)
 runtime. In these situations, a pass may override the following hook to
 initialize this heavy state:
 
-*   `void initialize(MLIRContext *context)`
+*   `LogicalResult initialize(MLIRContext *context)`
 
 This hook is executed once per run of a full pass pipeline, meaning that it does
 not have access to the state available during a `runOnOperation` call. More
 concretely, all necessary accesses to an `MLIRContext` should be driven via the
 provided `context` parameter, and methods that utilize "per-run" state such as
 `getContext`/`getOperation`/`getAnalysis`/etc. must not be used.
+In case of an error during initialization, the pass is expected to emit an error
+diagnostic and return a `failure()` which will abort the pass pipeline execution.
 
 ## Analysis Management
 

diff  --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 3eeaf79b9610..10b1a5c57a81 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -170,7 +170,9 @@ class Pass {
   /// should not rely on any state accessible during the execution of a pass.
   /// For example, `getContext`/`getOperation`/`getAnalysis`/etc. should not be
   /// invoked within this hook.
-  virtual void initialize(MLIRContext *context) {}
+  /// Returns a LogicalResult to indicate failure, in which case the pass
+  /// pipeline won't execute.
+  virtual LogicalResult initialize(MLIRContext *context) { return success(); }
 
   /// Schedule an arbitrary pass pipeline on the provided operation.
   /// This can be invoke any time in a pass to dynamic schedule more passes.

diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 99798bdd551d..b6e9e2766773 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -130,7 +130,7 @@ class OpPassManager {
   /// Initialize all of the passes within this pass manager with the given
   /// initialization generation. The initialization generation is used to detect
   /// if a pass manager has already been initialized.
-  void initialize(MLIRContext *context, unsigned newInitGeneration);
+  LogicalResult initialize(MLIRContext *context, unsigned newInitGeneration);
 
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f4779cb860a0..8507fd6d3451 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -331,23 +331,26 @@ void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
 
 OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
 
-void OpPassManager::initialize(MLIRContext *context,
-                               unsigned newInitGeneration) {
+LogicalResult OpPassManager::initialize(MLIRContext *context,
+                                        unsigned newInitGeneration) {
   if (impl->initializationGeneration == newInitGeneration)
-    return;
+    return success();
   impl->initializationGeneration = newInitGeneration;
   for (Pass &pass : getPasses()) {
     // If this pass isn't an adaptor, directly initialize it.
     auto *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
     if (!adaptor) {
-      pass.initialize(context);
+      if (failed(pass.initialize(context)))
+        return failure();
       continue;
     }
 
     // Otherwise, initialize each of the adaptors pass managers.
     for (OpPassManager &adaptorPM : adaptor->getPassManagers())
-      adaptorPM.initialize(context, newInitGeneration);
+      if (failed(adaptorPM.initialize(context, newInitGeneration)))
+        return failure();
   }
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -379,7 +382,8 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
     assert(pipeline.getOpName() == root->getName().getStringRef());
 
     // Initialize the user provided pipeline and execute the pipeline.
-    pipeline.initialize(root->getContext(), parentInitGeneration);
+    if (failed(pipeline.initialize(root->getContext(), parentInitGeneration)))
+      return failure();
     AnalysisManager nestedAm = root == op ? am : am.nest(root);
     return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
                                           verifyPasses, parentInitGeneration,
@@ -872,7 +876,8 @@ LogicalResult PassManager::run(Operation *op) {
   // Initialize all of the passes within the pass manager with a new generation.
   llvm::hash_code newInitKey = context->getRegistryHash();
   if (newInitKey != initializationKey) {
-    initialize(context, impl->initializationGeneration + 1);
+    if (failed(initialize(context, impl->initializationGeneration + 1)))
+      return failure();
     initializationKey = newInitKey;
   }
 

diff  --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index cf0b21735362..ba6ffe72a248 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -23,11 +23,12 @@ namespace {
 struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
-  void initialize(MLIRContext *context) override {
+  LogicalResult initialize(MLIRContext *context) override {
     OwningRewritePatternList owningPatterns;
     for (auto *op : context->getRegisteredOperations())
       op->getCanonicalizationPatterns(owningPatterns, context);
     patterns = std::move(owningPatterns);
+    return success();
   }
   void runOnOperation() override {
     (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns);


        


More information about the Mlir-commits mailing list