[Mlir-commits] [mlir] 7ee25bc - [mlir][python] Add bindings for diagnostic handler.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Jan 4 11:12:13 PST 2022
Author: Stella Laurenzo
Date: 2022-01-04T11:04:37-08:00
New Revision: 7ee25bc56f92495eb6d289b5ec18a07f27f1f44b
URL: https://github.com/llvm/llvm-project/commit/7ee25bc56f92495eb6d289b5ec18a07f27f1f44b
DIFF: https://github.com/llvm/llvm-project/commit/7ee25bc56f92495eb6d289b5ec18a07f27f1f44b.diff
LOG: [mlir][python] Add bindings for diagnostic handler.
I considered multiple approaches for this but settled on this one because I could make the lifetime management work in a reasonably easy way (others had issues with not being able to cast to a Python reference from a C++ constructor). We could stand to have more formatting helpers, but best to get the core mechanism in first.
Differential Revision: https://reviews.llvm.org/D116568
Added:
mlir/test/python/ir/diagnostic_handler.py
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b39a1ea844e4e..1a7eb46f75292 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -511,6 +511,57 @@ void PyMlirContext::contextExit(const pybind11::object &excType,
PyThreadContextEntry::popContext(*this);
}
+py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
+ // Note that ownership is transferred to the delete callback below by way of
+ // an explicit inc_ref (borrow).
+ PyDiagnosticHandler *pyHandler =
+ new PyDiagnosticHandler(get(), std::move(callback));
+ py::object pyHandlerObject =
+ py::cast(pyHandler, py::return_value_policy::take_ownership);
+ pyHandlerObject.inc_ref();
+
+ // In these C callbacks, the userData is a PyDiagnosticHandler* that is
+ // guaranteed to be known to pybind.
+ auto handlerCallback =
+ +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
+ PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
+ py::object pyDiagnosticObject =
+ py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
+
+ auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
+ bool result = false;
+ {
+ // Since this can be called from arbitrary C++ contexts, always get the
+ // gil.
+ py::gil_scoped_acquire gil;
+ try {
+ result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
+ } catch (std::exception &e) {
+ fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
+ e.what());
+ pyHandler->hadError = true;
+ }
+ }
+
+ pyDiagnostic->invalidate();
+ return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
+ };
+ auto deleteCallback = +[](void *userData) {
+ auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
+ assert(pyHandler->registeredID && "handler is not registered");
+ pyHandler->registeredID.reset();
+
+ // Decrement reference, balancing the inc_ref() above.
+ py::object pyHandlerObject =
+ py::cast(pyHandler, py::return_value_policy::reference);
+ pyHandlerObject.dec_ref();
+ };
+
+ pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
+ get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
+ return pyHandlerObject;
+}
+
PyMlirContext &DefaultingPyMlirContext::resolve() {
PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
if (!context) {
@@ -656,6 +707,78 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
stack.pop_back();
}
+//------------------------------------------------------------------------------
+// PyDiagnostic*
+//------------------------------------------------------------------------------
+
+void PyDiagnostic::invalidate() {
+ valid = false;
+ if (materializedNotes) {
+ for (auto ¬eObject : *materializedNotes) {
+ PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
+ note->invalidate();
+ }
+ }
+}
+
+PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
+ py::object callback)
+ : context(context), callback(std::move(callback)) {}
+
+PyDiagnosticHandler::~PyDiagnosticHandler() {}
+
+void PyDiagnosticHandler::detach() {
+ if (!registeredID)
+ return;
+ MlirDiagnosticHandlerID localID = *registeredID;
+ mlirContextDetachDiagnosticHandler(context, localID);
+ assert(!registeredID && "should have unregistered");
+ // Not strictly necessary but keeps stale pointers from being around to cause
+ // issues.
+ context = {nullptr};
+}
+
+void PyDiagnostic::checkValid() {
+ if (!valid) {
+ throw std::invalid_argument(
+ "Diagnostic is invalid (used outside of callback)");
+ }
+}
+
+MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
+ checkValid();
+ return mlirDiagnosticGetSeverity(diagnostic);
+}
+
+PyLocation PyDiagnostic::getLocation() {
+ checkValid();
+ MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
+ MlirContext context = mlirLocationGetContext(loc);
+ return PyLocation(PyMlirContext::forContext(context), loc);
+}
+
+py::str PyDiagnostic::getMessage() {
+ checkValid();
+ py::object fileObject = py::module::import("io").attr("StringIO")();
+ PyFileAccumulator accum(fileObject, /*binary=*/false);
+ mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
+ return fileObject.attr("getvalue")();
+}
+
+py::tuple PyDiagnostic::getNotes() {
+ checkValid();
+ if (materializedNotes)
+ return *materializedNotes;
+ intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
+ materializedNotes = py::tuple(numNotes);
+ for (intptr_t i = 0; i < numNotes; ++i) {
+ MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
+ py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
+ PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
+ }
+ return *materializedNotes;
+}
+
//------------------------------------------------------------------------------
// PyDialect, PyDialectDescriptor, PyDialects
//------------------------------------------------------------------------------
@@ -2024,6 +2147,36 @@ class PyOpAttributeMap {
//------------------------------------------------------------------------------
void mlir::python::populateIRCore(py::module &m) {
+ //----------------------------------------------------------------------------
+ // Enums.
+ //----------------------------------------------------------------------------
+ py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
+ .value("ERROR", MlirDiagnosticError)
+ .value("WARNING", MlirDiagnosticWarning)
+ .value("NOTE", MlirDiagnosticNote)
+ .value("REMARK", MlirDiagnosticRemark);
+
+ //----------------------------------------------------------------------------
+ // Mapping of Diagnostics.
+ //----------------------------------------------------------------------------
+ py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
+ .def_property_readonly("severity", &PyDiagnostic::getSeverity)
+ .def_property_readonly("location", &PyDiagnostic::getLocation)
+ .def_property_readonly("message", &PyDiagnostic::getMessage)
+ .def_property_readonly("notes", &PyDiagnostic::getNotes)
+ .def("__str__", [](PyDiagnostic &self) -> py::str {
+ if (!self.isValid())
+ return "<Invalid Diagnostic>";
+ return self.getMessage();
+ });
+
+ py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
+ .def("detach", &PyDiagnosticHandler::detach)
+ .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
+ .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
+ .def("__enter__", &PyDiagnosticHandler::contextEnter)
+ .def("__exit__", &PyDiagnosticHandler::contextExit);
+
//----------------------------------------------------------------------------
// Mapping of MlirContext.
//----------------------------------------------------------------------------
@@ -2079,6 +2232,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyMlirContext &self, bool value) {
mlirContextSetAllowUnregisteredDialects(self.get(), value);
})
+ .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+ py::arg("callback"),
+ "Attaches a diagnostic handler that will receive callbacks")
.def(
"enable_multithreading",
[](PyMlirContext &self, bool enable) {
@@ -2204,7 +2360,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("context") = py::none(), kContextGetFileLocationDocstring)
.def_static(
"fused",
- [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
+ [](const std::vector<PyLocation> &pyLocations,
+ llvm::Optional<PyAttribute> metadata,
DefaultingPyMlirContext context) {
if (pyLocations.empty())
throw py::value_error("No locations provided");
@@ -2236,6 +2393,12 @@ void mlir::python::populateIRCore(py::module &m) {
"context",
[](PyLocation &self) { return self.getContext().getObject(); },
"Context that owns the Location")
+ .def(
+ "emit_error",
+ [](PyLocation &self, std::string message) {
+ mlirEmitError(self, message.c_str());
+ },
+ py::arg("message"), "Emits an error at this location")
.def("__repr__", [](PyLocation &self) {
PyPrintAccumulator printAccum;
mlirLocationPrint(self, printAccum.getCallback(),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 117435d633b16..2f354d6d12620 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -15,6 +15,7 @@
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "llvm/ADT/DenseMap.h"
@@ -24,6 +25,8 @@ namespace mlir {
namespace python {
class PyBlock;
+class PyDiagnostic;
+class PyDiagnosticHandler;
class PyInsertionPoint;
class PyLocation;
class DefaultingPyLocation;
@@ -207,6 +210,10 @@ class PyMlirContext {
const pybind11::object &excVal,
const pybind11::object &excTb);
+ /// Attaches a Python callback as a diagnostic handler, returning a
+ /// registration object (internally a PyDiagnosticHandler).
+ pybind11::object attachDiagnosticHandler(pybind11::object callback);
+
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -267,6 +274,75 @@ class BaseContextObject {
PyMlirContextRef contextRef;
};
+/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
+/// are only valid for the duration of a diagnostic callback and attempting
+/// to access them outside of that will raise an exception. This applies to
+/// nested diagnostics (in the notes) as well.
+class PyDiagnostic {
+public:
+ PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
+ void invalidate();
+ bool isValid() { return valid; }
+ MlirDiagnosticSeverity getSeverity();
+ PyLocation getLocation();
+ pybind11::str getMessage();
+ pybind11::tuple getNotes();
+
+private:
+ MlirDiagnostic diagnostic;
+
+ void checkValid();
+ /// If notes have been materialized from the diagnostic, then this will
+ /// be populated with the corresponding objects (all castable to
+ /// PyDiagnostic).
+ llvm::Optional<pybind11::tuple> materializedNotes;
+ bool valid = true;
+};
+
+/// Represents a diagnostic handler attached to the context. The handler's
+/// callback will be invoked with PyDiagnostic instances until the detach()
+/// method is called or the context is destroyed. A diagnostic handler can be
+/// the subject of a `with` block, which will detach it when the block exits.
+///
+/// Since diagnostic handlers can call back into Python code which can do
+/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
+/// etc), this is generally not deemed to be a great user-level API. Users
+/// should generally use some form of DiagnosticCollector. If the handler raises
+/// any exceptions, they will just be emitted to stderr and dropped.
+///
+/// The unique usage of this class means that its lifetime management is
+///
diff erent from most other parts of the API. Instances are always created
+/// in an attached state and can transition to a detached state by either:
+/// a) The context being destroyed and unregistering all handlers.
+/// b) An explicit call to detach().
+/// The object may remain live from a Python perspective for an arbitrary time
+/// after detachment, but there is nothing the user can do with it (since there
+/// is no way to attach an existing handler object).
+class PyDiagnosticHandler {
+public:
+ PyDiagnosticHandler(MlirContext context, pybind11::object callback);
+ ~PyDiagnosticHandler();
+
+ bool isAttached() { return registeredID.hasValue(); }
+ bool getHadError() { return hadError; }
+
+ /// Detaches the handler. Does nothing if not attached.
+ void detach();
+
+ pybind11::object contextEnter() { return pybind11::cast(this); }
+ void contextExit(pybind11::object excType, pybind11::object excVal,
+ pybind11::object excTb) {
+ detach();
+ }
+
+private:
+ MlirContext context;
+ pybind11::object callback;
+ llvm::Optional<MlirDiagnosticHandlerID> registeredID;
+ bool hadError = false;
+ friend class PyMlirContext;
+};
+
/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
/// order to
diff erentiate it from the `Dialect` base class which is extended by
/// plugins which extend dialect functionality through extension python code.
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e61e34a176b03..affe54c3e11ba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -7,7 +7,7 @@
# * Local edits to signatures and types that MyPy did not auto detect (or
# detected incorrectly).
-from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
from typing import overload
@@ -43,6 +43,9 @@ __all__ = [
"Dialect",
"DialectDescriptor",
"Dialects",
+ "Diagnostic",
+ "DiagnosticHandler",
+ "DiagnosticSeverity",
"DictAttr",
"F16Type",
"F32Type",
@@ -425,8 +428,9 @@ class Context:
def _get_live_count() -> int: ...
def _get_live_module_count(self) -> int: ...
def _get_live_operation_count(self) -> int: ...
+ def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ...
def enable_multithreading(self, enable: bool) -> None: ...
- def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ...
+ def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ...
def is_registered_operation(self, operation_name: str) -> bool: ...
def __enter__(self) -> "Context": ...
def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
@@ -479,6 +483,31 @@ class Dialects:
def __getattr__(self, arg0: str) -> "Dialect": ...
def __getitem__(self, arg0: str) -> "Dialect": ...
+class Diagnostic:
+ @property
+ def severity(self) -> "DiagnosticSeverity": ...
+ @property
+ def location(self) -> "Location": ...
+ @property
+ def message(self) -> str: ...
+ @property
+ def notes(self) -> Tuple["Diagnostic"]: ...
+
+class DiagnosticHandler:
+ def detach(self) -> None: ...
+ @property
+ def attached(self) -> bool: ...
+ @property
+ def had_error(self) -> bool: ...
+ def __enter__(self) -> "DiagnosticHandler": ...
+ def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
+
+class DiagnosticSeverity:
+ ERROR: "DiagnosticSeverity"
+ WARNING: "DiagnosticSeverity"
+ NOTE: "DiagnosticSeverity"
+ REMARK: "DiagnosticSeverity"
+
# TODO: Auto-generated. Audit and fix.
class DictAttr(Attribute):
def __init__(self, cast_from_attr: Attribute) -> None: ...
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
new file mode 100644
index 0000000000000..f38187a6f3be2
--- /dev/null
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -0,0 +1,172 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+ return f
+
+
+ at run
+def testLifecycleContextDestroy():
+ ctx = Context()
+ def callback(foo): ...
+ handler = ctx.attach_diagnostic_handler(callback)
+ assert handler.attached
+ # If context is destroyed before the handler, it should auto-detach.
+ ctx = None
+ gc.collect()
+ assert not handler.attached
+
+ # And finally collecting the handler should be fine.
+ handler = None
+ gc.collect()
+
+
+ at run
+def testLifecycleExplicitDetach():
+ ctx = Context()
+ def callback(foo): ...
+ handler = ctx.attach_diagnostic_handler(callback)
+ assert handler.attached
+ handler.detach()
+ assert not handler.attached
+
+
+ at run
+def testLifecycleWith():
+ ctx = Context()
+ def callback(foo): ...
+ with ctx.attach_diagnostic_handler(callback) as handler:
+ assert handler.attached
+ assert not handler.attached
+
+
+ at run
+def testLifecycleWithAndExplicitDetach():
+ ctx = Context()
+ def callback(foo): ...
+ with ctx.attach_diagnostic_handler(callback) as handler:
+ assert handler.attached
+ handler.detach()
+ assert not handler.attached
+
+
+# CHECK-LABEL: TEST: testDiagnosticCallback
+ at run
+def testDiagnosticCallback():
+ ctx = Context()
+ def callback(d):
+ # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
+ print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
+ return True
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
+
+
+# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
+# TODO: Come up with a way to inject a diagnostic with notes from this API.
+ at run
+def testDiagnosticEmptyNotes():
+ ctx = Context()
+ def callback(d):
+ # CHECK: DIAGNOSTIC: notes=()
+ print(f"DIAGNOSTIC: notes={d.notes}")
+ return True
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
+
+
+# CHECK-LABEL: TEST: testDiagnosticCallbackException
+ at run
+def testDiagnosticCallbackException():
+ ctx = Context()
+ def callback(d):
+ raise ValueError("Error in handler")
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert handler.had_error
+
+
+# CHECK-LABEL: TEST: testEscapingDiagnostic
+ at run
+def testEscapingDiagnostic():
+ ctx = Context()
+ diags = []
+ def callback(d):
+ diags.append(d)
+ return True
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ loc.emit_error("foobar")
+ assert not handler.had_error
+
+ # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
+ print(f"DIAGNOSTIC: {str(diags[0])}")
+ try:
+ diags[0].severity
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].location
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].message
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+ try:
+ diags[0].notes
+ raise RuntimeError("expected exception")
+ except ValueError:
+ pass
+
+
+
+# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
+ at run
+def testDiagnosticReturnTrueHandles():
+ ctx = Context()
+ def callback1(d):
+ print(f"CALLBACK1: {d}")
+ return True
+ def callback2(d):
+ print(f"CALLBACK2: {d}")
+ return True
+ ctx.attach_diagnostic_handler(callback1)
+ ctx.attach_diagnostic_handler(callback2)
+ loc = Location.unknown(ctx)
+ # CHECK-NOT: CALLBACK1
+ # CHECK: CALLBACK2: foobar
+ # CHECK-NOT: CALLBACK1
+ loc.emit_error("foobar")
+
+
+# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
+ at run
+def testDiagnosticReturnFalseDoesNotHandle():
+ ctx = Context()
+ def callback1(d):
+ print(f"CALLBACK1: {d}")
+ return True
+ def callback2(d):
+ print(f"CALLBACK2: {d}")
+ return False
+ ctx.attach_diagnostic_handler(callback1)
+ ctx.attach_diagnostic_handler(callback2)
+ loc = Location.unknown(ctx)
+ # CHECK: CALLBACK2: foobar
+ # CHECK: CALLBACK1: foobar
+ loc.emit_error("foobar")
More information about the Mlir-commits
mailing list