[Mlir-commits] [mlir] [MLIR] [Python] a few more fixes to type annotaitons (PR #186106)

Sergei Lebedev llvmlistbot at llvm.org
Thu Mar 12 10:49:08 PDT 2026


https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/186106

>From 2fa7410fbc736eccf3b0931d3c11e2264f151af7 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Thu, 12 Mar 2026 10:34:15 +0000
Subject: [PATCH] [MLIR] [Python] a few more fixes to type annotaitons

* `_OperationBase.walk` was missing a default.
* `MLIRError` is now fully defined in C++. The monkey-patching previously
  done in `_site_initialize` was opaque to type checkers.
---
 mlir/include/mlir/Bindings/Python/IRCore.h | 10 ++-
 mlir/lib/Bindings/Python/IRCore.cpp        | 74 +++++++++++++++++++++-
 mlir/lib/Bindings/Python/MainModule.cpp    | 14 ----
 mlir/python/mlir/_mlir_libs/__init__.py    | 34 ----------
 4 files changed, 81 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bd2d49acbf681..557e32e9a612d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -10,6 +10,7 @@
 #define MLIR_BINDINGS_PYTHON_IRCORE_H
 
 #include <cstddef>
+#include <exception>
 #include <optional>
 #include <sstream>
 #include <utility>
@@ -1313,12 +1314,17 @@ class MLIR_PYTHON_API_EXPORTED PySymbolTable {
 };
 
 /// Custom exception that allows access to error diagnostic information. This is
-/// converted to the `ir.MLIRError` python exception when thrown.
-struct MLIR_PYTHON_API_EXPORTED MLIRError {
+/// translated to the `ir.MLIRError` python exception when thrown.
+struct MLIR_PYTHON_API_EXPORTED MLIRError : std::exception {
   MLIRError(std::string message,
             std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
       : message(std::move(message)),
         errorDiagnostics(std::move(errorDiagnostics)) {}
+  const char *what() const noexcept override { return message.c_str(); }
+
+  /// Bind the MLIRError exception class to the given module.
+  static void bind(nanobind::module_ &m);
+
   std::string message;
   std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
 };
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..92ea44605b01a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2774,6 +2774,75 @@ namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 
+static std::string formatMLIRError(const MLIRError &e) {
+  auto locStr = [](const PyLocation &loc) {
+    PyPrintAccumulator accum;
+    mlirLocationPrint(loc, accum.getCallback(), accum.getUserData());
+    std::string s = nb::cast<std::string>(nb::str(accum.join()));
+    std::string_view sv(s);
+    if (sv.size() > 5) {
+      sv.remove_prefix(4); // "loc("
+      sv.remove_suffix(1); // ")"
+    }
+    return std::string(sv);
+  };
+  auto indent = [](std::string s) {
+    size_t pos = 0;
+    while ((pos = s.find('\n', pos)) != std::string::npos) {
+      s.replace(pos, 1, "\n  ");
+      pos += 3;
+    }
+    return s;
+  };
+
+  std::ostringstream os;
+  os << e.message;
+  if (!e.errorDiagnostics.empty())
+    os << ":";
+  for (const auto &diag : e.errorDiagnostics) {
+    os << "\nerror: " << locStr(diag.location) << ": " << indent(diag.message);
+    for (const auto &note : diag.notes) {
+      os << "\n note: " << locStr(note.location) << ": "
+         << indent(note.message);
+    }
+  }
+  return os.str();
+}
+
+void MLIRError::bind(nb::module_ &m) {
+  auto cls = nb::exception<MLIRError>(m, "MLIRError", PyExc_Exception);
+  nb::register_exception_translator(
+      [](const std::exception_ptr &p, void *payload) {
+        try {
+          if (p)
+            std::rethrow_exception(p);
+        } catch (MLIRError &e) {
+          std::string formatted = formatMLIRError(e);
+          nb::object ty = nb::borrow(static_cast<PyObject *>(payload));
+          nb::object obj = ty(formatted);
+          obj.attr("_message") = nb::cast(std::move(e.message));
+          obj.attr("_error_diagnostics") =
+              nb::cast(std::move(e.errorDiagnostics));
+          PyErr_SetObject(static_cast<PyObject *>(payload), obj.ptr());
+        }
+      },
+      cls.ptr());
+  auto propertyType = nb::borrow<nb::type_object>(
+      reinterpret_cast<PyObject *>(&PyProperty_Type));
+  nb::setattr(
+      cls, "message",
+      propertyType(nb::cpp_function(
+          [](nb::object self) -> nb::str { return self.attr("_message"); },
+          nb::is_method())));
+  nb::setattr(cls, "error_diagnostics",
+              propertyType(nb::cpp_function(
+                  [](nb::object self)
+                      -> nb::typed<nb::list, PyDiagnostic::DiagnosticInfo> {
+                    return self.attr("_error_diagnostics");
+                  },
+                  nb::is_method())));
+}
+
 void populateRoot(nb::module_ &m) {
   m.attr("T") = nb::type_var("T");
   m.attr("U") = nb::type_var("U");
@@ -3861,7 +3930,7 @@ void populateIRCore(nb::module_ &m) {
       .def("walk", &PyOperationBase::walk, "callback"_a,
            "walk_order"_a = PyWalkOrder::PostOrder,
            // clang-format off
-          nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+          nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ...) -> None"),
            // clang-format on
            R"(
              Walks the operation tree with a callback function.
@@ -4930,6 +4999,9 @@ void populateIRCore(nb::module_ &m) {
   PyDynamicOpTrait::bind(m);
   PyDynamicOpTraits::IsTerminator::bind(m);
   PyDynamicOpTraits::NoTerminator::bind(m);
+
+  // MLIRError exception.
+  MLIRError::bind(m);
 }
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 88f58d45cdd75..93ce34b2e89f5 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -54,18 +54,4 @@ NB_MODULE(_mlir, m) {
   auto passManagerModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passManagerModule);
-  nanobind::register_exception_translator(
-      [](const std::exception_ptr &p, void *payload) {
-        // We can't define exceptions with custom fields through pybind, so
-        // instead the exception class is defined in python and imported here.
-        try {
-          if (p)
-            std::rethrow_exception(p);
-        } catch (const MLIRError &e) {
-          nanobind::object obj =
-              nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                  .attr("MLIRError")(e.message, e.errorDiagnostics);
-          PyErr_SetObject(PyExc_Exception, obj.ptr());
-        }
-      });
 }
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 8f4cb385c09e8..886f38a6bf793 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -200,40 +200,6 @@ def __init__(
 
     ir.Context = Context
 
-    class MLIRError(Exception):
-        """
-        An exception with diagnostic information. Has the following fields:
-          message: str
-          error_diagnostics: List[ir.DiagnosticInfo]
-        """
-
-        def __init__(self, message, error_diagnostics):
-            self.message = message
-            self.error_diagnostics = error_diagnostics
-            super().__init__(message, error_diagnostics)
-
-        def __str__(self):
-            s = self.message
-            if self.error_diagnostics:
-                s += ":"
-            for diag in self.error_diagnostics:
-                s += (
-                    "\nerror: "
-                    + str(diag.location)[4:-1]
-                    + ": "
-                    + diag.message.replace("\n", "\n  ")
-                )
-                for note in diag.notes:
-                    s += (
-                        "\n note: "
-                        + str(note.location)[4:-1]
-                        + ": "
-                        + note.message.replace("\n", "\n  ")
-                    )
-            return s
-
-    ir.MLIRError = MLIRError
-
     # Register containers as Sequences, so they can be used with `match`.
 
     Sequence.register(ir.BlockArgumentList)



More information about the Mlir-commits mailing list