[Mlir-commits] [mlir] [MLIR] [Python] a few more fixes to type annotaitons (PR #186106)
Sergei Lebedev
llvmlistbot at llvm.org
Thu Mar 12 07:58:19 PDT 2026
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/186106
>From 9786e4a26f2b8b36560ffcb121dc509b6a5b3811 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Thu, 12 Mar 2026 10:34:15 +0000
Subject: [PATCH] [MLIR] [Python] a few more fixes to type annotaitons
* `_OperationBase.walk` was missing a default.
* `MLIRError` is now fully defined in C++. The monkey-patching previously
done in `_site_initialize` was opaque to type checkers.
---
mlir/include/mlir/Bindings/Python/IRCore.h | 10 ++-
mlir/lib/Bindings/Python/IRCore.cpp | 73 +++++++++++++++++++++-
mlir/lib/Bindings/Python/MainModule.cpp | 14 -----
mlir/python/mlir/_mlir_libs/__init__.py | 34 ----------
4 files changed, 80 insertions(+), 51 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index bd2d49acbf681..557e32e9a612d 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_IRCORE_H
#include <cstddef>
+#include <exception>
#include <optional>
#include <sstream>
#include <utility>
@@ -1313,12 +1314,17 @@ class MLIR_PYTHON_API_EXPORTED PySymbolTable {
};
/// Custom exception that allows access to error diagnostic information. This is
-/// converted to the `ir.MLIRError` python exception when thrown.
-struct MLIR_PYTHON_API_EXPORTED MLIRError {
+/// translated to the `ir.MLIRError` python exception when thrown.
+struct MLIR_PYTHON_API_EXPORTED MLIRError : std::exception {
MLIRError(std::string message,
std::vector<PyDiagnostic::DiagnosticInfo> &&errorDiagnostics = {})
: message(std::move(message)),
errorDiagnostics(std::move(errorDiagnostics)) {}
+ const char *what() const noexcept override { return message.c_str(); }
+
+ /// Bind the MLIRError exception class to the given module.
+ static void bind(nanobind::module_ &m);
+
std::string message;
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b8637c57a3f48..86f391ed7b84a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2774,6 +2774,74 @@ namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+static std::string formatMLIRError(const MLIRError &e) {
+ auto locStr = [](const PyLocation &loc) {
+ PyPrintAccumulator accum;
+ mlirLocationPrint(loc, accum.getCallback(), accum.getUserData());
+ std::string s = nb::cast<std::string>(nb::str(accum.join()));
+ std::string_view sv(s);
+ if (sv.size() > 5) {
+ sv.remove_prefix(4); // "loc("
+ sv.remove_suffix(1); // ")"
+ }
+ return std::string(sv);
+ };
+ auto indent = [](std::string s) {
+ size_t pos = 0;
+ while ((pos = s.find('\n', pos)) != std::string::npos) {
+ s.replace(pos, 1, "\n ");
+ pos += 3;
+ }
+ return s;
+ };
+
+ std::ostringstream os;
+ os << e.message;
+ if (!e.errorDiagnostics.empty())
+ os << ":";
+ for (const auto &diag : e.errorDiagnostics) {
+ os << "\nerror: " << locStr(diag.location) << ": " << indent(diag.message);
+ for (const auto ¬e : diag.notes)
+ os << "\n note: " << locStr(note.location) << ": "
+ << indent(note.message);
+ }
+ return os.str();
+}
+
+void MLIRError::bind(nb::module_ &m) {
+ auto cls = nb::exception<MLIRError>(m, "MLIRError", PyExc_Exception);
+ nb::register_exception_translator(
+ [](const std::exception_ptr &p, void *payload) {
+ try {
+ if (p)
+ std::rethrow_exception(p);
+ } catch (MLIRError &e) {
+ std::string formatted = formatMLIRError(e);
+ nb::object ty = nb::borrow(static_cast<PyObject *>(payload));
+ nb::object obj = ty(formatted);
+ obj.attr("_message") = nb::cast(std::move(e.message));
+ obj.attr("_error_diagnostics") =
+ nb::cast(std::move(e.errorDiagnostics));
+ PyErr_SetObject(static_cast<PyObject *>(payload), obj.ptr());
+ }
+ },
+ cls.ptr());
+ auto propertyType = nb::borrow<nb::type_object>(
+ reinterpret_cast<PyObject *>(&PyProperty_Type));
+ nb::setattr(
+ cls, "message",
+ propertyType(nb::cpp_function(
+ [](nb::object self) -> nb::str { return self.attr("_message"); },
+ nb::is_method())));
+ nb::setattr(cls, "error_diagnostics",
+ propertyType(nb::cpp_function(
+ [](nb::object self)
+ -> nb::typed<nb::list, PyDiagnostic::DiagnosticInfo> {
+ return self.attr("_error_diagnostics");
+ },
+ nb::is_method())));
+}
+
void populateRoot(nb::module_ &m) {
m.attr("T") = nb::type_var("T");
m.attr("U") = nb::type_var("U");
@@ -3861,7 +3929,7 @@ void populateIRCore(nb::module_ &m) {
.def("walk", &PyOperationBase::walk, "callback"_a,
"walk_order"_a = PyWalkOrder::PostOrder,
// clang-format off
- nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None"),
+ nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder = ...) -> None"),
// clang-format on
R"(
Walks the operation tree with a callback function.
@@ -4930,6 +4998,9 @@ void populateIRCore(nb::module_ &m) {
PyDynamicOpTrait::bind(m);
PyDynamicOpTraits::IsTerminator::bind(m);
PyDynamicOpTraits::NoTerminator::bind(m);
+
+ // MLIRError exception.
+ MLIRError::bind(m);
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 88f58d45cdd75..93ce34b2e89f5 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -54,18 +54,4 @@ NB_MODULE(_mlir, m) {
auto passManagerModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passManagerModule);
- nanobind::register_exception_translator(
- [](const std::exception_ptr &p, void *payload) {
- // We can't define exceptions with custom fields through pybind, so
- // instead the exception class is defined in python and imported here.
- try {
- if (p)
- std::rethrow_exception(p);
- } catch (const MLIRError &e) {
- nanobind::object obj =
- nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("MLIRError")(e.message, e.errorDiagnostics);
- PyErr_SetObject(PyExc_Exception, obj.ptr());
- }
- });
}
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 8f4cb385c09e8..886f38a6bf793 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -200,40 +200,6 @@ def __init__(
ir.Context = Context
- class MLIRError(Exception):
- """
- An exception with diagnostic information. Has the following fields:
- message: str
- error_diagnostics: List[ir.DiagnosticInfo]
- """
-
- def __init__(self, message, error_diagnostics):
- self.message = message
- self.error_diagnostics = error_diagnostics
- super().__init__(message, error_diagnostics)
-
- def __str__(self):
- s = self.message
- if self.error_diagnostics:
- s += ":"
- for diag in self.error_diagnostics:
- s += (
- "\nerror: "
- + str(diag.location)[4:-1]
- + ": "
- + diag.message.replace("\n", "\n ")
- )
- for note in diag.notes:
- s += (
- "\n note: "
- + str(note.location)[4:-1]
- + ": "
- + note.message.replace("\n", "\n ")
- )
- return s
-
- ir.MLIRError = MLIRError
-
# Register containers as Sequences, so they can be used with `match`.
Sequence.register(ir.BlockArgumentList)
More information about the Mlir-commits
mailing list