[Mlir-commits] [mlir] 3ea4c50 - [mlir][python] Capture error diagnostics in exceptions
Rahul Kayaith
llvmlistbot at llvm.org
Tue Mar 7 11:59:28 PST 2023
Author: Rahul Kayaith
Date: 2023-03-07T14:59:22-05:00
New Revision: 3ea4c5014da3a18b56fea3579bed72c649357f47
URL: https://github.com/llvm/llvm-project/commit/3ea4c5014da3a18b56fea3579bed72c649357f47
DIFF: https://github.com/llvm/llvm-project/commit/3ea4c5014da3a18b56fea3579bed72c649357f47.diff
LOG: [mlir][python] Capture error diagnostics in exceptions
This updates most (all?) error-diagnostic-emitting python APIs to
capture error diagnostics and include them in the raised exception's
message:
```
>>> Operation.parse('"arith.addi"() : () -> ()'))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
mlir._mlir_libs.MLIRError: Unable to parse operation assembly:
error: "-":1:1: 'arith.addi' op requires one result
note: "-":1:1: see current operation: "arith.addi"() : () -> ()
```
The diagnostic information is available on the exception for users who
may want to customize the error message:
```
>>> try:
... Operation.parse('"arith.addi"() : () -> ()')
... except MLIRError as e:
... print(e.message)
... print(e.error_diagnostics)
... print(e.error_diagnostics[0].message)
...
Unable to parse operation assembly
[<mlir._mlir_libs._mlir.ir.DiagnosticInfo object at 0x7fed32bd6b70>]
'arith.addi' op requires one result
```
Error diagnostics captured in exceptions aren't propagated to diagnostic
handlers, to avoid double-reporting of errors. The context-level
`emit_error_diagnostics` option can be used to revert to the old
behaviour, causing error diagnostics to be reported to handlers instead
of as part of exceptions.
API changes:
- `Operation.verify` now raises an exception on verification failure,
instead of returning `false`
- The exception raised by the following methods has been changed to
`MLIRError`:
- `PassManager.run`
- `{Module,Operation,Type,Attribute}.parse`
- `{RankedTensorType,UnrankedTensorType}.get`
- `{MemRefType,UnrankedMemRefType}.get`
- `VectorType.get`
- `FloatAttr.get`
closes #60595
depends on D144804, D143830
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D143869
Added:
mlir/test/python/ir/exception.py
Modified:
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/IRTypes.cpp
mlir/lib/Bindings/Python/Pass.cpp
mlir/python/mlir/_mlir_libs/__init__.py
mlir/test/python/ir/attributes.py
mlir/test/python/ir/builtin_types.py
mlir/test/python/ir/diagnostic_handler.py
mlir/test/python/ir/module.py
mlir/test/python/ir/operation.py
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index c8ede8b06e1e3..b0c35ffb8a53f 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -344,15 +344,10 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
c.def_static(
"get",
[](PyType &type, double value, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirAttributeIsNull(attr)) {
- throw SetPyError(PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(type)).cast<std::string>() +
- "' and expected floating point type.");
- }
+ if (mlirAttributeIsNull(attr))
+ throw MLIRError("Invalid attribute", errors.take());
return PyFloatAttribute(type.getContext(), attr);
},
py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index e03b6470c4dbd..8d637ea2b4abf 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
//#include "mlir-c/Registration.h"
#include "llvm/ADT/ArrayRef.h"
@@ -38,7 +39,7 @@ using llvm::Twine;
static const char kContextParseTypeDocstring[] =
R"(Parses the assembly form of a type.
-Returns a Type object or raises a ValueError if the type cannot be parsed.
+Returns a Type object or raises an MLIRError if the type cannot be parsed.
See also: https://mlir.llvm.org/docs/LangRef/#type-system
)";
@@ -58,7 +59,7 @@ static const char kContextGetNameLocationDocString[] =
static const char kModuleParseDocstring[] =
R"(Parses a module's assembly format from a string.
-Returns a new MlirModule or raises a ValueError if the parsing fails.
+Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";
@@ -654,6 +655,20 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
return pyHandlerObject;
}
+MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag,
+ void *userData) {
+ auto *self = static_cast<ErrorCapture *>(userData);
+ // Check if the context requested we emit errors instead of capturing them.
+ if (self->ctx->emitErrorDiagnostics)
+ return mlirLogicalResultFailure();
+
+ if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError)
+ return mlirLogicalResultFailure();
+
+ self->errors.emplace_back(PyDiagnostic(diag).getInfo());
+ return mlirLogicalResultSuccess();
+}
+
PyMlirContext &DefaultingPyMlirContext::resolve() {
PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
if (!context) {
@@ -870,6 +885,13 @@ py::tuple PyDiagnostic::getNotes() {
return *materializedNotes;
}
+PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() {
+ std::vector<DiagnosticInfo> notes;
+ for (py::handle n : getNotes())
+ notes.emplace_back(n.cast<PyDiagnostic>().getInfo());
+ return {getSeverity(), getLocation(), getMessage(), std::move(notes)};
+}
+
//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
//------------------------------------------------------------------------------
@@ -1062,13 +1084,12 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
const std::string &sourceStr,
const std::string &sourceName) {
+ PyMlirContext::ErrorCapture errors(contextRef);
MlirOperation op =
mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr),
toMlirStringRef(sourceName));
- // TODO: Include error diagnostic messages in the exception message
if (mlirOperationIsNull(op))
- throw py::value_error(
- "Unable to parse operation assembly (see diagnostics)");
+ throw MLIRError("Unable to parse operation assembly", errors.take());
return PyOperation::createDetached(std::move(contextRef), op);
}
@@ -1155,6 +1176,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) {
operation.parentKeepAlive = otherOp.parentKeepAlive;
}
+bool PyOperationBase::verify() {
+ PyOperation &op = getOperation();
+ PyMlirContext::ErrorCapture errors(op.getContext());
+ if (!mlirOperationVerify(op.get()))
+ throw MLIRError("Verification failed", errors.take());
+ return true;
+}
+
std::optional<PyOperationRef> PyOperation::getParentOperation() {
checkValid();
if (!isAttached())
@@ -2287,6 +2316,16 @@ void mlir::python::populateIRCore(py::module &m) {
return self.getMessage();
});
+ py::class_<PyDiagnostic::DiagnosticInfo>(m, "DiagnosticInfo",
+ py::module_local())
+ .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); }))
+ .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity)
+ .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location)
+ .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message)
+ .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes)
+ .def("__str__",
+ [](PyDiagnostic::DiagnosticInfo &self) { return self.message; });
+
py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
.def("detach", &PyDiagnosticHandler::detach)
.def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
@@ -2375,6 +2414,11 @@ void mlir::python::populateIRCore(py::module &m) {
mlirContextAppendDialectRegistry(self.get(), registry);
},
py::arg("registry"))
+ .def_property("emit_error_diagnostics", nullptr,
+ &PyMlirContext::setEmitErrorDiagnostics,
+ "Emit error diagnostics to diagnostic handlers. By default "
+ "error diagnostics are captured and reported through "
+ "MLIRError exceptions.")
.def("load_all_available_dialects", [](PyMlirContext &self) {
mlirContextLoadAllAvailableDialects(self.get());
});
@@ -2566,16 +2610,12 @@ void mlir::python::populateIRCore(py::module &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
.def_static(
"parse",
- [](const std::string moduleAsm, DefaultingPyMlirContext context) {
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirModuleIsNull(module)) {
- throw SetPyError(
- PyExc_ValueError,
- "Unable to parse module assembly (see diagnostics)");
- }
+ if (mlirModuleIsNull(module))
+ throw MLIRError("Unable to parse module assembly", errors.take());
return PyModule::forModule(module).releaseObject();
},
py::arg("asm"), py::arg("context") = py::none(),
@@ -2724,13 +2764,9 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, kOperationGetAsmDocstring)
- .def(
- "verify",
- [](PyOperationBase &self) {
- return mlirOperationVerify(self.getOperation());
- },
- "Verify the operation and return true if it passes, false if it "
- "fails.")
+ .def("verify", &PyOperationBase::verify,
+ "Verify the operation. Raises MLIRError if verification fails, and "
+ "returns true otherwise.")
.def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
"Puts self immediately after the other operation in its parent "
"block.")
@@ -2833,12 +2869,12 @@ void mlir::python::populateIRCore(py::module &m) {
// directly.
std::string clsOpName =
py::cast<std::string>(cls.attr("OPERATION_NAME"));
- MlirStringRef parsedOpName =
+ MlirStringRef identifier =
mlirIdentifierStr(mlirOperationGetName(*parsed.get()));
- if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName)))
- throw py::value_error(
- "Expected a '" + clsOpName + "' op, got: '" +
- std::string(parsedOpName.data, parsedOpName.length) + "'");
+ std::string_view parsedOpName(identifier.data, identifier.length);
+ if (clsOpName != parsedOpName)
+ throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" +
+ parsedOpName + "'");
return PyOpView::constructDerived(cls, *parsed.get());
},
py::arg("cls"), py::arg("source"), py::kw_only(),
@@ -3071,19 +3107,16 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static(
"parse",
[](std::string attrSpec, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute type = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirAttributeIsNull(type)) {
- throw SetPyError(PyExc_ValueError,
- Twine("Unable to parse attribute: '") +
- attrSpec + "'");
- }
+ if (mlirAttributeIsNull(type))
+ throw MLIRError("Unable to parse attribute", errors.take());
return PyAttribute(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
- "Parses an attribute from an assembly form")
+ "Parses an attribute from an assembly form. Raises an MLIRError on "
+ "failure.")
.def_property_readonly(
"context",
[](PyAttribute &self) { return self.getContext().getObject(); },
@@ -3182,15 +3215,11 @@ void mlir::python::populateIRCore(py::module &m) {
.def_static(
"parse",
[](std::string typeSpec, DefaultingPyMlirContext context) {
+ PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(type)) {
- throw SetPyError(PyExc_ValueError,
- Twine("Unable to parse type: '") + typeSpec +
- "'");
- }
+ if (mlirTypeIsNull(type))
+ throw MLIRError("Unable to parse type", errors.take());
return PyType(context->getRef(), type);
},
py::arg("asm"), py::arg("context") = py::none(),
@@ -3342,4 +3371,17 @@ void mlir::python::populateIRCore(py::module &m) {
// Attribute builder getter.
PyAttrBuilderMap::bind(m);
+
+ py::register_local_exception_translator([](std::exception_ptr p) {
+ // We can't define exceptions with custom fields through pybind, so instead
+ // the exception class is defined in python and imported here.
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (const MLIRError &e) {
+ py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr("MLIRError")(e.message, e.errorDiagnostics);
+ PyErr_SetObject(PyExc_Exception, obj.ptr());
+ }
+ });
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 4aced3639127a..fc236b1c68ecb 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -221,6 +221,11 @@ class PyMlirContext {
/// registration object (internally a PyDiagnosticHandler).
pybind11::object attachDiagnosticHandler(pybind11::object callback);
+ /// Controls whether error diagnostics should be propagated to diagnostic
+ /// handlers, instead of being captured by `ErrorCapture`.
+ void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; }
+ struct ErrorCapture;
+
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -248,6 +253,8 @@ class PyMlirContext {
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
LiveOperationMap liveOperations;
+ bool emitErrorDiagnostics = false;
+
MlirContext context;
friend class PyModule;
friend class PyOperation;
@@ -281,6 +288,34 @@ class BaseContextObject {
PyMlirContextRef contextRef;
};
+/// Wrapper around an MlirLocation.
+class PyLocation : public BaseContextObject {
+public:
+ PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
+ : BaseContextObject(std::move(contextRef)), loc(loc) {}
+
+ operator MlirLocation() const { return loc; }
+ MlirLocation get() const { return loc; }
+
+ /// Enter and exit the context manager.
+ pybind11::object contextEnter();
+ void contextExit(const pybind11::object &excType,
+ const pybind11::object &excVal,
+ const pybind11::object &excTb);
+
+ /// Gets a capsule wrapping the void* within the MlirLocation.
+ pybind11::object getCapsule();
+
+ /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
+ /// Note that PyLocation instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirLocation
+ /// is taken by calling this function.
+ static PyLocation createFromCapsule(pybind11::object capsule);
+
+private:
+ MlirLocation loc;
+};
+
/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
/// are only valid for the duration of a diagnostic callback and attempting
/// to access them outside of that will raise an exception. This applies to
@@ -295,6 +330,16 @@ class PyDiagnostic {
pybind11::str getMessage();
pybind11::tuple getNotes();
+ /// Materialized diagnostic information. This is safe to access outside the
+ /// diagnostic callback.
+ struct DiagnosticInfo {
+ MlirDiagnosticSeverity severity;
+ PyLocation location;
+ std::string message;
+ std::vector<DiagnosticInfo> notes;
+ };
+ DiagnosticInfo getInfo();
+
private:
MlirDiagnostic diagnostic;
@@ -351,6 +396,30 @@ class PyDiagnosticHandler {
friend class PyMlirContext;
};
+/// RAII object that captures any error diagnostics emitted to the provided
+/// context.
+struct PyMlirContext::ErrorCapture {
+ ErrorCapture(PyMlirContextRef ctx)
+ : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler(
+ ctx->get(), handler, /*userData=*/this,
+ /*deleteUserData=*/nullptr)) {}
+ ~ErrorCapture() {
+ mlirContextDetachDiagnosticHandler(ctx->get(), handlerID);
+ assert(errors.empty() && "unhandled captured errors");
+ }
+
+ std::vector<PyDiagnostic::DiagnosticInfo> take() {
+ return std::move(errors);
+ };
+
+private:
+ PyMlirContextRef ctx;
+ MlirDiagnosticHandlerID handlerID;
+ std::vector<PyDiagnostic::DiagnosticInfo> errors;
+
+ static MlirLogicalResult handler(MlirDiagnostic diag, void *userData);
+};
+
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
/// order to
diff erentiate it from the `Dialect` base class which is extended by
/// plugins which extend dialect functionality through extension python code.
@@ -416,34 +485,6 @@ class PyDialectRegistry {
MlirDialectRegistry registry;
};
-/// Wrapper around an MlirLocation.
-class PyLocation : public BaseContextObject {
-public:
- PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
- : BaseContextObject(std::move(contextRef)), loc(loc) {}
-
- operator MlirLocation() const { return loc; }
- MlirLocation get() const { return loc; }
-
- /// Enter and exit the context manager.
- pybind11::object contextEnter();
- void contextExit(const pybind11::object &excType,
- const pybind11::object &excVal,
- const pybind11::object &excTb);
-
- /// Gets a capsule wrapping the void* within the MlirLocation.
- pybind11::object getCapsule();
-
- /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
- /// Note that PyLocation instances are uniqued, so the returned object
- /// may be a pre-existing object. Ownership of the underlying MlirLocation
- /// is taken by calling this function.
- static PyLocation createFromCapsule(pybind11::object capsule);
-
-private:
- MlirLocation loc;
-};
-
/// Used in function arguments when None should resolve to the current context
/// manager set instance.
class DefaultingPyLocation
@@ -519,6 +560,10 @@ class PyOperationBase {
void moveAfter(PyOperationBase &other);
void moveBefore(PyOperationBase &other);
+ /// Verify the operation. Throws `MLIRError` if verification fails, and
+ /// returns `true` otherwise.
+ bool verify();
+
/// Each must provide access to the raw Operation.
virtual PyOperation &getOperation() = 0;
};
@@ -1073,6 +1118,16 @@ class PySymbolTable {
MlirSymbolTable symbolTable;
};
+/// Custom exception that allows access to error diagnostic information. This is
+/// converted to the `ir.MLIRError` python exception when thrown.
+struct MLIRError {
+ MLIRError(llvm::Twine message,
+ std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
+ : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {}
+ std::string message;
+ std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
+};
+
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index 87ffe593655b2..2166bab902a13 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -407,17 +407,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
"get",
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point or integer type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyVectorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
@@ -438,20 +432,12 @@ class PyRankedTensorType
"get",
[](std::vector<int64_t> shape, PyType &elementType,
std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(), elementType,
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyRankedTensorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
@@ -479,18 +465,10 @@ class PyUnrankedTensorType
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyUnrankedTensorType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("loc") = py::none(),
@@ -511,23 +489,15 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
MlirAttribute memSpaceAttr =
memorySpace ? *memorySpace : mlirAttributeGetNull();
MlirType t =
mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
shape.data(), layoutAttr, memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyMemRefType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
@@ -570,23 +540,15 @@ class PyUnrankedMemRefType
"get",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
+ PyMlirContext::ErrorCapture errors(loc->getContext());
MlirAttribute memSpaceAttr = {};
if (memorySpace)
memSpaceAttr = *memorySpace;
MlirType t =
mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
- // TODO: Rework error reporting once diagnostic engine is exposed
- // in C API.
- if (mlirTypeIsNull(t)) {
- throw SetPyError(
- PyExc_ValueError,
- Twine("invalid '") +
- py::repr(py::cast(elementType)).cast<std::string>() +
- "' and expected floating point, integer, vector or "
- "complex "
- "type.");
- }
+ if (mlirTypeIsNull(t))
+ throw MLIRError("Invalid type", errors.take());
return PyUnrankedMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("memory_space"),
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 7e90d8be66cb6..79c53084e9260 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -117,15 +117,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
+ PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
- throw SetPyError(PyExc_RuntimeError,
- "Failure while executing pass pipeline.");
+ throw MLIRError("Failure while executing pass pipeline",
+ errors.take());
},
py::arg("operation"),
- "Run the pass manager on the provided operation, throw a "
- "RuntimeError on failure.")
+ "Run the pass manager on the provided operation, raising an "
+ "MLIRError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 9ceeef81844c2..7d3d1f6ca873a 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -100,8 +100,29 @@ def __init__(self, *args, **kwargs):
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
self.load_all_available_dialects()
-
ir.Context = Context
+ class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ':'
+ for diag in self.error_diagnostics:
+ s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ')
+ for note in diag.notes:
+ s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ')
+ return s
+ ir.MLIRError = MLIRError
+
_site_initialize()
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 684d52c3ae28f..1e1589d6d5f4c 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -28,16 +28,17 @@ def testParsePrint():
# CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
@run
def testParseError():
with Context():
try:
t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
- except ValueError as e:
- # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
- print("testParseError:", e)
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse attribute:
+ # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value
+ # CHECK: >
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
@@ -180,8 +181,9 @@ def testFloatAttr():
try:
fattr_invalid = FloatAttr.get(
IntegerType.get_signless(32), 42)
- except ValueError as e:
- # CHECK: invalid 'Type(i32)' and expected floating point type.
+ except MLIRError as e:
+ # CHECK: Invalid attribute:
+ # CHECK: error: unknown: expected floating point type
print(e)
else:
print("Exception not produced")
diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 7af81859e3e7a..594cc6620e396 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -26,16 +26,17 @@ def testParsePrint():
# CHECK-LABEL: TEST: testParseError
-# TODO: Hook the diagnostic manager to capture a more meaningful error
-# message.
@run
def testParseError():
ctx = Context()
try:
t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx)
- except ValueError as e:
- # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST'
- print("testParseError:", e)
+ except MLIRError as e:
+ # CHECK: testParseError: <
+ # CHECK: Unable to parse type:
+ # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type
+ # CHECK: >
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
@@ -292,8 +293,9 @@ def testVectorType():
none = NoneType.get()
try:
vector_invalid = VectorType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point or integer type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
print(e)
else:
print("Exception not produced")
@@ -313,9 +315,9 @@ def testRankedTensorType():
none = NoneType.get()
try:
tensor_invalid = RankedTensorType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
print(e)
else:
print("Exception not produced")
@@ -361,9 +363,9 @@ def testUnrankedTensorType():
none = NoneType.get()
try:
tensor_invalid = UnrankedTensorType.get(none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid tensor element type: 'none'
print(e)
else:
print("Exception not produced")
@@ -400,9 +402,9 @@ def testMemRefType():
none = NoneType.get()
try:
memref_invalid = MemRefType.get(shape, none)
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
print(e)
else:
print("Exception not produced")
@@ -444,9 +446,9 @@ def testUnrankedMemRefType():
none = NoneType.get()
try:
memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2"))
- except ValueError as e:
- # CHECK: invalid 'Type(none)' and expected floating point, integer, vector
- # CHECK: or complex type.
+ except MLIRError as e:
+ # CHECK: Invalid type:
+ # CHECK: error: unknown: invalid memref element type
print(e)
else:
print("Exception not produced")
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index d973db24c366b..cc07f6eaf56ed 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -89,6 +89,7 @@ def callback(d):
@run
def testDiagnosticNonEmptyNotes():
ctx = Context()
+ ctx.emit_error_diagnostics = True
def callback(d):
# CHECK: DIAGNOSTIC:
# CHECK: message='arith.addi' op requires one result
@@ -99,7 +100,10 @@ def callback(d):
return True
handler = ctx.attach_diagnostic_handler(callback)
loc = Location.unknown(ctx)
- Operation.create('arith.addi', loc=loc).verify()
+ try:
+ Operation.create('arith.addi', loc=loc).verify()
+ except MLIRError:
+ pass
assert not handler.had_error
# CHECK-LABEL: TEST: testDiagnosticCallbackException
diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py
new file mode 100644
index 0000000000000..6cb2375a13247
--- /dev/null
+++ b/mlir/test/python/ir/exception.py
@@ -0,0 +1,77 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
+
+# CHECK-LABEL: TEST: test_exception
+ at run
+def test_exception():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ try:
+ Operation.parse("""
+ func.func @foo() {
+ "test.use"(%0) : (i64) -> () loc("use")
+ %0 = "test.def"() : () -> i64 loc("def")
+ return
+ }
+ """, context=ctx)
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Unable to parse operation assembly:
+ # CHECK: error: "use": operand #0 does not dominate this use
+ # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: note: "def": operand defined here (op in the same block)
+ # CHECK: >
+ print(f"Exception: <{e}>")
+
+ # CHECK: message: Unable to parse operation assembly
+ print(f"message: {e.message}")
+
+ # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use
+ # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> ()
+ # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block)
+ print("error_diagnostics[0]: ", e.error_diagnostics[0].location, e.error_diagnostics[0].message)
+ print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message)
+ print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message)
+
+
+# CHECK-LABEL: test_emit_error_diagnostics
+ at run
+def test_emit_error_diagnostics():
+ ctx = Context()
+ loc = Location.unknown(ctx)
+ handler_diags = []
+ def handler(d):
+ handler_diags.append(str(d))
+ return True
+ ctx.attach_diagnostic_handler(handler)
+
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=False:
+ # CHECK: e.error_diagnostics: ['expected attribute value']
+ # CHECK: handler_diags: []
+ print(f"emit_error_diagnostics=False:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
+
+ ctx.emit_error_diagnostics = True
+ try:
+ Attribute.parse("not an attr", ctx)
+ except MLIRError as e:
+ # CHECK: emit_error_diagnostics=True:
+ # CHECK: e.error_diagnostics: []
+ # CHECK: handler_diags: ['expected attribute value']
+ print(f"emit_error_diagnostics=True:")
+ print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}")
+ print(f"handler_diags: {handler_diags}")
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index f0b62435f514b..2d00923683339 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -28,14 +28,17 @@ def testParseSuccess():
# Verify parse error.
# CHECK-LABEL: TEST: testParseError
-# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+# CHECK: testParseError: <
+# CHECK: Unable to parse module assembly:
+# CHECK: error: "-":1:1: expected operation name in quotes
+# CHECK: >
@run
def testParseError():
ctx = Context()
try:
module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
- except ValueError as e:
- print("testParseError:", e)
+ except MLIRError as e:
+ print(f"testParseError: <{e}>")
else:
print("Exception not produced")
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index be7467d12ff13..941420e8d1ff1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -685,8 +685,19 @@ def testInvalidOperationStrSoftFails():
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(invalid_op)
- # CHECK: .verify = False
- print(f".verify = {invalid_op.operation.verify()}")
+ try:
+ invalid_op.verify()
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Verification failed:
+ # CHECK: error: unknown: 'builtin.module' op requires one region
+ # CHECK: note: unknown: see current operation:
+ # CHECK: "builtin.module"() ({
+ # CHECK: ^bb0:
+ # CHECK: }, {
+ # CHECK: }) : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")
# CHECK-LABEL: TEST: testInvalidModuleStrSoftFails
@@ -920,7 +931,7 @@ def testOperationParse():
assert isinstance(m, ModuleOp)
try:
ModuleOp.parse('"test.foo"() : () -> ()')
- except ValueError as e:
+ except MLIRError as e:
# CHECK: error: Expected a 'builtin.module' op, got: 'test.foo'
print(f"error: {e}")
else:
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index b3acd359a207d..8b276537dddcc 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -118,7 +118,7 @@ def testInvalidNesting():
# Verify that a pass manager can execute on IR
-# CHECK-LABEL: TEST: testRun
+# CHECK-LABEL: TEST: testRunPipeline
def testRunPipeline():
with Context():
pm = PassManager.parse("any(print-op-stats{json=false})")
@@ -128,3 +128,20 @@ def testRunPipeline():
# CHECK: func.func , 1
# CHECK: func.return , 1
run(testRunPipeline)
+
+# CHECK-LABEL: TEST: testRunPipelineError
+ at run
+def testRunPipelineError():
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ op = Operation.parse('"test.op"() : () -> ()')
+ pm = PassManager.parse("any(cse)")
+ try:
+ pm.run(op)
+ except MLIRError as e:
+ # CHECK: Exception: <
+ # CHECK: Failure while executing pass pipeline:
+ # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation
+ # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> ()
+ # CHECK: >
+ print(f"Exception: <{e}>")
More information about the Mlir-commits
mailing list