[Mlir-commits] [mlir] a62d38a - Disable implicit nesting on parsing textual pass pipeline

Mehdi Amini llvmlistbot at llvm.org
Wed Nov 11 11:21:59 PST 2020


Author: Mehdi Amini
Date: 2020-11-11T19:21:51Z
New Revision: a62d38a90d2feee3955df032f22a02a09a42cd44

URL: https://github.com/llvm/llvm-project/commit/a62d38a90d2feee3955df032f22a02a09a42cd44
DIFF: https://github.com/llvm/llvm-project/commit/a62d38a90d2feee3955df032f22a02a09a42cd44.diff

LOG: Disable implicit nesting on parsing textual pass pipeline

Previous the textual form of the pass pipeline would implicitly nest,
instead we opt for the explicit form here: this has less surprise.

This also avoids asserting in the bindings when passing a pass pipeline
with incorrect nesting.

Differential Revision: https://reviews.llvm.org/D91233

Added: 
    

Modified: 
    mlir/include/mlir/Pass/PassManager.h
    mlir/include/mlir/Pass/PassRegistry.h
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Pass/PassRegistry.cpp
    mlir/lib/Support/MlirOptMain.cpp
    mlir/test/Bindings/Python/pass_manager.py
    mlir/test/Pass/pipeline-options-parsing.mlir
    mlir/test/Pass/pipeline-parsing.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 3080fc35a153..33ffb542c56b 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -118,6 +118,14 @@ class OpPassManager {
   /// documentation for the same method on the Pass class.
   void getDependentDialects(DialectRegistry &dialects) const;
 
+  /// Enable or disable the implicit nesting on this particular PassManager.
+  /// This will also apply to any newly nested PassManager built from this
+  /// instance.
+  void setNesting(Nesting nesting);
+
+  /// Return the current nesting mode.
+  Nesting getNesting();
+
 private:
   /// A pointer to an internal implementation instance.
   std::unique_ptr<detail::OpPassManagerImpl> impl;

diff  --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index 3187d8eb34c7..52ded58ccfe6 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -28,8 +28,12 @@ class PassOptions;
 
 /// A registry function that adds passes to the given pass manager. This should
 /// also parse options and return success() if parsing succeeded.
-using PassRegistryFunction =
-    std::function<LogicalResult(OpPassManager &, StringRef options)>;
+/// `errorHandler` is a functor used to emit errors during parsing.
+/// parameter corresponds to the raw location within the pipeline string. This
+/// should always return failure.
+using PassRegistryFunction = std::function<LogicalResult(
+    OpPassManager &, StringRef options,
+    function_ref<LogicalResult(const Twine &)> errorHandler)>;
 using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
 
 //===----------------------------------------------------------------------===//
@@ -43,10 +47,12 @@ class PassRegistryEntry {
   /// Adds this pass registry entry to the given pass manager. `options` is
   /// an opaque string that will be parsed by the builder. The success of
   /// parsing will be returned.
-  LogicalResult addToPipeline(OpPassManager &pm, StringRef options) const {
+  LogicalResult
+  addToPipeline(OpPassManager &pm, StringRef options,
+                function_ref<LogicalResult(const Twine &)> errorHandler) const {
     assert(builder &&
            "cannot call addToPipeline on PassRegistryEntry without builder");
-    return builder(pm, options);
+    return builder(pm, options, errorHandler);
   }
 
   /// Returns the command line option that may be passed to 'mlir-opt' that will
@@ -163,7 +169,8 @@ struct PassPipelineRegistration {
       std::function<void(OpPassManager &, const Options &options)> builder) {
     registerPassPipeline(
         arg, description,
-        [builder](OpPassManager &pm, StringRef optionsStr) {
+        [builder](OpPassManager &pm, StringRef optionsStr,
+                  function_ref<LogicalResult(const Twine &)> errorHandler) {
           Options options;
           if (failed(options.parseFromString(optionsStr)))
             return failure();
@@ -183,7 +190,8 @@ template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
                            std::function<void(OpPassManager &)> builder) {
     registerPassPipeline(
         arg, description,
-        [builder](OpPassManager &pm, StringRef optionsStr) {
+        [builder](OpPassManager &pm, StringRef optionsStr,
+                  function_ref<LogicalResult(const Twine &)> errorHandler) {
           if (!optionsStr.empty())
             return failure();
           builder(pm);
@@ -230,7 +238,9 @@ class PassPipelineCLParser {
   /// Adds the passes defined by this parser entry to the given pass manager.
   /// Returns failure() if the pass could not be properly constructed due
   /// to options parsing.
-  LogicalResult addToPipeline(OpPassManager &pm) const;
+  LogicalResult
+  addToPipeline(OpPassManager &pm,
+                function_ref<LogicalResult(const Twine &)> errorHandler) const;
 
 private:
   std::unique_ptr<detail::PassPipelineCLParserImpl> impl;

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index a1aecf6ffb20..8682171834ca 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -348,6 +348,10 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
   registerDialectsForPipeline(*this, dialects);
 }
 
+OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
+
+void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
+
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 442233024bbe..78e40d5b0aa7 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -27,9 +27,16 @@ static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
 /// Utility to create a default registry function from a pass instance.
 static PassRegistryFunction
 buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
-  return [=](OpPassManager &pm, StringRef options) {
+  return [=](OpPassManager &pm, StringRef options,
+             function_ref<LogicalResult(const Twine &)> errorHandler) {
     std::unique_ptr<Pass> pass = allocator();
     LogicalResult result = pass->initializeOptions(options);
+    if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
+        pass->getOpName() && *pass->getOpName() != pm.getOpName())
+      return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
+                          "' restricted to '" + *pass->getOpName() +
+                          "' on a PassManager intended to run on '" +
+                          pm.getOpName() + "', did you intend to nest?");
     pm.addPass(std::move(pass));
     return result;
   };
@@ -229,7 +236,9 @@ class TextualPipeline {
   LogicalResult initialize(StringRef text, raw_ostream &errorStream);
 
   /// Add the internal pipeline elements to the provided pass manager.
-  LogicalResult addToPipeline(OpPassManager &pm) const;
+  LogicalResult
+  addToPipeline(OpPassManager &pm,
+                function_ref<LogicalResult(const Twine &)> errorHandler) const;
 
 private:
   /// A functor used to emit errors found during pipeline handling. The first
@@ -269,8 +278,9 @@ class TextualPipeline {
                                        ErrorHandlerT errorHandler);
 
   /// Add the given pipeline elements to the provided pass manager.
-  LogicalResult addToPipeline(ArrayRef<PipelineElement> elements,
-                              OpPassManager &pm) const;
+  LogicalResult
+  addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
+                function_ref<LogicalResult(const Twine &)> errorHandler) const;
 
   std::vector<PipelineElement> pipeline;
 };
@@ -299,8 +309,10 @@ LogicalResult TextualPipeline::initialize(StringRef text,
 }
 
 /// Add the internal pipeline elements to the provided pass manager.
-LogicalResult TextualPipeline::addToPipeline(OpPassManager &pm) const {
-  return addToPipeline(pipeline, pm);
+LogicalResult TextualPipeline::addToPipeline(
+    OpPassManager &pm,
+    function_ref<LogicalResult(const Twine &)> errorHandler) const {
+  return addToPipeline(pipeline, pm, errorHandler);
 }
 
 /// Parse the given pipeline text into the internal pipeline vector. This
@@ -397,7 +409,6 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
   // pipeline.
   if (!element.innerPipeline.empty())
     return resolvePipelineElements(element.innerPipeline, errorHandler);
-
   // Otherwise, this must be a pass or pass pipeline.
   // Check to see if a pipeline was registered with this name.
   auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
@@ -422,13 +433,16 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
 }
 
 /// Add the given pipeline elements to the provided pass manager.
-LogicalResult TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements,
-                                             OpPassManager &pm) const {
+LogicalResult TextualPipeline::addToPipeline(
+    ArrayRef<PipelineElement> elements, OpPassManager &pm,
+    function_ref<LogicalResult(const Twine &)> errorHandler) const {
   for (auto &elt : elements) {
     if (elt.registryEntry) {
-      if (failed(elt.registryEntry->addToPipeline(pm, elt.options)))
+      if (failed(
+              elt.registryEntry->addToPipeline(pm, elt.options, errorHandler)))
         return failure();
-    } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name)))) {
+    } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
+                                    errorHandler))) {
       return failure();
     }
   }
@@ -444,7 +458,11 @@ LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
   TextualPipeline pipelineParser;
   if (failed(pipelineParser.initialize(pipeline, errorStream)))
     return failure();
-  if (failed(pipelineParser.addToPipeline(pm)))
+  auto errorHandler = [&](Twine msg) {
+    errorStream << msg << "\n";
+    return failure();
+  };
+  if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
     return failure();
   return success();
 }
@@ -634,13 +652,21 @@ bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
 }
 
 /// Adds the passes defined by this parser entry to the given pass manager.
-LogicalResult PassPipelineCLParser::addToPipeline(OpPassManager &pm) const {
+LogicalResult PassPipelineCLParser::addToPipeline(
+    OpPassManager &pm,
+    function_ref<LogicalResult(const Twine &)> errorHandler) const {
   for (auto &passIt : impl->passList) {
     if (passIt.registryEntry) {
-      if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options)))
+      if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
+                                                     errorHandler)))
+        return failure();
+    } else {
+      OpPassManager::Nesting nesting = pm.getNesting();
+      pm.setNesting(OpPassManager::Nesting::Explicit);
+      LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
+      pm.setNesting(nesting);
+      if (failed(status))
         return failure();
-    } else if (failed(passIt.pipeline.addToPipeline(pm))) {
-      return failure();
     }
   }
   return success();

diff  --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 7b9470c0d630..18d17de92531 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -62,8 +62,13 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
   pm.enableVerifier(verifyPasses);
   applyPassManagerCLOptions(pm);
 
+  auto errorHandler = [&](const Twine &msg) {
+    emitError(UnknownLoc::get(context)) << msg;
+    return failure();
+  };
+
   // Build the provided pipeline.
-  if (failed(passPipeline.addToPipeline(pm)))
+  if (failed(passPipeline.addToPipeline(pm, errorHandler)))
     return failure();
 
   // Run the pipeline.

diff  --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py
index bd269ce29896..2b682d84b751 100644
--- a/mlir/test/Bindings/Python/pass_manager.py
+++ b/mlir/test/Bindings/Python/pass_manager.py
@@ -66,6 +66,21 @@ def testParseFail():
 run(testParseFail)
 
 
+# Verify failure on incorrect level of nesting.
+# CHECK-LABEL: TEST: testInvalidNesting
+def testInvalidNesting():
+  with Context():
+    try:
+      pm = PassManager.parse("func(print-op-graph)")
+    except ValueError as e:
+      # CHECK: Can't add pass 'PrintOp' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
+      # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
+      log("ValueError exception:", e)
+    else:
+      log("Exception not produced")
+run(testInvalidNesting)
+
+
 # Verify that a pass manager can execute on IR
 # CHECK-LABEL: TEST: testRun
 def testRunPipeline():

diff  --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index bfb24af93027..777c8437f7d8 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -1,11 +1,11 @@
 // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
 // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{test-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
-// RUN: not mlir-opt %s -pass-pipeline='module(test-options-pass{list=3}, test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
+// RUN: not mlir-opt %s -pass-pipeline='module(func(test-options-pass{list=3}), test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
 // RUN: not mlir-opt %s -pass-pipeline='test-options-pass{list=3 list=notaninteger}' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
 // RUN: not mlir-opt %s -pass-pipeline='func(test-options-pass{list=1,2,3,4 list=5 string=value1 string=value2})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s
 // RUN: mlir-opt %s -verify-each=false -pass-pipeline='func(test-options-pass{string-list=a list=1,2,3,4 string-list=b,c list=5 string-list=d string=some_value})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_1 %s
 // RUN: mlir-opt %s -verify-each=false -test-options-pass-pipeline='list=1 string-list=a,b' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_2 %s
-// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(test-options-pass{list=3}, test-options-pass{list=1,2,3,4})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s
+// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(func(test-options-pass{list=3}), func(test-options-pass{list=1,2,3,4}))' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s
 
 // CHECK_ERROR_1: missing closing '}' while processing pass options
 // CHECK_ERROR_2: no such option test-option

diff  --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir
index 118a87d42865..34bb4b8d783d 100644
--- a/mlir/test/Pass/pipeline-parsing.mlir
+++ b/mlir/test/Pass/pipeline-parsing.mlir
@@ -4,12 +4,13 @@
 // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
 // RUN: not mlir-opt %s -pass-pipeline='module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
 // RUN: not mlir-opt %s -pass-pipeline=',' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
+// RUN: not mlir-opt %s -pass-pipeline='func(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s
 
 // CHECK_ERROR_1: encountered unbalanced parentheses while parsing pipeline
 // CHECK_ERROR_2: encountered extra closing ')' creating unbalanced parentheses while parsing pipeline
 // CHECK_ERROR_3: expected ',' after parsing pipeline
 // CHECK_ERROR_4: does not refer to a registered pass or pass pipeline
-
+// CHECK_ERROR_5:  Can't add pass '{{.*}}TestModulePass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
 func @foo() {
   return
 }


        


More information about the Mlir-commits mailing list