[Mlir-commits] [mlir] [MLIR] [Python] Fixed a few issues in the type annotations (PR #183021)
Sergei Lebedev
llvmlistbot at llvm.org
Sun Mar 8 06:07:45 PDT 2026
https://github.com/superbobry updated https://github.com/llvm/llvm-project/pull/183021
>From c45a7fc244c10ae717982266f899dbbb8666f84e Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Tue, 24 Feb 2026 09:49:52 +0000
Subject: [PATCH 1/3] [MLIR] [Python] Fixed a few issues in the type
annotations
* Removed an explicit `nb::sig` for `static_typeid`. The inferred type would
work just fine, and unqualified `TypeID`, which was there previously, only
really works for core types in the `ir` submodule.
* `DefaultingPyMlir*` helpers also produce qualified types, e.g.
`_mlir.ir.Location` instead of bare `Location`.
* `ir.*.__enter__` now returns a concrete type instead of `object`, e.g.
`ir.Context.__enter__` returns `Context`.
* `loc_tracebacks` uses `Generator` as the return type, since this is what
`contextmanager` expects in typeshed.
* Changed static methods on subclasses of `DenseElementsAttribute` to return
that concrete subclass, instead of `DenseElementsAttribute`.
I also sent wjakob/nanobind#1302 and wjakob/nanobind#1303, which should improve
the quality of the auto-generated stubs.
---
.../mlir/Bindings/Python/IRAttributes.h | 7 ++
mlir/include/mlir/Bindings/Python/IRCore.h | 36 ++++-------
mlir/lib/Bindings/Python/IRAttributes.cpp | 64 ++++++++++++-------
mlir/lib/Bindings/Python/IRCore.cpp | 12 ++--
mlir/python/mlir/ir.py | 4 +-
5 files changed, 73 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/IRAttributes.h b/mlir/include/mlir/Bindings/Python/IRAttributes.h
index 173674e0091d2..64a31d3f3e42d 100644
--- a/mlir/include/mlir/Bindings/Python/IRAttributes.h
+++ b/mlir/include/mlir/Bindings/Python/IRAttributes.h
@@ -424,6 +424,13 @@ class MLIR_PYTHON_API_EXPORTED PyDenseElementsAttribute
static PyType_Slot slots[];
+protected:
+ /// Registers get/get_splat factory methods with the concrete return
+ /// type in the nb::sig. Subclasses call this from their bindDerived
+ /// to override the return type in generated stubs.
+ template <typename ClassT>
+ static void bindFactoryMethods(ClassT &c, const char *pyClassName);
+
private:
static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags);
static void bf_releasebuffer(PyObject *, Py_buffer *buffer);
diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h
index 5953f26d07370..bd2d49acbf681 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -278,7 +278,7 @@ class MLIR_PYTHON_API_EXPORTED DefaultingPyMlirContext
: public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] = "Context";
+ static constexpr const char kTypeDescription[] = "_mlir.ir.Context";
static PyMlirContext &resolve();
};
@@ -524,7 +524,7 @@ class MLIR_PYTHON_API_EXPORTED DefaultingPyLocation
: public Defaulting<DefaultingPyLocation, PyLocation> {
public:
using Defaulting::Defaulting;
- static constexpr const char kTypeDescription[] = "Location";
+ static constexpr const char kTypeDescription[] = "_mlir.ir.Location";
static PyLocation &resolve();
operator MlirLocation() const { return *get(); }
@@ -957,16 +957,12 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy {
auto cls = ClassTy(m, DerivedTy::pyClassName, nanobind::is_generic());
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_type"));
- cls.def_prop_ro_static(
- "static_typeid",
- [](nanobind::object & /*class*/) {
- if (DerivedTy::getTypeIdFunction)
- return PyTypeID(DerivedTy::getTypeIdFunction());
- throw nanobind::attribute_error(
- (DerivedTy::pyClassName + std::string(" has no typeid."))
- .c_str());
- },
- nanobind::sig("def static_typeid(/) -> TypeID"));
+ cls.def_prop_ro_static("static_typeid", [](nanobind::object & /*class*/) {
+ if (DerivedTy::getTypeIdFunction)
+ return PyTypeID(DerivedTy::getTypeIdFunction());
+ throw nanobind::attribute_error(
+ (DerivedTy::pyClassName + std::string(" has no typeid.")).c_str());
+ });
cls.def_prop_ro("typeid", [](PyType &self) {
return nanobind::cast<PyTypeID>(nanobind::cast(self).attr("typeid"));
});
@@ -1100,16 +1096,12 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy {
return PyType(attr.getContext(), mlirAttributeGetType(attr))
.maybeDownCast();
});
- cls.def_prop_ro_static(
- "static_typeid",
- [](nanobind::object & /*class*/) -> PyTypeID {
- if (DerivedTy::getTypeIdFunction)
- return PyTypeID(DerivedTy::getTypeIdFunction());
- throw nanobind::attribute_error(
- (DerivedTy::pyClassName + std::string(" has no typeid."))
- .c_str());
- },
- nanobind::sig("def static_typeid(/) -> TypeID"));
+ cls.def_prop_ro_static("static_typeid", [](nanobind::object & /*class*/) {
+ if (DerivedTy::getTypeIdFunction)
+ return PyTypeID(DerivedTy::getTypeIdFunction());
+ throw nanobind::attribute_error(
+ (DerivedTy::pyClassName + std::string(" has no typeid.")).c_str());
+ });
cls.def_prop_ro("typeid", [](PyAttribute &self) {
return nanobind::cast<PyTypeID>(nanobind::cast(self).attr("typeid"));
});
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 59eb9b8e81cf0..997f16978fa58 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -421,12 +421,9 @@ void PyIntegerAttribute::bindDerived(ClassTy &c) {
c.def_prop_ro("value", toPyInt, "Returns the value of the integer attribute");
c.def("__int__", toPyInt,
"Converts the value of the integer attribute to a Python int");
- c.def_prop_ro_static(
- "static_typeid",
- [](nb::object & /*class*/) {
- return PyTypeID(mlirIntegerAttrGetTypeID());
- },
- nb::sig("def static_typeid(/) -> TypeID"));
+ c.def_prop_ro_static("static_typeid", [](nb::object & /*class*/) {
+ return PyTypeID(mlirIntegerAttrGetTypeID());
+ });
}
nb::object PyIntegerAttribute::toPyInt(PyIntegerAttribute &self) {
@@ -774,27 +771,48 @@ std::unique_ptr<nb_buffer_info> PyDenseElementsAttribute::accessBuffer() {
"unsupported data type for conversion to Python buffer");
}
-void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
- c.def("__len__", &PyDenseElementsAttribute::dunderLen)
- .def_static(
- "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
- nb::arg("signless") = true, nb::arg("type") = nb::none(),
- nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
- // clang-format off
- nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
- // clang-format on
- kDenseElementsAttrGetDocstring)
+template <typename ClassT>
+void PyDenseElementsAttribute::bindFactoryMethods(ClassT &c,
+ const char *pyClassName) {
+ std::string getSig1 =
+ // clang-format off
+ "def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> " +
+ // clang-format on
+ std::string(pyClassName);
+ std::string getSig2 =
+ // clang-format off
+ "def get(attrs: Sequence[Attribute], type: Type | None = None, context: Context | None = None) -> " +
+ // clang-format on
+ std::string(pyClassName);
+ std::string getSplatSig =
+ // clang-format off
+ "def get_splat(shaped_type: Type, element_attr: Attribute) -> " +
+ // clang-format on
+ std::string(pyClassName);
+
+ c.def_static("get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
+ nb::arg("signless") = true, nb::arg("type") = nb::none(),
+ nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
+ nb::sig(getSig1.c_str()), kDenseElementsAttrGetDocstring)
.def_static("get", PyDenseElementsAttribute::getFromList,
nb::arg("attrs"), nb::arg("type") = nb::none(),
- nb::arg("context") = nb::none(),
+ nb::arg("context") = nb::none(), nb::sig(getSig2.c_str()),
kDenseElementsAttrGetFromListDocstring)
.def_static("get_splat", PyDenseElementsAttribute::getSplat,
nb::arg("shaped_type"), nb::arg("element_attr"),
- "Gets a DenseElementsAttr where all values are the same")
- .def_prop_ro("is_splat",
- [](PyDenseElementsAttribute &self) -> bool {
- return mlirDenseElementsAttrIsSplat(self);
- })
+ nb::sig(getSplatSig.c_str()),
+ ("Gets a " + std::string(pyClassName) +
+ " where all values are the same")
+ .c_str());
+}
+
+void PyDenseElementsAttribute::bindDerived(ClassTy &c) {
+ c.def("__len__", &PyDenseElementsAttribute::dunderLen);
+ bindFactoryMethods(c, pyClassName);
+ c.def_prop_ro("is_splat",
+ [](PyDenseElementsAttribute &self) -> bool {
+ return mlirDenseElementsAttrIsSplat(self);
+ })
.def("get_splat_value",
[](PyDenseElementsAttribute &self)
-> nb::typed<nb::object, PyAttribute> {
@@ -1037,6 +1055,7 @@ nb::int_ PyDenseIntElementsAttribute::dunderGetItem(intptr_t pos) const {
}
void PyDenseIntElementsAttribute::bindDerived(ClassTy &c) {
+ PyDenseElementsAttribute::bindFactoryMethods(c, pyClassName);
c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
}
@@ -1215,6 +1234,7 @@ nb::float_ PyDenseFPElementsAttribute::dunderGetItem(intptr_t pos) const {
}
void PyDenseFPElementsAttribute::bindDerived(ClassTy &c) {
+ PyDenseElementsAttribute::bindFactoryMethods(c, pyClassName);
c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
}
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4d389c656e58d..b8637c57a3f48 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2988,7 +2988,8 @@ void populateIRCore(nb::module_ &m) {
"Returns True if an error was encountered during diagnostic "
"handling.")
.def("__enter__", &PyDiagnosticHandler::contextEnter,
- "Enters the diagnostic handler as a context manager.")
+ "Enters the diagnostic handler as a context manager.",
+ nb::sig("def __enter__(self, /) -> DiagnosticHandler"))
.def("__exit__", &PyDiagnosticHandler::contextExit, "exc_type"_a.none(),
"exc_value"_a.none(), "traceback"_a.none(),
"Exits the diagnostic handler context manager.");
@@ -3034,7 +3035,8 @@ void populateIRCore(nb::module_ &m) {
&PyMlirContext::createFromCapsule,
"Creates a Context from a capsule wrapping MlirContext.")
.def("__enter__", &PyMlirContext::contextEnter,
- "Enters the context as a context manager.")
+ "Enters the context as a context manager.",
+ nb::sig("def __enter__(self, /) -> Context"))
.def("__exit__", &PyMlirContext::contextExit, "exc_type"_a.none(),
"exc_value"_a.none(), "traceback"_a.none(),
"Exits the context manager.")
@@ -3260,7 +3262,8 @@ void populateIRCore(nb::module_ &m) {
.def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule,
"Creates a Location from a capsule wrapping MlirLocation.")
.def("__enter__", &PyLocation::contextEnter,
- "Enters the location as a context manager.")
+ "Enters the location as a context manager.",
+ nb::sig("def __enter__(self, /) -> Location"))
.def("__exit__", &PyLocation::contextExit, "exc_type"_a.none(),
"exc_value"_a.none(), "traceback"_a.none(),
"Exits the location context manager.")
@@ -4290,7 +4293,8 @@ void populateIRCore(nb::module_ &m) {
.def(nb::init<PyBlock &>(), "block"_a,
"Inserts after the last operation but still inside the block.")
.def("__enter__", &PyInsertionPoint::contextEnter,
- "Enters the insertion point as a context manager.")
+ "Enters the insertion point as a context manager.",
+ nb::sig("def __enter__(self, /) -> InsertionPoint"))
.def("__exit__", &PyInsertionPoint::contextExit, "exc_type"_a.none(),
"exc_value"_a.none(), "traceback"_a.none(),
"Exits the insertion point context manager.")
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index f4aa2d6b051c8..f84792d4095f4 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -4,7 +4,7 @@
from __future__ import annotations
-from collections.abc import Iterable
+from collections.abc import Generator
from contextlib import contextmanager
from ._mlir_libs._mlir.ir import *
@@ -22,7 +22,7 @@
@contextmanager
-def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]:
+def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None]:
"""Enables automatic traceback-based locations for MLIR operations.
Operations created within this context will have their location
>From 5cea4d0dd283e63a3c74def0ee6d30052147ba86 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <slebedev at google.com>
Date: Fri, 6 Mar 2026 16:57:14 +0000
Subject: [PATCH 2/3] [MLIR] [Python] Allow specifying a custom stubgen version
Stubgen is largely independent from the rest of nanobind, and thus can be
versioned separately.
---
mlir/cmake/modules/AddMLIRPython.cmake | 24 +++++++++++++++++++++++-
1 file changed, 23 insertions(+), 1 deletion(-)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 1821cfbf35d2a..1f0a180668914 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -123,8 +123,30 @@ function(mlir_generate_type_stubs)
"IMPORT_PATHS;DEPENDS_TARGETS;OUTPUTS;DEPENDS_TARGET_SRC_DEPS"
${ARGN})
+ # Allow overriding the stubgen.py path or fetching a specific version
+ # from the nanobind repository, independent of the nanobind used for
+ # building. This is useful when a newer stubgen has bug fixes or features
+ # not yet available in the nanobind version used for compilation.
+ if(MLIR_NB_STUBGEN)
+ set(NB_STUBGEN "${MLIR_NB_STUBGEN}")
+ elseif(MLIR_NB_STUBGEN_VERSION)
+ set(_stubgen_path "${MLIR_BINARY_DIR}/stubgen/${MLIR_NB_STUBGEN_VERSION}/stubgen.py")
+ if(NOT EXISTS "${_stubgen_path}")
+ message(STATUS "Downloading stubgen.py from nanobind ${MLIR_NB_STUBGEN_VERSION}...")
+ file(DOWNLOAD
+ "https://raw.githubusercontent.com/wjakob/nanobind/${MLIR_NB_STUBGEN_VERSION}/src/stubgen.py"
+ "${_stubgen_path}"
+ STATUS _download_status
+ )
+ list(GET _download_status 0 _download_code)
+ if(NOT _download_code EQUAL 0)
+ list(GET _download_status 1 _download_error)
+ message(FATAL_ERROR "Failed to download stubgen.py: ${_download_error}")
+ endif()
+ endif()
+ set(NB_STUBGEN "${_stubgen_path}")
# for people installing a distro (e.g., pip install) of nanobind
- if(EXISTS ${nanobind_DIR}/../src/stubgen.py)
+ elseif(EXISTS ${nanobind_DIR}/../src/stubgen.py)
set(NB_STUBGEN "${nanobind_DIR}/../src/stubgen.py")
elseif(EXISTS ${nanobind_DIR}/../stubgen.py)
set(NB_STUBGEN "${nanobind_DIR}/../stubgen.py")
>From b4455b74e2be0685cd7540b83a1cae28fe6bf790 Mon Sep 17 00:00:00 2001
From: Sergei Lebedev <185856+superbobry at users.noreply.github.com>
Date: Sun, 8 Mar 2026 13:07:36 +0000
Subject: [PATCH 3/3] Update mlir/cmake/modules/AddMLIRPython.cmake
Co-authored-by: Maksim Levental <maksim.levental at gmail.com>
---
mlir/cmake/modules/AddMLIRPython.cmake | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 1f0a180668914..436148a4d6fb8 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -133,6 +133,21 @@ function(mlir_generate_type_stubs)
set(_stubgen_path "${MLIR_BINARY_DIR}/stubgen/${MLIR_NB_STUBGEN_VERSION}/stubgen.py")
if(NOT EXISTS "${_stubgen_path}")
message(STATUS "Downloading stubgen.py from nanobind ${MLIR_NB_STUBGEN_VERSION}...")
+ file(MAKE_DIRECTORY "${MLIR_BINARY_DIR}/stubgen/${MLIR_NB_STUBGEN_VERSION}" RESULT _created_dir)
+ if(NOT _created_dir EQUAL 0)
+ list(GET _created_dir 1 _created_dir_error)
+ message(FATAL_ERROR "Failed to create parent dir for stubgen.py: ${_created_dir_error}")
+ endif()
+ file(DOWNLOAD
+ "https://raw.githubusercontent.com/wjakob/nanobind/${MLIR_NB_STUBGEN_VERSION}/src/stubgen.py"
+ "${_stubgen_path}"
+ STATUS _download_status
+ )
+ list(GET _download_status 0 _download_code)
+ if(NOT _download_code EQUAL 0)
+ list(GET _download_status 1 _download_error)
+ message(FATAL_ERROR "Failed to download stubgen.py: ${_download_error}")
+ endif()
file(DOWNLOAD
"https://raw.githubusercontent.com/wjakob/nanobind/${MLIR_NB_STUBGEN_VERSION}/src/stubgen.py"
"${_stubgen_path}"
More information about the Mlir-commits
mailing list