[Mlir-commits] [mlir] 7ee25bc - [mlir][python] Add bindings for diagnostic handler.

Stella Laurenzo llvmlistbot at llvm.org
Tue Jan 4 11:12:13 PST 2022


Author: Stella Laurenzo
Date: 2022-01-04T11:04:37-08:00
New Revision: 7ee25bc56f92495eb6d289b5ec18a07f27f1f44b

URL: https://github.com/llvm/llvm-project/commit/7ee25bc56f92495eb6d289b5ec18a07f27f1f44b
DIFF: https://github.com/llvm/llvm-project/commit/7ee25bc56f92495eb6d289b5ec18a07f27f1f44b.diff

LOG: [mlir][python] Add bindings for diagnostic handler.

I considered multiple approaches for this but settled on this one because I could make the lifetime management work in a reasonably easy way (others had issues with not being able to cast to a Python reference from a C++ constructor). We could stand to have more formatting helpers, but best to get the core mechanism in first.

Differential Revision: https://reviews.llvm.org/D116568

Added: 
    mlir/test/python/ir/diagnostic_handler.py

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b39a1ea844e4e..1a7eb46f75292 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -511,6 +511,57 @@ void PyMlirContext::contextExit(const pybind11::object &excType,
   PyThreadContextEntry::popContext(*this);
 }
 
+py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
+  // Note that ownership is transferred to the delete callback below by way of
+  // an explicit inc_ref (borrow).
+  PyDiagnosticHandler *pyHandler =
+      new PyDiagnosticHandler(get(), std::move(callback));
+  py::object pyHandlerObject =
+      py::cast(pyHandler, py::return_value_policy::take_ownership);
+  pyHandlerObject.inc_ref();
+
+  // In these C callbacks, the userData is a PyDiagnosticHandler* that is
+  // guaranteed to be known to pybind.
+  auto handlerCallback =
+      +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
+    PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
+    py::object pyDiagnosticObject =
+        py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
+
+    auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
+    bool result = false;
+    {
+      // Since this can be called from arbitrary C++ contexts, always get the
+      // gil.
+      py::gil_scoped_acquire gil;
+      try {
+        result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
+      } catch (std::exception &e) {
+        fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
+                e.what());
+        pyHandler->hadError = true;
+      }
+    }
+
+    pyDiagnostic->invalidate();
+    return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
+  };
+  auto deleteCallback = +[](void *userData) {
+    auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
+    assert(pyHandler->registeredID && "handler is not registered");
+    pyHandler->registeredID.reset();
+
+    // Decrement reference, balancing the inc_ref() above.
+    py::object pyHandlerObject =
+        py::cast(pyHandler, py::return_value_policy::reference);
+    pyHandlerObject.dec_ref();
+  };
+
+  pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
+      get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
+  return pyHandlerObject;
+}
+
 PyMlirContext &DefaultingPyMlirContext::resolve() {
   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
   if (!context) {
@@ -656,6 +707,78 @@ void PyThreadContextEntry::popLocation(PyLocation &location) {
   stack.pop_back();
 }
 
+//------------------------------------------------------------------------------
+// PyDiagnostic*
+//------------------------------------------------------------------------------
+
+void PyDiagnostic::invalidate() {
+  valid = false;
+  if (materializedNotes) {
+    for (auto &noteObject : *materializedNotes) {
+      PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
+      note->invalidate();
+    }
+  }
+}
+
+PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
+                                         py::object callback)
+    : context(context), callback(std::move(callback)) {}
+
+PyDiagnosticHandler::~PyDiagnosticHandler() {}
+
+void PyDiagnosticHandler::detach() {
+  if (!registeredID)
+    return;
+  MlirDiagnosticHandlerID localID = *registeredID;
+  mlirContextDetachDiagnosticHandler(context, localID);
+  assert(!registeredID && "should have unregistered");
+  // Not strictly necessary but keeps stale pointers from being around to cause
+  // issues.
+  context = {nullptr};
+}
+
+void PyDiagnostic::checkValid() {
+  if (!valid) {
+    throw std::invalid_argument(
+        "Diagnostic is invalid (used outside of callback)");
+  }
+}
+
+MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
+  checkValid();
+  return mlirDiagnosticGetSeverity(diagnostic);
+}
+
+PyLocation PyDiagnostic::getLocation() {
+  checkValid();
+  MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
+  MlirContext context = mlirLocationGetContext(loc);
+  return PyLocation(PyMlirContext::forContext(context), loc);
+}
+
+py::str PyDiagnostic::getMessage() {
+  checkValid();
+  py::object fileObject = py::module::import("io").attr("StringIO")();
+  PyFileAccumulator accum(fileObject, /*binary=*/false);
+  mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
+  return fileObject.attr("getvalue")();
+}
+
+py::tuple PyDiagnostic::getNotes() {
+  checkValid();
+  if (materializedNotes)
+    return *materializedNotes;
+  intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
+  materializedNotes = py::tuple(numNotes);
+  for (intptr_t i = 0; i < numNotes; ++i) {
+    MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
+    py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
+    PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
+  }
+  return *materializedNotes;
+}
+
 //------------------------------------------------------------------------------
 // PyDialect, PyDialectDescriptor, PyDialects
 //------------------------------------------------------------------------------
@@ -2024,6 +2147,36 @@ class PyOpAttributeMap {
 //------------------------------------------------------------------------------
 
 void mlir::python::populateIRCore(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Enums.
+  //----------------------------------------------------------------------------
+  py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
+      .value("ERROR", MlirDiagnosticError)
+      .value("WARNING", MlirDiagnosticWarning)
+      .value("NOTE", MlirDiagnosticNote)
+      .value("REMARK", MlirDiagnosticRemark);
+
+  //----------------------------------------------------------------------------
+  // Mapping of Diagnostics.
+  //----------------------------------------------------------------------------
+  py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
+      .def_property_readonly("severity", &PyDiagnostic::getSeverity)
+      .def_property_readonly("location", &PyDiagnostic::getLocation)
+      .def_property_readonly("message", &PyDiagnostic::getMessage)
+      .def_property_readonly("notes", &PyDiagnostic::getNotes)
+      .def("__str__", [](PyDiagnostic &self) -> py::str {
+        if (!self.isValid())
+          return "<Invalid Diagnostic>";
+        return self.getMessage();
+      });
+
+  py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
+      .def("detach", &PyDiagnosticHandler::detach)
+      .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
+      .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
+      .def("__enter__", &PyDiagnosticHandler::contextEnter)
+      .def("__exit__", &PyDiagnosticHandler::contextExit);
+
   //----------------------------------------------------------------------------
   // Mapping of MlirContext.
   //----------------------------------------------------------------------------
@@ -2079,6 +2232,9 @@ void mlir::python::populateIRCore(py::module &m) {
           [](PyMlirContext &self, bool value) {
             mlirContextSetAllowUnregisteredDialects(self.get(), value);
           })
+      .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
+           py::arg("callback"),
+           "Attaches a diagnostic handler that will receive callbacks")
       .def(
           "enable_multithreading",
           [](PyMlirContext &self, bool enable) {
@@ -2204,7 +2360,8 @@ void mlir::python::populateIRCore(py::module &m) {
           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
       .def_static(
           "fused",
-          [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
+          [](const std::vector<PyLocation> &pyLocations,
+             llvm::Optional<PyAttribute> metadata,
              DefaultingPyMlirContext context) {
             if (pyLocations.empty())
               throw py::value_error("No locations provided");
@@ -2236,6 +2393,12 @@ void mlir::python::populateIRCore(py::module &m) {
           "context",
           [](PyLocation &self) { return self.getContext().getObject(); },
           "Context that owns the Location")
+      .def(
+          "emit_error",
+          [](PyLocation &self, std::string message) {
+            mlirEmitError(self, message.c_str());
+          },
+          py::arg("message"), "Emits an error at this location")
       .def("__repr__", [](PyLocation &self) {
         PyPrintAccumulator printAccum;
         mlirLocationPrint(self, printAccum.getCallback(),

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 117435d633b16..2f354d6d12620 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -15,6 +15,7 @@
 
 #include "mlir-c/AffineExpr.h"
 #include "mlir-c/AffineMap.h"
+#include "mlir-c/Diagnostics.h"
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "llvm/ADT/DenseMap.h"
@@ -24,6 +25,8 @@ namespace mlir {
 namespace python {
 
 class PyBlock;
+class PyDiagnostic;
+class PyDiagnosticHandler;
 class PyInsertionPoint;
 class PyLocation;
 class DefaultingPyLocation;
@@ -207,6 +210,10 @@ class PyMlirContext {
                    const pybind11::object &excVal,
                    const pybind11::object &excTb);
 
+  /// Attaches a Python callback as a diagnostic handler, returning a
+  /// registration object (internally a PyDiagnosticHandler).
+  pybind11::object attachDiagnosticHandler(pybind11::object callback);
+
 private:
   PyMlirContext(MlirContext context);
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
@@ -267,6 +274,75 @@ class BaseContextObject {
   PyMlirContextRef contextRef;
 };
 
+/// Python class mirroring the C MlirDiagnostic struct. Note that these structs
+/// are only valid for the duration of a diagnostic callback and attempting
+/// to access them outside of that will raise an exception. This applies to
+/// nested diagnostics (in the notes) as well.
+class PyDiagnostic {
+public:
+  PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
+  void invalidate();
+  bool isValid() { return valid; }
+  MlirDiagnosticSeverity getSeverity();
+  PyLocation getLocation();
+  pybind11::str getMessage();
+  pybind11::tuple getNotes();
+
+private:
+  MlirDiagnostic diagnostic;
+
+  void checkValid();
+  /// If notes have been materialized from the diagnostic, then this will
+  /// be populated with the corresponding objects (all castable to
+  /// PyDiagnostic).
+  llvm::Optional<pybind11::tuple> materializedNotes;
+  bool valid = true;
+};
+
+/// Represents a diagnostic handler attached to the context. The handler's
+/// callback will be invoked with PyDiagnostic instances until the detach()
+/// method is called or the context is destroyed. A diagnostic handler can be
+/// the subject of a `with` block, which will detach it when the block exits.
+///
+/// Since diagnostic handlers can call back into Python code which can do
+/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
+/// etc), this is generally not deemed to be a great user-level API. Users
+/// should generally use some form of DiagnosticCollector. If the handler raises
+/// any exceptions, they will just be emitted to stderr and dropped.
+///
+/// The unique usage of this class means that its lifetime management is
+/// 
diff erent from most other parts of the API. Instances are always created
+/// in an attached state and can transition to a detached state by either:
+///   a) The context being destroyed and unregistering all handlers.
+///   b) An explicit call to detach().
+/// The object may remain live from a Python perspective for an arbitrary time
+/// after detachment, but there is nothing the user can do with it (since there
+/// is no way to attach an existing handler object).
+class PyDiagnosticHandler {
+public:
+  PyDiagnosticHandler(MlirContext context, pybind11::object callback);
+  ~PyDiagnosticHandler();
+
+  bool isAttached() { return registeredID.hasValue(); }
+  bool getHadError() { return hadError; }
+
+  /// Detaches the handler. Does nothing if not attached.
+  void detach();
+
+  pybind11::object contextEnter() { return pybind11::cast(this); }
+  void contextExit(pybind11::object excType, pybind11::object excVal,
+                   pybind11::object excTb) {
+    detach();
+  }
+
+private:
+  MlirContext context;
+  pybind11::object callback;
+  llvm::Optional<MlirDiagnosticHandlerID> registeredID;
+  bool hadError = false;
+  friend class PyMlirContext;
+};
+
 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
 /// order to 
diff erentiate it from the `Dialect` base class which is extended by
 /// plugins which extend dialect functionality through extension python code.

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index e61e34a176b03..affe54c3e11ba 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -7,7 +7,7 @@
 #   * Local edits to signatures and types that MyPy did not auto detect (or
 #     detected incorrectly).
 
-from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence
+from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple
 
 from typing import overload
 
@@ -43,6 +43,9 @@ __all__ = [
     "Dialect",
     "DialectDescriptor",
     "Dialects",
+    "Diagnostic",
+    "DiagnosticHandler",
+    "DiagnosticSeverity",
     "DictAttr",
     "F16Type",
     "F32Type",
@@ -425,8 +428,9 @@ class Context:
     def _get_live_count() -> int: ...
     def _get_live_module_count(self) -> int: ...
     def _get_live_operation_count(self) -> int: ...
+    def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ...
     def enable_multithreading(self, enable: bool) -> None: ...
-    def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ...
+    def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ...
     def is_registered_operation(self, operation_name: str) -> bool: ...
     def __enter__(self) -> "Context": ...
     def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
@@ -479,6 +483,31 @@ class Dialects:
     def __getattr__(self, arg0: str) -> "Dialect": ...
     def __getitem__(self, arg0: str) -> "Dialect": ...
 
+class Diagnostic:
+    @property
+    def severity(self) -> "DiagnosticSeverity": ...
+    @property
+    def location(self) -> "Location": ...
+    @property
+    def message(self) -> str: ...
+    @property
+    def notes(self) -> Tuple["Diagnostic"]: ...
+
+class DiagnosticHandler:
+    def detach(self) -> None: ...
+    @property
+    def attached(self) -> bool: ...
+    @property
+    def had_error(self) -> bool: ...
+    def __enter__(self) -> "DiagnosticHandler": ...
+    def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ...
+
+class DiagnosticSeverity:
+    ERROR: "DiagnosticSeverity"
+    WARNING: "DiagnosticSeverity"
+    NOTE: "DiagnosticSeverity"
+    REMARK: "DiagnosticSeverity"
+
 # TODO: Auto-generated. Audit and fix.
 class DictAttr(Attribute):
     def __init__(self, cast_from_attr: Attribute) -> None: ...

diff  --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
new file mode 100644
index 0000000000000..f38187a6f3be2
--- /dev/null
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -0,0 +1,172 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+  return f
+
+
+ at run
+def testLifecycleContextDestroy():
+  ctx = Context()
+  def callback(foo): ...
+  handler = ctx.attach_diagnostic_handler(callback)
+  assert handler.attached
+  # If context is destroyed before the handler, it should auto-detach.
+  ctx = None
+  gc.collect()
+  assert not handler.attached
+
+  # And finally collecting the handler should be fine.
+  handler = None
+  gc.collect()
+
+
+ at run
+def testLifecycleExplicitDetach():
+  ctx = Context()
+  def callback(foo): ...
+  handler = ctx.attach_diagnostic_handler(callback)
+  assert handler.attached
+  handler.detach()
+  assert not handler.attached
+
+
+ at run
+def testLifecycleWith():
+  ctx = Context()
+  def callback(foo): ...
+  with ctx.attach_diagnostic_handler(callback) as handler:
+    assert handler.attached
+  assert not handler.attached
+
+
+ at run
+def testLifecycleWithAndExplicitDetach():
+  ctx = Context()
+  def callback(foo): ...
+  with ctx.attach_diagnostic_handler(callback) as handler:
+    assert handler.attached
+    handler.detach()
+  assert not handler.attached
+
+
+# CHECK-LABEL: TEST: testDiagnosticCallback
+ at run
+def testDiagnosticCallback():
+  ctx = Context()
+  def callback(d):
+    # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown)
+    print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}")
+    return True
+  handler = ctx.attach_diagnostic_handler(callback)
+  loc = Location.unknown(ctx)
+  loc.emit_error("foobar")
+  assert not handler.had_error
+
+
+# CHECK-LABEL: TEST: testDiagnosticEmptyNotes
+# TODO: Come up with a way to inject a diagnostic with notes from this API.
+ at run
+def testDiagnosticEmptyNotes():
+  ctx = Context()
+  def callback(d):
+    # CHECK: DIAGNOSTIC: notes=()
+    print(f"DIAGNOSTIC: notes={d.notes}")
+    return True
+  handler = ctx.attach_diagnostic_handler(callback)
+  loc = Location.unknown(ctx)
+  loc.emit_error("foobar")
+  assert not handler.had_error
+
+
+# CHECK-LABEL: TEST: testDiagnosticCallbackException
+ at run
+def testDiagnosticCallbackException():
+  ctx = Context()
+  def callback(d):
+    raise ValueError("Error in handler")
+  handler = ctx.attach_diagnostic_handler(callback)
+  loc = Location.unknown(ctx)
+  loc.emit_error("foobar")
+  assert handler.had_error
+
+
+# CHECK-LABEL: TEST: testEscapingDiagnostic
+ at run
+def testEscapingDiagnostic():
+  ctx = Context()
+  diags = []
+  def callback(d):
+    diags.append(d)
+    return True
+  handler = ctx.attach_diagnostic_handler(callback)
+  loc = Location.unknown(ctx)
+  loc.emit_error("foobar")
+  assert not handler.had_error
+
+  # CHECK: DIAGNOSTIC: <Invalid Diagnostic>
+  print(f"DIAGNOSTIC: {str(diags[0])}")
+  try:
+    diags[0].severity
+    raise RuntimeError("expected exception")
+  except ValueError:
+    pass
+  try:
+    diags[0].location
+    raise RuntimeError("expected exception")
+  except ValueError:
+    pass
+  try:
+    diags[0].message
+    raise RuntimeError("expected exception")
+  except ValueError:
+    pass
+  try:
+    diags[0].notes
+    raise RuntimeError("expected exception")
+  except ValueError:
+    pass
+
+
+
+# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles
+ at run
+def testDiagnosticReturnTrueHandles():
+  ctx = Context()
+  def callback1(d):
+    print(f"CALLBACK1: {d}")
+    return True
+  def callback2(d):
+    print(f"CALLBACK2: {d}")
+    return True
+  ctx.attach_diagnostic_handler(callback1)
+  ctx.attach_diagnostic_handler(callback2)
+  loc = Location.unknown(ctx)
+  # CHECK-NOT: CALLBACK1
+  # CHECK: CALLBACK2: foobar
+  # CHECK-NOT: CALLBACK1
+  loc.emit_error("foobar")
+
+
+# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle
+ at run
+def testDiagnosticReturnFalseDoesNotHandle():
+  ctx = Context()
+  def callback1(d):
+    print(f"CALLBACK1: {d}")
+    return True
+  def callback2(d):
+    print(f"CALLBACK2: {d}")
+    return False
+  ctx.attach_diagnostic_handler(callback1)
+  ctx.attach_diagnostic_handler(callback2)
+  loc = Location.unknown(ctx)
+  # CHECK: CALLBACK2: foobar
+  # CHECK: CALLBACK1: foobar
+  loc.emit_error("foobar")


        


More information about the Mlir-commits mailing list