[Mlir-commits] [mlir] a62d38a - Disable implicit nesting on parsing textual pass pipeline
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 11 11:21:59 PST 2020
Author: Mehdi Amini
Date: 2020-11-11T19:21:51Z
New Revision: a62d38a90d2feee3955df032f22a02a09a42cd44
URL: https://github.com/llvm/llvm-project/commit/a62d38a90d2feee3955df032f22a02a09a42cd44
DIFF: https://github.com/llvm/llvm-project/commit/a62d38a90d2feee3955df032f22a02a09a42cd44.diff
LOG: Disable implicit nesting on parsing textual pass pipeline
Previous the textual form of the pass pipeline would implicitly nest,
instead we opt for the explicit form here: this has less surprise.
This also avoids asserting in the bindings when passing a pass pipeline
with incorrect nesting.
Differential Revision: https://reviews.llvm.org/D91233
Added:
Modified:
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Pass/PassRegistry.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Support/MlirOptMain.cpp
mlir/test/Bindings/Python/pass_manager.py
mlir/test/Pass/pipeline-options-parsing.mlir
mlir/test/Pass/pipeline-parsing.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 3080fc35a153..33ffb542c56b 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -118,6 +118,14 @@ class OpPassManager {
/// documentation for the same method on the Pass class.
void getDependentDialects(DialectRegistry &dialects) const;
+ /// Enable or disable the implicit nesting on this particular PassManager.
+ /// This will also apply to any newly nested PassManager built from this
+ /// instance.
+ void setNesting(Nesting nesting);
+
+ /// Return the current nesting mode.
+ Nesting getNesting();
+
private:
/// A pointer to an internal implementation instance.
std::unique_ptr<detail::OpPassManagerImpl> impl;
diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h
index 3187d8eb34c7..52ded58ccfe6 100644
--- a/mlir/include/mlir/Pass/PassRegistry.h
+++ b/mlir/include/mlir/Pass/PassRegistry.h
@@ -28,8 +28,12 @@ class PassOptions;
/// A registry function that adds passes to the given pass manager. This should
/// also parse options and return success() if parsing succeeded.
-using PassRegistryFunction =
- std::function<LogicalResult(OpPassManager &, StringRef options)>;
+/// `errorHandler` is a functor used to emit errors during parsing.
+/// parameter corresponds to the raw location within the pipeline string. This
+/// should always return failure.
+using PassRegistryFunction = std::function<LogicalResult(
+ OpPassManager &, StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler)>;
using PassAllocatorFunction = std::function<std::unique_ptr<Pass>()>;
//===----------------------------------------------------------------------===//
@@ -43,10 +47,12 @@ class PassRegistryEntry {
/// Adds this pass registry entry to the given pass manager. `options` is
/// an opaque string that will be parsed by the builder. The success of
/// parsing will be returned.
- LogicalResult addToPipeline(OpPassManager &pm, StringRef options) const {
+ LogicalResult
+ addToPipeline(OpPassManager &pm, StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const {
assert(builder &&
"cannot call addToPipeline on PassRegistryEntry without builder");
- return builder(pm, options);
+ return builder(pm, options, errorHandler);
}
/// Returns the command line option that may be passed to 'mlir-opt' that will
@@ -163,7 +169,8 @@ struct PassPipelineRegistration {
std::function<void(OpPassManager &, const Options &options)> builder) {
registerPassPipeline(
arg, description,
- [builder](OpPassManager &pm, StringRef optionsStr) {
+ [builder](OpPassManager &pm, StringRef optionsStr,
+ function_ref<LogicalResult(const Twine &)> errorHandler) {
Options options;
if (failed(options.parseFromString(optionsStr)))
return failure();
@@ -183,7 +190,8 @@ template <> struct PassPipelineRegistration<EmptyPipelineOptions> {
std::function<void(OpPassManager &)> builder) {
registerPassPipeline(
arg, description,
- [builder](OpPassManager &pm, StringRef optionsStr) {
+ [builder](OpPassManager &pm, StringRef optionsStr,
+ function_ref<LogicalResult(const Twine &)> errorHandler) {
if (!optionsStr.empty())
return failure();
builder(pm);
@@ -230,7 +238,9 @@ class PassPipelineCLParser {
/// Adds the passes defined by this parser entry to the given pass manager.
/// Returns failure() if the pass could not be properly constructed due
/// to options parsing.
- LogicalResult addToPipeline(OpPassManager &pm) const;
+ LogicalResult
+ addToPipeline(OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const;
private:
std::unique_ptr<detail::PassPipelineCLParserImpl> impl;
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index a1aecf6ffb20..8682171834ca 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -348,6 +348,10 @@ void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
registerDialectsForPipeline(*this, dialects);
}
+OpPassManager::Nesting OpPassManager::getNesting() { return impl->nesting; }
+
+void OpPassManager::setNesting(Nesting nesting) { impl->nesting = nesting; }
+
//===----------------------------------------------------------------------===//
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 442233024bbe..78e40d5b0aa7 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -27,9 +27,16 @@ static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
/// Utility to create a default registry function from a pass instance.
static PassRegistryFunction
buildDefaultRegistryFn(const PassAllocatorFunction &allocator) {
- return [=](OpPassManager &pm, StringRef options) {
+ return [=](OpPassManager &pm, StringRef options,
+ function_ref<LogicalResult(const Twine &)> errorHandler) {
std::unique_ptr<Pass> pass = allocator();
LogicalResult result = pass->initializeOptions(options);
+ if ((pm.getNesting() == OpPassManager::Nesting::Explicit) &&
+ pass->getOpName() && *pass->getOpName() != pm.getOpName())
+ return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() +
+ "' restricted to '" + *pass->getOpName() +
+ "' on a PassManager intended to run on '" +
+ pm.getOpName() + "', did you intend to nest?");
pm.addPass(std::move(pass));
return result;
};
@@ -229,7 +236,9 @@ class TextualPipeline {
LogicalResult initialize(StringRef text, raw_ostream &errorStream);
/// Add the internal pipeline elements to the provided pass manager.
- LogicalResult addToPipeline(OpPassManager &pm) const;
+ LogicalResult
+ addToPipeline(OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const;
private:
/// A functor used to emit errors found during pipeline handling. The first
@@ -269,8 +278,9 @@ class TextualPipeline {
ErrorHandlerT errorHandler);
/// Add the given pipeline elements to the provided pass manager.
- LogicalResult addToPipeline(ArrayRef<PipelineElement> elements,
- OpPassManager &pm) const;
+ LogicalResult
+ addToPipeline(ArrayRef<PipelineElement> elements, OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const;
std::vector<PipelineElement> pipeline;
};
@@ -299,8 +309,10 @@ LogicalResult TextualPipeline::initialize(StringRef text,
}
/// Add the internal pipeline elements to the provided pass manager.
-LogicalResult TextualPipeline::addToPipeline(OpPassManager &pm) const {
- return addToPipeline(pipeline, pm);
+LogicalResult TextualPipeline::addToPipeline(
+ OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const {
+ return addToPipeline(pipeline, pm, errorHandler);
}
/// Parse the given pipeline text into the internal pipeline vector. This
@@ -397,7 +409,6 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
// pipeline.
if (!element.innerPipeline.empty())
return resolvePipelineElements(element.innerPipeline, errorHandler);
-
// Otherwise, this must be a pass or pass pipeline.
// Check to see if a pipeline was registered with this name.
auto pipelineRegistryIt = passPipelineRegistry->find(element.name);
@@ -422,13 +433,16 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
}
/// Add the given pipeline elements to the provided pass manager.
-LogicalResult TextualPipeline::addToPipeline(ArrayRef<PipelineElement> elements,
- OpPassManager &pm) const {
+LogicalResult TextualPipeline::addToPipeline(
+ ArrayRef<PipelineElement> elements, OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const {
for (auto &elt : elements) {
if (elt.registryEntry) {
- if (failed(elt.registryEntry->addToPipeline(pm, elt.options)))
+ if (failed(
+ elt.registryEntry->addToPipeline(pm, elt.options, errorHandler)))
return failure();
- } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name)))) {
+ } else if (failed(addToPipeline(elt.innerPipeline, pm.nest(elt.name),
+ errorHandler))) {
return failure();
}
}
@@ -444,7 +458,11 @@ LogicalResult mlir::parsePassPipeline(StringRef pipeline, OpPassManager &pm,
TextualPipeline pipelineParser;
if (failed(pipelineParser.initialize(pipeline, errorStream)))
return failure();
- if (failed(pipelineParser.addToPipeline(pm)))
+ auto errorHandler = [&](Twine msg) {
+ errorStream << msg << "\n";
+ return failure();
+ };
+ if (failed(pipelineParser.addToPipeline(pm, errorHandler)))
return failure();
return success();
}
@@ -634,13 +652,21 @@ bool PassPipelineCLParser::contains(const PassRegistryEntry *entry) const {
}
/// Adds the passes defined by this parser entry to the given pass manager.
-LogicalResult PassPipelineCLParser::addToPipeline(OpPassManager &pm) const {
+LogicalResult PassPipelineCLParser::addToPipeline(
+ OpPassManager &pm,
+ function_ref<LogicalResult(const Twine &)> errorHandler) const {
for (auto &passIt : impl->passList) {
if (passIt.registryEntry) {
- if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options)))
+ if (failed(passIt.registryEntry->addToPipeline(pm, passIt.options,
+ errorHandler)))
+ return failure();
+ } else {
+ OpPassManager::Nesting nesting = pm.getNesting();
+ pm.setNesting(OpPassManager::Nesting::Explicit);
+ LogicalResult status = passIt.pipeline.addToPipeline(pm, errorHandler);
+ pm.setNesting(nesting);
+ if (failed(status))
return failure();
- } else if (failed(passIt.pipeline.addToPipeline(pm))) {
- return failure();
}
}
return success();
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 7b9470c0d630..18d17de92531 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -62,8 +62,13 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
pm.enableVerifier(verifyPasses);
applyPassManagerCLOptions(pm);
+ auto errorHandler = [&](const Twine &msg) {
+ emitError(UnknownLoc::get(context)) << msg;
+ return failure();
+ };
+
// Build the provided pipeline.
- if (failed(passPipeline.addToPipeline(pm)))
+ if (failed(passPipeline.addToPipeline(pm, errorHandler)))
return failure();
// Run the pipeline.
diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py
index bd269ce29896..2b682d84b751 100644
--- a/mlir/test/Bindings/Python/pass_manager.py
+++ b/mlir/test/Bindings/Python/pass_manager.py
@@ -66,6 +66,21 @@ def testParseFail():
run(testParseFail)
+# Verify failure on incorrect level of nesting.
+# CHECK-LABEL: TEST: testInvalidNesting
+def testInvalidNesting():
+ with Context():
+ try:
+ pm = PassManager.parse("func(print-op-graph)")
+ except ValueError as e:
+ # CHECK: Can't add pass 'PrintOp' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
+ # CHECK: ValueError exception: invalid pass pipeline 'func(print-op-graph)'.
+ log("ValueError exception:", e)
+ else:
+ log("Exception not produced")
+run(testInvalidNesting)
+
+
# Verify that a pass manager can execute on IR
# CHECK-LABEL: TEST: testRun
def testRunPipeline():
diff --git a/mlir/test/Pass/pipeline-options-parsing.mlir b/mlir/test/Pass/pipeline-options-parsing.mlir
index bfb24af93027..777c8437f7d8 100644
--- a/mlir/test/Pass/pipeline-options-parsing.mlir
+++ b/mlir/test/Pass/pipeline-options-parsing.mlir
@@ -1,11 +1,11 @@
// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass{test-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
-// RUN: not mlir-opt %s -pass-pipeline='module(test-options-pass{list=3}, test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
+// RUN: not mlir-opt %s -pass-pipeline='module(func(test-options-pass{list=3}), test-module-pass{invalid-option=3})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
// RUN: not mlir-opt %s -pass-pipeline='test-options-pass{list=3 list=notaninteger}' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
// RUN: not mlir-opt %s -pass-pipeline='func(test-options-pass{list=1,2,3,4 list=5 string=value1 string=value2})' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s
// RUN: mlir-opt %s -verify-each=false -pass-pipeline='func(test-options-pass{string-list=a list=1,2,3,4 string-list=b,c list=5 string-list=d string=some_value})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_1 %s
// RUN: mlir-opt %s -verify-each=false -test-options-pass-pipeline='list=1 string-list=a,b' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_2 %s
-// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(test-options-pass{list=3}, test-options-pass{list=1,2,3,4})' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s
+// RUN: mlir-opt %s -verify-each=false -pass-pipeline='module(func(test-options-pass{list=3}), func(test-options-pass{list=1,2,3,4}))' -test-dump-pipeline 2>&1 | FileCheck --check-prefix=CHECK_3 %s
// CHECK_ERROR_1: missing closing '}' while processing pass options
// CHECK_ERROR_2: no such option test-option
diff --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir
index 118a87d42865..34bb4b8d783d 100644
--- a/mlir/test/Pass/pipeline-parsing.mlir
+++ b/mlir/test/Pass/pipeline-parsing.mlir
@@ -4,12 +4,13 @@
// RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
// RUN: not mlir-opt %s -pass-pipeline='module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s
// RUN: not mlir-opt %s -pass-pipeline=',' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s
+// RUN: not mlir-opt %s -pass-pipeline='func(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s
// CHECK_ERROR_1: encountered unbalanced parentheses while parsing pipeline
// CHECK_ERROR_2: encountered extra closing ')' creating unbalanced parentheses while parsing pipeline
// CHECK_ERROR_3: expected ',' after parsing pipeline
// CHECK_ERROR_4: does not refer to a registered pass or pass pipeline
-
+// CHECK_ERROR_5: Can't add pass '{{.*}}TestModulePass' restricted to 'module' on a PassManager intended to run on 'func', did you intend to nest?
func @foo() {
return
}
More information about the Mlir-commits
mailing list