[Mlir-commits] [mlir] [MLIR][Python] restore types (PR #160194)
Maksim Levental
llvmlistbot at llvm.org
Mon Sep 22 14:00:22 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160194
>From 7468dff85f01a66722e9543131f5eefef0815175 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 16:27:46 -0400
Subject: [PATCH] [MLIR][Python] restore types
---
mlir/lib/Bindings/Python/IRAttributes.cpp | 38 ++++++-----
mlir/lib/Bindings/Python/IRCore.cpp | 82 ++++++++++++++---------
mlir/lib/Bindings/Python/IRModule.h | 5 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 44 ++++++------
4 files changed, 101 insertions(+), 68 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..404c4d842e02c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,13 +485,14 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
- return PyAttribute(this->attr.getContext(),
- mlirArrayAttrGetElement(attr.get(), nextIndex++))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(this->attr.getContext(),
+ mlirArrayAttrGetElement(attr.get(), nextIndex++))
+ .maybeDownCast());
}
static void bind(nb::module_ &m) {
@@ -524,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
},
nb::arg("attributes"), nb::arg("context") = nb::none(),
"Gets a uniqued Array attribute");
- c.def(
- "__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
- if (i >= mlirArrayAttrGetNumElements(arr))
- throw nb::index_error("ArrayAttribute index out of range");
- return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
- })
+ c.def("__getitem__",
+ [](PyArrayAttribute &arr, intptr_t i) {
+ if (i >= mlirArrayAttrGetNumElements(arr))
+ throw nb::index_error("ArrayAttribute index out of range");
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast());
+ })
.def("__len__",
[](const PyArrayAttribute &arr) {
return mlirArrayAttrGetNumElements(arr);
@@ -1014,9 +1015,10 @@ class PyDenseElementsAttribute
if (!mlirDenseElementsAttrIsSplat(self))
throw nb::value_error(
"get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast());
});
}
@@ -1527,7 +1529,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(self.getContext(), attr).maybeDownCast());
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
@@ -1595,8 +1598,9 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
c.def_prop_ro("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast());
});
}
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 81386f2227a7f..5a6edfa737fd7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -528,7 +528,8 @@ class PyOperationIterator {
static void bind(nb::module_ &m) {
nb::class_<PyOperationIterator>(m, "OperationIterator")
.def("__iter__", &PyOperationIterator::dunderIter)
- .def("__next__", &PyOperationIterator::dunderNext);
+ .def("__next__", &PyOperationIterator::dunderNext,
+ nb::sig("def __next__(self) -> OpView"));
}
private:
@@ -1604,8 +1605,9 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) {
+ return nb::cast<nb::typed<nb::object, DerivedTy>>(self.maybeDownCast());
+ });
DerivedTy::bindDerived(cls);
}
@@ -1638,14 +1640,15 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nb::typed<nb::object, PyType>> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(PyType(context->getRef(),
- mlirValueGetType(container.getElement(i).get()))
- .maybeDownCast());
+ result.push_back(nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(context->getRef(),
+ mlirValueGetType(container.getElement(i).get()))
+ .maybeDownCast()));
}
return result;
}
@@ -2677,13 +2680,15 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- nb::object dunderGetItemNamed(const std::string &name) {
+ nb::typed<nb::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
throw nb::key_error("attempt to access a non-existent attribute");
}
- return PyAttribute(operation->getContext(), attr).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(operation->getContext(), attr).maybeDownCast());
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
@@ -2961,14 +2966,17 @@ void mlir::python::populateIRCore(nb::module_ &m) {
new (&self) PyMlirContext(context);
})
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
- .def("_get_context_again",
- [](PyMlirContext &self) {
- PyMlirContextRef ref = PyMlirContext::forContext(self.get());
- return ref.releaseObject();
- })
+ .def(
+ "_get_context_again",
+ [](PyMlirContext &self) {
+ PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+ return ref.releaseObject();
+ },
+ nb::sig("def _get_context_again(self) -> Context"))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule,
+ nb::sig("def _CAPICreate(self) -> Context"))
.def("__enter__", &PyMlirContext::contextEnter)
.def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3463,8 +3471,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
- return PyOpResult(operation.getRef(), getUniqueResult(operation))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyOpResult>>(
+ PyOpResult(operation.getRef(), getUniqueResult(operation))
+ .maybeDownCast());
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
@@ -3988,7 +3997,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
context->get(), toMlirStringRef(attrSpec));
if (mlirAttributeIsNull(attr))
throw MLIRError("Unable to parse attribute", errors.take());
- return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyAttribute>>(
+ PyAttribute(context.get()->getRef(), attr).maybeDownCast());
},
nb::arg("asm"), nb::arg("context") = nb::none(),
"Parses an attribute from an assembly form. Raises an MLIRError on "
@@ -3999,9 +4009,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Context that owns the Attribute")
.def_prop_ro("type",
[](PyAttribute &self) {
- return PyType(self.getContext(),
- mlirAttributeGetType(self))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getContext(), mlirAttributeGetType(self))
+ .maybeDownCast());
})
.def(
"get_named",
@@ -4049,7 +4059,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"mlirTypeID was expected to be non-null.");
return PyTypeID(mlirTypeID);
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
+ nb::cast<nb::typed<nb::object, PyAttribute>>(self.maybeDownCast());
+ });
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
@@ -4100,7 +4112,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
if (mlirTypeIsNull(type))
throw MLIRError("Unable to parse type", errors.take());
- return PyType(context.get()->getRef(), type).maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(context.get()->getRef(), type).maybeDownCast());
},
nb::arg("asm"), nb::arg("context") = nb::none(),
kContextParseTypeDocstring)
@@ -4139,7 +4152,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.parts.append(")");
return printAccum.join();
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) {
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ self.maybeDownCast());
+ })
.def_prop_ro("typeid", [](PyType &self) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4267,9 +4284,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("state"), kGetNameAsOperand)
.def_prop_ro("type",
[](PyValue &self) {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()))
+ .maybeDownCast());
})
.def(
"set_type",
@@ -4305,7 +4323,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("with_"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) {
+ return nb::cast<nb::typed<nb::object, PyValue>>(
+ self.maybeDownCast());
+ })
.def_prop_ro(
"location",
[](MlirValue self) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..dc9913bc5ebb2 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1102,8 +1102,9 @@ class PyConcreteAttribute : public BaseTy {
},
nanobind::arg("other"));
cls.def_prop_ro("type", [](PyAttribute &attr) {
- return PyType(attr.getContext(), mlirAttributeGetType(attr))
- .maybeDownCast();
+ return nanobind::cast<nanobind::typed<nanobind::object, PyType>>(
+ PyType(attr.getContext(), mlirAttributeGetType(attr))
+ .maybeDownCast());
});
cls.def_prop_ro_static(
"static_typeid",
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7aa1c65c6c43..a228ca4418c4a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -502,8 +502,9 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
c.def_prop_ro(
"element_type",
[](PyComplexType &self) {
- return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+ .maybeDownCast());
},
"Returns element type.");
}
@@ -516,8 +517,9 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
[](PyShapedType &self) {
- return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getContext(), mlirShapedTypeGetElementType(self))
+ .maybeDownCast());
},
"Returns the element type of the shaped type.");
c.def_prop_ro(
@@ -898,11 +900,21 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
},
nb::arg("elements"), nb::arg("context") = nb::none(),
"Create a tuple type");
+ c.def_static(
+ "get_tuple",
+ [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+ MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+ elements.data());
+ return PyTupleType(context->getRef(), t);
+ },
+ nb::arg("elements"), nb::arg("context") = nb::none(),
+ "Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
- return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
- .maybeDownCast();
+ return nb::cast<nb::typed<nb::object, PyType>>(
+ PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+ .maybeDownCast());
},
nb::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_prop_ro(
@@ -926,23 +938,17 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
- [](std::vector<PyType> inputs, std::vector<PyType> results,
+ [](std::vector<MlirType> inputs, std::vector<MlirType> results,
DefaultingPyMlirContext context) {
- std::vector<MlirType> mlirInputs;
- mlirInputs.reserve(inputs.size());
- for (const auto &input : inputs)
- mlirInputs.push_back(input.get());
- std::vector<MlirType> mlirResults;
- mlirResults.reserve(results.size());
- for (const auto &result : results)
- mlirResults.push_back(result.get());
-
- MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
- mlirInputs.data(), results.size(),
- mlirResults.data());
+ MlirType t =
+ mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+ results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+ // clang-format off
+ nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
+ // clang-format on
"Gets a FunctionType from a list of input and result types");
c.def_prop_ro(
"inputs",
More information about the Mlir-commits
mailing list