[Mlir-commits] [mlir] b3c5f6b - [mlir][python] Include pipeline parse errors in exception message
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 27 10:05:44 PDT 2022
Author: rkayaith
Date: 2022-10-27T13:05:38-04:00
New Revision: b3c5f6b15b1eaa2552ce62329208ece5166356fe
URL: https://github.com/llvm/llvm-project/commit/b3c5f6b15b1eaa2552ce62329208ece5166356fe
DIFF: https://github.com/llvm/llvm-project/commit/b3c5f6b15b1eaa2552ce62329208ece5166356fe.diff
LOG: [mlir][python] Include pipeline parse errors in exception message
Currently any errors during pipeline parsing are reported to stderr.
This adds a new pipeline parsing function to the C api that reports
errors through a callback, and updates the python bindings to use it.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D136402
Added:
Modified:
mlir/include/mlir-c/Pass.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/lib/CAPI/IR/Pass.cpp
mlir/test/CAPI/pass.c
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index b66bdfe024905..6f281b6dc7aa1 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -105,6 +105,13 @@ MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager,
MLIR_CAPI_EXPORTED void
mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass);
+/// Parse a sequence of textual MLIR pass pipeline elements and add them to the
+/// provided OpPassManager. If parsing fails an error message is reported using
+/// the provided callback.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline(
+ MlirOpPassManager passManager, MlirStringRef pipelineElements,
+ MlirStringCallback callback, void *userData);
+
/// Print a textual MLIR pass pipeline by sending chunks of the string
/// representation and forwarding `userData to `callback`. Note that the
/// callback may be called several times with consecutive chunks of the string.
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 3278d3a911c21..99d67582d1780 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -82,15 +82,15 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
py::arg("enable"), "Enable / disable verify-each.")
.def_static(
"parse",
- [](const std::string pipeline, DefaultingPyMlirContext context) {
+ [](const std::string &pipeline, DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreate(context->get());
- MlirLogicalResult status = mlirParsePassPipeline(
+ PyPrintAccumulator errorMsg;
+ MlirLogicalResult status = mlirOpPassManagerAddPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
- mlirStringRefCreate(pipeline.data(), pipeline.size()));
+ mlirStringRefCreate(pipeline.data(), pipeline.size()),
+ errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
- throw SetPyError(PyExc_ValueError,
- llvm::Twine("invalid pass pipeline '") +
- pipeline + "'.");
+ throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
return new PyPassManager(passManager);
},
py::arg("pipeline"), py::arg("context") = py::none(),
diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp
index a2998939a1586..398abfee2ba1b 100644
--- a/mlir/lib/CAPI/IR/Pass.cpp
+++ b/mlir/lib/CAPI/IR/Pass.cpp
@@ -65,6 +65,15 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager,
unwrap(passManager)->addPass(std::unique_ptr<Pass>(unwrap(pass)));
}
+MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager,
+ MlirStringRef pipelineElements,
+ MlirStringCallback callback,
+ void *userData) {
+ detail::CallbackOstream stream(callback, userData);
+ return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager),
+ stream));
+}
+
void mlirPrintPassPipeline(MlirOpPassManager passManager,
MlirStringCallback callback, void *userData) {
detail::CallbackOstream stream(callback, userData);
diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c
index 4c68a3ef7963b..966bcaf8caeac 100644
--- a/mlir/test/CAPI/pass.c
+++ b/mlir/test/CAPI/pass.c
@@ -133,6 +133,11 @@ static void printToStderr(MlirStringRef str, void *userData) {
fwrite(str.data, 1, str.length, stderr);
}
+static void dontPrint(MlirStringRef str, void *userData) {
+ (void)str;
+ (void)userData;
+}
+
void testPrintPassPipeline() {
MlirContext ctx = mlirContextCreate();
MlirPassManager pm = mlirPassManagerCreate(ctx);
@@ -176,8 +181,7 @@ void testParsePassPipeline() {
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
mlirStringRefCreateFromCString(
- "builtin.module(func.func(print-op-stats{json=false}),"
- " func.func(print-op-stats{json=false}))"));
+ "builtin.module(func.func(print-op-stats{json=false}))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsSuccess(status)) {
fprintf(
@@ -190,8 +194,7 @@ void testParsePassPipeline() {
status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(pm),
mlirStringRefCreateFromCString(
- "builtin.module(func.func(print-op-stats{json=false}),"
- " func.func(print-op-stats{json=false}))"));
+ "builtin.module(func.func(print-op-stats{json=false}))"));
// Expect a failure, we haven't registered the print-op-stats pass yet.
if (mlirLogicalResultIsFailure(status)) {
fprintf(stderr,
@@ -199,14 +202,61 @@ void testParsePassPipeline() {
exit(EXIT_FAILURE);
}
- // CHECK: Round-trip: builtin.module(builtin.module(
- // CHECK-SAME: func.func(print-op-stats{json=false}),
- // CHECK-SAME: func.func(print-op-stats{json=false})
- // CHECK-SAME: ))
+ // CHECK: Round-trip: builtin.module(
+ // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false}))
+ // CHECK-SAME: )
fprintf(stderr, "Round-trip: ");
mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
NULL);
fprintf(stderr, "\n");
+
+ // Try appending a pass:
+ status = mlirOpPassManagerAddPipeline(
+ mlirPassManagerGetAsOpPassManager(pm),
+ mlirStringRefCreateFromCString("func.func(print-op-stats{json=false})"),
+ printToStderr, NULL);
+ if (mlirLogicalResultIsFailure(status)) {
+ fprintf(stderr, "Unexpected failure appending pipeline\n");
+ exit(EXIT_FAILURE);
+ }
+ // CHECK: Appended: builtin.module(
+ // CHECK-SAME: builtin.module(func.func(print-op-stats{json=false})),
+ // CHECK-SAME: func.func(print-op-stats{json=false})
+ // CHECK-SAME: )
+ fprintf(stderr, "Appended: ");
+ mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr,
+ NULL);
+ fprintf(stderr, "\n");
+
+ mlirPassManagerDestroy(pm);
+ mlirContextDestroy(ctx);
+}
+
+void testParseErrorCapture() {
+ // CHECK-LABEL: testParseErrorCapture:
+ fprintf(stderr, "\nTEST: testParseErrorCapture:\n");
+
+ MlirContext ctx = mlirContextCreate();
+ MlirPassManager pm = mlirPassManagerCreate(ctx);
+ MlirOpPassManager opm = mlirPassManagerGetAsOpPassManager(pm);
+ MlirStringRef invalidPipeline = mlirStringRefCreateFromCString("invalid");
+
+ // CHECK: mlirOpPassManagerAddPipeline:
+ // CHECK: 'invalid' does not refer to a registered pass or pass pipeline
+ fprintf(stderr, "mlirOpPassManagerAddPipeline:\n");
+ if (mlirLogicalResultIsSuccess(mlirOpPassManagerAddPipeline(
+ opm, invalidPipeline, printToStderr, NULL)))
+ exit(EXIT_FAILURE);
+ fprintf(stderr, "\n");
+
+ // Make sure all output is going through the callback.
+ // CHECK: dontPrint: <>
+ fprintf(stderr, "dontPrint: <");
+ if (mlirLogicalResultIsSuccess(
+ mlirOpPassManagerAddPipeline(opm, invalidPipeline, dontPrint, NULL)))
+ exit(EXIT_FAILURE);
+ fprintf(stderr, ">\n");
+
mlirPassManagerDestroy(pm);
mlirContextDestroy(ctx);
}
@@ -534,6 +584,7 @@ int main() {
testRunPassOnNestedModule();
testPrintPassPipeline();
testParsePassPipeline();
+ testParseErrorCapture();
testExternalPass();
return 0;
}
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index df55f205c541d..a2d56a1f6e031 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -36,10 +36,8 @@ def testParseSuccess():
# An unregistered pass should not parse.
try:
pm = PassManager.parse("builtin.module(func.func(not-existing-pass{json=false}))")
- # TODO: this error should be propagate to Python but the C API does not help right now.
- # CHECK: error: 'not-existing-pass' does not refer to a registered pass or pass pipeline
except ValueError as e:
- # CHECK: ValueError exception: invalid pass pipeline 'builtin.module(func.func(not-existing-pass{json=false}))'.
+ # CHECK: ValueError exception: {{.+}} 'not-existing-pass' does not refer to a registered pass
log("ValueError exception:", e)
else:
log("Exception not produced")
@@ -57,7 +55,10 @@ def testParseFail():
try:
pm = PassManager.parse("unknown-pass")
except ValueError as e:
- # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
+ # CHECK: ValueError exception: MLIR Textual PassPipeline Parser:1:1: error:
+ # CHECK-SAME: 'unknown-pass' does not refer to a registered pass or pass pipeline
+ # CHECK: unknown-pass
+ # CHECK: ^
log("ValueError exception:", e)
else:
log("Exception not produced")
@@ -71,8 +72,7 @@ def testInvalidNesting():
try:
pm = PassManager.parse("func.func(normalize-memrefs)")
except ValueError as e:
- # CHECK: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
- # CHECK: ValueError exception: invalid pass pipeline 'func.func(normalize-memrefs)'.
+ # CHECK: ValueError exception: Can't add pass 'NormalizeMemRefs' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest?
log("ValueError exception:", e)
else:
log("Exception not produced")
More information about the Mlir-commits
mailing list