[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)
Maksim Levental
llvmlistbot at llvm.org
Mon Sep 22 20:35:50 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160221
>From fbd9df278abb825145e6a307786a759d1eff2d38 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 18:57:51 -0700
Subject: [PATCH] [MLIR][Python] use nb::typed for return signatures
---
mlir/lib/Bindings/Python/IRAffine.cpp | 49 +++---
mlir/lib/Bindings/Python/IRAttributes.cpp | 50 +++---
mlir/lib/Bindings/Python/IRCore.cpp | 194 ++++++++++++++--------
mlir/lib/Bindings/Python/IRInterfaces.cpp | 30 ++--
mlir/lib/Bindings/Python/IRModule.h | 14 +-
mlir/lib/Bindings/Python/IRTypes.cpp | 21 +--
6 files changed, 216 insertions(+), 142 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index bc6aa0dac6221..7147f2cbad149 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyAffineExpr &self) { return self.getContext().getObject(); })
+ [](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ })
.def("compose",
[](PyAffineExpr &self, PyAffineMap &other) {
return PyAffineExpr(self.getContext(),
@@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
[](PyAffineMap &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
- .def_static("compress_unused_symbols",
- [](const nb::list &affineMaps,
- DefaultingPyMlirContext context) {
- SmallVector<MlirAffineMap> maps;
- pyListToVector<PyAffineMap, MlirAffineMap>(
- affineMaps, maps, "attempting to create an AffineMap");
- std::vector<MlirAffineMap> compressed(affineMaps.size());
- auto populate = [](void *result, intptr_t idx,
- MlirAffineMap m) {
- static_cast<MlirAffineMap *>(result)[idx] = (m);
- };
- mlirAffineMapCompressUnusedSymbols(
- maps.data(), maps.size(), compressed.data(), populate);
- std::vector<PyAffineMap> res;
- res.reserve(compressed.size());
- for (auto m : compressed)
- res.emplace_back(context->getRef(), m);
- return res;
- })
+ .def_static(
+ "compress_unused_symbols",
+ [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+ SmallVector<MlirAffineMap> maps;
+ pyListToVector<PyAffineMap, MlirAffineMap>(
+ affineMaps, maps, "attempting to create an AffineMap");
+ std::vector<MlirAffineMap> compressed(affineMaps.size());
+ auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
+ static_cast<MlirAffineMap *>(result)[idx] = (m);
+ };
+ mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
+ compressed.data(), populate);
+ std::vector<PyAffineMap> res;
+ res.reserve(compressed.size());
+ for (auto m : compressed)
+ res.emplace_back(context->getRef(), m);
+ return res;
+ })
.def_prop_ro(
"context",
- [](PyAffineMap &self) { return self.getContext().getObject(); },
+ [](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Affine Map")
.def(
"dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
@@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyIntegerSet &self) { return self.getContext().getObject(); })
+ [](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ })
.def(
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
kDumpDocstring)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..c77653f97e6dd 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
"Gets a uniqued Array attribute");
c.def(
"__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
- .def("get_splat_value", [](PyDenseElementsAttribute &self) {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nb::typed<nb::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
}
static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
/// Returns the element at the given linear position. Asserts if the index
/// is out of range.
- nb::object dunderGetItem(intptr_t pos) {
+ nb::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds element");
}
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
},
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
"Gets an uniqued dict attribute");
- c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
+ c.def("__getitem__",
+ [](PyDictAttribute &self,
+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
- c.def_prop_ro("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
+ c.def_prop_ro(
+ "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
}
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..5bad942d70374 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -513,7 +513,7 @@ class PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyOpView> dunderNext() {
parentOperation->checkValid();
if (mlirOperationIsNull(next)) {
throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
return count;
}
- nb::object dunderGetItem(intptr_t index) {
+ nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
parentOperation->checkValid();
if (index < 0) {
index += dunderLen();
@@ -725,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
new PyDiagnosticHandler(get(), std::move(callback));
nb::object pyHandlerObject =
nb::cast(pyHandler, nb::rv_policy::take_ownership);
- pyHandlerObject.inc_ref();
+ (void)pyHandlerObject.inc_ref();
// In these C callbacks, the userData is a PyDiagnosticHandler* that is
// guaranteed to be known to pybind.
@@ -1395,7 +1395,7 @@ nb::object PyOperation::getCapsule() {
return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
}
-nb::object PyOperation::createFromCapsule(nb::object capsule) {
+nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
if (mlirOperationIsNull(rawOperation))
throw nb::python_error();
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
},
nb::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
+ [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
DerivedTy::bindDerived(cls);
}
@@ -1623,13 +1625,14 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyOpResult &self) {
- assert(
- mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in the IR");
- return self.getParentOperation().getObject();
- });
+ c.def_prop_ro(
+ "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ });
c.def_prop_ro("result_number", [](PyOpResult &self) {
return mlirOpResultGetResultNumber(self.get());
});
@@ -1638,9 +1641,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nb::typed<nb::object, PyType>> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
@@ -1671,9 +1674,10 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
c.def_prop_ro("types", [](PyOpResultList &self) {
return getValueTypes(self, self.operation->getContext());
});
- c.def_prop_ro("owner", [](PyOpResultList &self) {
- return self.operation->createOpView();
- });
+ c.def_prop_ro("owner",
+ [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+ return self.operation->createOpView();
+ });
}
PyOperationRef &getOperation() { return operation; }
@@ -2104,7 +2108,7 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
- return PyThreadContextEntry::pushInsertionPoint(insertPoint);
+ return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
}
void PyInsertionPoint::contextExit(const nb::object &excType,
@@ -2125,7 +2129,7 @@ nb::object PyAttribute::getCapsule() {
return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
}
-PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
+PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
if (mlirAttributeIsNull(rawAttr))
throw nb::python_error();
@@ -2677,7 +2681,8 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- nb::object dunderGetItemNamed(const std::string &name) {
+ nb::typed<nb::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
@@ -2962,13 +2967,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
- [](PyMlirContext &self) {
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
.def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3123,7 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyDialectRegistry>(m, "DialectRegistry")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule)
.def(nb::init<>());
//----------------------------------------------------------------------------
@@ -3131,7 +3138,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyLocation>(m, "Location")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
.def("__enter__", &PyLocation::contextEnter)
.def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3286,7 +3293,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Gets a Location from a LocationAttr")
.def_prop_ro(
"context",
- [](PyLocation &self) { return self.getContext().getObject(); },
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Location")
.def_prop_ro(
"attr",
@@ -3313,12 +3322,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- kModuleCAPICreate)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ kModuleCAPICreate)
.def("_clear_mlir_module", &PyModule::clearMlirModule)
.def_static(
"parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
@@ -3330,7 +3340,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
@@ -3342,7 +3353,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"parseFile",
- [](const std::string &path, DefaultingPyMlirContext context) {
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParseFromFile(
context->get(), toMlirStringRef(path));
@@ -3354,7 +3366,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
- [](const std::optional<PyLocation> &loc) {
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
PyLocation pyLoc = maybeGetTracebackLocation(loc);
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
@@ -3362,11 +3375,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("loc") = nb::none(), "Creates an empty module")
.def_prop_ro(
"context",
- [](PyModule &self) { return self.getContext().getObject(); },
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that created the Module")
.def_prop_ro(
"operation",
- [](PyModule &self) {
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
return PyOperation::forOperation(self.getContext(),
mlirModuleGetOperation(self.get()),
self.getRef().releaseObject())
@@ -3430,7 +3445,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyOperationBase &self) {
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
PyOperation &concreteOperation = self.getOperation();
concreteOperation.checkValid();
return concreteOperation.getContext().getObject();
@@ -3461,7 +3476,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Returns the list of Operation results.")
.def_prop_ro(
"result",
- [](PyOperationBase &self) {
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
auto &operation = self.getOperation();
return PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast();
@@ -3478,11 +3493,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Returns the source location the operation was defined or derived "
"from.")
.def_prop_ro("parent",
- [](PyOperationBase &self) -> nb::object {
+ [](PyOperationBase &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
auto parent = self.getOperation().getParentOperation();
if (parent)
return parent->getObject();
- return nb::none();
+ return {};
})
.def(
"__str__",
@@ -3553,13 +3569,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"of the parent block.")
.def(
"clone",
- [](PyOperationBase &self, nb::object ip) {
+ [](PyOperationBase &self,
+ const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
return self.getOperation().clone(ip);
},
nb::arg("ip") = nb::none())
.def(
"detach_from_parent",
- [](PyOperationBase &self) {
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
PyOperation &operation = self.getOperation();
operation.checkValid();
if (!operation.isAttached())
@@ -3595,7 +3612,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
const std::optional<PyLocation> &location,
- const nb::object &maybeIp, bool inferType) {
+ const nb::object &maybeIp,
+ bool inferType) -> nb::typed<nb::object, PyOperation> {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
@@ -3620,7 +3638,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,
- DefaultingPyMlirContext context) {
+ DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyOpView> {
return PyOperation::parse(context->getRef(), sourceStr, sourceName)
->createOpView();
},
@@ -3629,9 +3648,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Parses an operation. Supports both text assembly format and binary "
"bytecode format.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
- .def_prop_ro("operation", [](nb::object self) { return self; })
- .def_prop_ro("opview", &PyOperation::createOpView)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyOperation::createFromCapsule)
+ .def_prop_ro("operation",
+ [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+ return self;
+ })
+ .def_prop_ro("opview",
+ [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+ return self.createOpView();
+ })
.def_prop_ro("block", &PyOperation::getBlock)
.def_prop_ro(
"successors",
@@ -3644,7 +3670,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
- .def(nb::init<nb::object>(), nb::arg("operation"))
+ .def(nb::init<nb::typed<nb::object, PyOperation>>(),
+ nb::arg("operation"))
.def(
"__init__",
[](PyOpView *self, std::string_view name,
@@ -3671,9 +3698,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("successors") = nb::none(),
nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
nb::arg("ip") = nb::none())
-
- .def_prop_ro("operation", &PyOpView::getOperationObject)
- .def_prop_ro("opview", [](nb::object self) { return self; })
+ .def_prop_ro(
+ "operation",
+ [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ })
+ .def_prop_ro("opview",
+ [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+ return self;
+ })
.def(
"__str__",
[](PyOpView &self) { return nb::str(self.getOperationObject()); })
@@ -3717,7 +3750,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Builds a specific, generated OpView based on class level attributes.");
opViewClass.attr("parse") = classmethod(
[](const nb::object &cls, const std::string &sourceStr,
- const std::string &sourceName, DefaultingPyMlirContext context) {
+ const std::string &sourceName,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
PyOperationRef parsed =
PyOperation::parse(context->getRef(), sourceStr, sourceName);
@@ -3752,7 +3786,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Returns a forward-optimized sequence of blocks.")
.def_prop_ro(
"owner",
- [](PyRegion &self) {
+ [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
return self.getParentOperation()->createOpView();
},
"Returns the operation owning this region.")
@@ -3777,7 +3811,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
.def_prop_ro(
"owner",
- [](PyBlock &self) {
+ [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
return self.getParentOperation()->createOpView();
},
"Returns the owning operation of this block.")
@@ -3960,11 +3994,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Returns the block that this InsertionPoint points to.")
.def_prop_ro(
"ref_operation",
- [](PyInsertionPoint &self) -> nb::object {
+ [](PyInsertionPoint &self)
+ -> std::optional<nb::typed<nb::object, PyOperation>> {
auto refOperation = self.getRefOperation();
if (refOperation)
return refOperation->getObject();
- return nb::none();
+ return {};
},
"The reference operation before which new operations are "
"inserted, or None if the insertion point is at the end of "
@@ -3979,10 +4014,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
"Casts the passed attribute to the generic Attribute")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyAttribute::createFromCapsule)
.def_static(
"parse",
- [](const std::string &attrSpec, DefaultingPyMlirContext context) {
+ [](const std::string &attrSpec, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyAttribute> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirAttribute attr = mlirAttributeParseGet(
context->get(), toMlirStringRef(attrSpec));
@@ -3995,10 +4032,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"failure.")
.def_prop_ro(
"context",
- [](PyAttribute &self) { return self.getContext().getObject(); },
+ [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Attribute")
.def_prop_ro("type",
- [](PyAttribute &self) {
+ [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(),
mlirAttributeGetType(self))
.maybeDownCast();
@@ -4049,7 +4088,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"mlirTypeID was expected to be non-null.");
return PyTypeID(mlirTypeID);
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+ return self.maybeDownCast();
+ });
//----------------------------------------------------------------------------
// Mapping of PyNamedAttribute
@@ -4091,10 +4133,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def(nb::init<PyType &>(), nb::arg("cast_from_type"),
"Casts the passed type to the generic Type")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
.def_static(
"parse",
- [](std::string typeSpec, DefaultingPyMlirContext context) {
+ [](std::string typeSpec,
+ DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirType type =
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
@@ -4105,7 +4148,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("asm"), nb::arg("context") = nb::none(),
kContextParseTypeDocstring)
.def_prop_ro(
- "context", [](PyType &self) { return self.getContext().getObject(); },
+ "context",
+ [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Type")
.def("__eq__", [](PyType &self, PyType &other) { return self == other; })
.def(
@@ -4139,7 +4185,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
printAccum.parts.append(")");
return printAccum.join();
})
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyType &self) -> nb::typed<nb::object, PyType> {
+ return self.maybeDownCast();
+ })
.def_prop_ro("typeid", [](PyType &self) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4154,7 +4203,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyTypeID>(m, "TypeID")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
// Note, this tests whether the underlying TypeIDs are the same,
// not whether the wrapper MlirTypeIDs are the same, nor whether
// the Python objects are the same (i.e., PyTypeID is a value type).
@@ -4175,10 +4224,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::class_<PyValue>(m, "Value")
.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
.def_prop_ro(
"context",
- [](PyValue &self) {
+ [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getParentOperation()->getContext().getObject();
},
"Context in which the value lives.")
@@ -4266,7 +4315,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("state"), kGetNameAsOperand)
.def_prop_ro("type",
- [](PyValue &self) {
+ [](PyValue &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getParentOperation()->getContext(),
mlirValueGetType(self.get()))
.maybeDownCast();
@@ -4332,7 +4381,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
},
nb::arg("with_"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
- .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+ .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+ return self.maybeDownCast();
+ })
.def_prop_ro(
"location",
[](MlirValue self) {
@@ -4357,7 +4409,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PySymbolTable>(m, "SymbolTable")
.def(nb::init<PyOperationBase &>())
- .def("__getitem__", &PySymbolTable::dunderGetItem)
+ .def("__getitem__",
+ [](PySymbolTable &self,
+ const std::string &name) -> nb::typed<nb::object, PyOpView> {
+ return self.dunderGetItem(name);
+ })
.def("insert", &PySymbolTable::insert, nb::arg("operation"))
.def("erase", &PySymbolTable::erase, nb::arg("operation"))
.def("__delitem__", &PySymbolTable::dunderDel)
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 44aad10ded082..32d52b48668cf 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -196,9 +196,19 @@ class PyConcreteOpInterface {
nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
nb::arg("context") = nb::none(), constructorDoc)
- .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
- operationDoc)
- .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
+ .def_prop_ro(
+ "operation",
+ [](PyConcreteOpInterface &self)
+ -> nb::typed<nb::object, PyOperation> {
+ return self.getOperationObject();
+ },
+ operationDoc)
+ .def_prop_ro(
+ "opview",
+ [](PyConcreteOpInterface &self) -> nb::typed<nb::object, PyOpView> {
+ return self.getOpView();
+ },
+ opviewDoc);
ConcreteIface::bindDerived(cls);
}
@@ -362,10 +372,9 @@ class PyShapedTypeComponents {
"Returns whether the given shaped type component is ranked.")
.def_prop_ro(
"rank",
- [](PyShapedTypeComponents &self) -> nb::object {
- if (!self.ranked) {
- return nb::none();
- }
+ [](PyShapedTypeComponents &self) -> std::optional<nb::int_> {
+ if (!self.ranked)
+ return {};
return nb::int_(self.shape.size());
},
"Returns the rank of the given ranked shaped type components. If "
@@ -373,10 +382,9 @@ class PyShapedTypeComponents {
"returned.")
.def_prop_ro(
"shape",
- [](PyShapedTypeComponents &self) -> nb::object {
- if (!self.ranked) {
- return nb::none();
- }
+ [](PyShapedTypeComponents &self) -> std::optional<nb::list> {
+ if (!self.ranked)
+ return {};
return nb::list(self.shape);
},
"Returns the shape of the ranked shaped type components as a list "
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..598ae0188464a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -671,7 +671,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// Creates a PyOperation from the MlirOperation wrapped by a capsule.
/// Ownership of the underlying MlirOperation is taken by calling this
/// function.
- static nanobind::object createFromCapsule(nanobind::object capsule);
+ static nanobind::object createFromCapsule(const nanobind::object &capsule);
/// Creates an operation. See corresponding python docstring.
static nanobind::object
@@ -1020,7 +1020,7 @@ class PyAttribute : public BaseContextObject {
/// Note that PyAttribute instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAttribute
/// is taken by calling this function.
- static PyAttribute createFromCapsule(nanobind::object capsule);
+ static PyAttribute createFromCapsule(const nanobind::object &capsule);
nanobind::object maybeDownCast();
@@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy {
return DerivedTy::isaFunction(otherAttr);
},
nanobind::arg("other"));
- cls.def_prop_ro("type", [](PyAttribute &attr) {
- return PyType(attr.getContext(), mlirAttributeGetType(attr))
- .maybeDownCast();
- });
+ cls.def_prop_ro(
+ "type",
+ [](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
+ return PyType(attr.getContext(), mlirAttributeGetType(attr))
+ .maybeDownCast();
+ });
cls.def_prop_ro_static(
"static_typeid",
[](nanobind::object & /*class*/) -> PyTypeID {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cab3bf549295b..07dc00521833f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
"Create a complex type");
c.def_prop_ro(
"element_type",
- [](PyComplexType &self) {
+ [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
.maybeDownCast();
},
@@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
void mlir::PyShapedType::bindDerived(ClassTy &c) {
c.def_prop_ro(
"element_type",
- [](PyShapedType &self) {
+ [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
.maybeDownCast();
},
@@ -731,8 +731,7 @@ class PyRankedTensorType
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), encoding).maybeDownCast());
+ return PyAttribute(self.getContext(), encoding).maybeDownCast();
});
}
};
@@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
- .maybeDownCast());
+ return PyAttribute(self.getContext(),
+ mlirMemRefTypeGetLayout(self))
+ .maybeDownCast();
},
"The layout of the MemRef type.")
.def(
@@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), a).maybeDownCast());
+ return PyAttribute(self.getContext(), a).maybeDownCast();
},
"Returns the memory space of the given MemRef type.");
}
@@ -867,8 +865,7 @@ class PyUnrankedMemRefType
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
if (mlirAttributeIsNull(a))
return std::nullopt;
- return nb::cast<nb::typed<nb::object, PyAttribute>>(
- PyAttribute(self.getContext(), a).maybeDownCast());
+ return PyAttribute(self.getContext(), a).maybeDownCast();
},
"Returns the memory space of the given Unranked MemRef type.");
}
@@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
"Create a tuple type");
c.def(
"get_type",
- [](PyTupleType &self, intptr_t pos) {
+ [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
.maybeDownCast();
},
More information about the Mlir-commits
mailing list