[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)
Maksim Levental
llvmlistbot at llvm.org
Mon Sep 22 19:18:04 PDT 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/160221
None
>From 3376f4f23dbee295a81948b1dfb5d8c3459990e1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 18:57:51 -0700
Subject: [PATCH] [MLIR][Python] use nb::typed for return signatures
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 48 +++++++++++++----------
mlir/lib/Bindings/Python/IRCore.cpp | 40 +++++++++++++------
mlir/lib/Bindings/Python/IRModule.h | 10 +++--
mlir/lib/Bindings/Python/IRTypes.cpp | 21 +++++-----
4 files changed, 69 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..51c3c46bcd02b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
"Gets a uniqued Array attribute");
c.def(
"__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
- .def("get_splat_value", [](PyDenseElementsAttribute &self) {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nb::typed<nb::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
}
static PyType_Slot slots[];
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
},
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
"Gets an uniqued dict attribute");
- c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
+ c.def("__getitem__",
+ [](PyDictAttribute &self,
+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
- c.def_prop_ro("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
+ c.def_prop_ro(
+ "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
}
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..c7af6a6ce0d60 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
},
nb::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
+ [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
DerivedTy::bindDerived(cls);
}
@@ -1638,9 +1640,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nb::typed<nb::object, PyType>> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
@@ -2677,7 +2679,8 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- nb::object dunderGetItemNamed(const std::string &name) {
+ nb::typed<nb::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
@@ -3461,7 +3464,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Returns the list of Operation results.")
.def_prop_ro(
"result",
- [](PyOperationBase &self) {
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
auto &operation = self.getOperation();
return PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast();
@@ -3982,7 +3985,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
.def_static(
"parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context) {
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute attr = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
@@ -3998,7 +4002,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyAttribute &self) { return self.getContext().getObject(); },
"Context that owns the Attribute")
.def_prop_ro("type",
- [](PyAttribute &self) {
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(),
mlirAttributeGetType(self))
.maybeDownCast();
@@ -4049,7 +4053,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"mlirTypeID was expected to be non-null.");
return PyTypeID(mlirTypeID);
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ });
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
@@ -4094,7 +4101,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
"parse",
- [](std::string typeSpec, DefaultingPyMlirContext context) {
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
@@ -4139,7 +4147,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.parts.append(")");
return printAccum.join();
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ })
.def_prop_ro("typeid", [](PyType &self) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4266,7 +4277,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("state"), kGetNameAsOperand)
.def_prop_ro("type",
- [](PyValue &self) {
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getParentOperation()->getContext(),
mlirValueGetType(self.get()))
.maybeDownCast();
@@ -4332,7 +4343,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("with_"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ })
.def_prop_ro(
"location",
[](MlirValue self) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..0aabab8991b46 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy {
return DerivedTy::isaFunction(otherAttr);
},
nanobind::arg("other"));
- cls.def_prop_ro("type", [](PyAttribute &attr) {
- return PyType(attr.getContext(), mlirAttributeGetType(attr))
- .maybeDownCast();
- });
+ cls.def_prop_ro(
+ "type",
+ [](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(attr.getContext(), mlirAttributeGetType(attr))
+ .maybeDownCast();
+ });
cls.def_prop_ro_static(
"static_typeid",
[](nanobind::object & /*class*/) -> PyTypeID {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cab3bf549295b..07dc00521833f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
"Create a complex type");
c.def_prop_ro(
"element_type",
- [](PyComplexType &self) {
+ [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
.maybeDownCast();
},
@@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
void mlir::PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
- [](PyShapedType &self) {
+ [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
.maybeDownCast();
},
@@ -731,8 +731,7 @@ class PyRankedTensorType
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), encoding).maybeDownCast());
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
});
}
};
@@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
- .maybeDownCast());
+ return PyAttribute(self.getContext(),
+ mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
},
"The layout of the MemRef type.")
.def(
@@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), a).maybeDownCast());
+ return PyAttribute(self.getContext(), a).maybeDownCast();
},
"Returns the memory space of the given MemRef type.");
}
@@ -867,8 +865,7 @@ class PyUnrankedMemRefType
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), a).maybeDownCast());
+ return PyAttribute(self.getContext(), a).maybeDownCast();
},
"Returns the memory space of the given Unranked MemRef type.");
}
@@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
"Create a tuple type");
c.def(
"get_type",
- [](PyTupleType &self, intptr_t pos) {
+ [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
.maybeDownCast();
},
More information about the Mlir-commits
mailing list