[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 22 20:32:30 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
https://github.com/llvm/llvm-project/pull/160183 removed `nb::typed` annotation to fix bazel but it turned out to simply a matter of not using the correct version of nanobind (see https://github.com/llvm/llvm-project/pull/160183#issuecomment-3321429155). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.
---
Patch is 41.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160221.diff
6 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+27-22)
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+28-22)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+125-68)
- (modified) mlir/lib/Bindings/Python/IRInterfaces.cpp (+19-11)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+8-6)
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+9-12)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index bc6aa0dac6221..7147f2cbad149 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyAffineExpr &self) { return self.getContext().getObject(); })
+ [](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ })
.def("compose",
[](PyAffineExpr &self, PyAffineMap &other) {
return PyAffineExpr(self.getContext(),
@@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
[](PyAffineMap &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
- .def_static("compress_unused_symbols",
- [](const nb::list &affineMaps,
- DefaultingPyMlirContext context) {
- SmallVector<MlirAffineMap> maps;
- pyListToVector<PyAffineMap, MlirAffineMap>(
- affineMaps, maps, "attempting to create an AffineMap");
- std::vector<MlirAffineMap> compressed(affineMaps.size());
- auto populate = [](void *result, intptr_t idx,
- MlirAffineMap m) {
- static_cast<MlirAffineMap *>(result)[idx] = (m);
- };
- mlirAffineMapCompressUnusedSymbols(
- maps.data(), maps.size(), compressed.data(), populate);
- std::vector<PyAffineMap> res;
- res.reserve(compressed.size());
- for (auto m : compressed)
- res.emplace_back(context->getRef(), m);
- return res;
- })
+ .def_static(
+ "compress_unused_symbols",
+ [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+ SmallVector<MlirAffineMap> maps;
+ pyListToVector<PyAffineMap, MlirAffineMap>(
+ affineMaps, maps, "attempting to create an AffineMap");
+ std::vector<MlirAffineMap> compressed(affineMaps.size());
+ auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
+ static_cast<MlirAffineMap *>(result)[idx] = (m);
+ };
+ mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
+ compressed.data(), populate);
+ std::vector<PyAffineMap> res;
+ res.reserve(compressed.size());
+ for (auto m : compressed)
+ res.emplace_back(context->getRef(), m);
+ return res;
+ })
.def_prop_ro(
"context",
- [](PyAffineMap &self) { return self.getContext().getObject(); },
+ [](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Affine Map")
.def(
"dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
@@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyIntegerSet &self) { return self.getContext().getObject(); })
+ [](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ })
.def(
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
kDumpDocstring)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..c77653f97e6dd 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
PyArrayAttributeIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
"Gets a uniqued Array attribute");
c.def(
"__getitem__",
- [](PyArrayAttribute &arr, intptr_t i) {
+ [](PyArrayAttribute &arr,
+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
- .def("get_splat_value", [](PyDenseElementsAttribute &self) {
- if (!mlirDenseElementsAttrIsSplat(self))
- throw nb::value_error(
- "get_splat_value called on a non-splat attribute");
- return PyAttribute(self.getContext(),
- mlirDenseElementsAttrGetSplatValue(self))
- .maybeDownCast();
- });
+ .def("get_splat_value",
+ [](PyDenseElementsAttribute &self)
+ -> nb::typed<nb::object, PyAttribute> {
+ if (!mlirDenseElementsAttrIsSplat(self))
+ throw nb::value_error(
+ "get_splat_value called on a non-splat attribute");
+ return PyAttribute(self.getContext(),
+ mlirDenseElementsAttrGetSplatValue(self))
+ .maybeDownCast();
+ });
}
static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
/// Returns the element at the given linear position. Asserts if the index
/// is out of range.
- nb::object dunderGetItem(intptr_t pos) {
+ nb::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds element");
}
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
},
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
"Gets an uniqued dict attribute");
- c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
- MlirAttribute attr =
- mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
- if (mlirAttributeIsNull(attr))
- throw nb::key_error("attempt to access a non-existent attribute");
- return PyAttribute(self.getContext(), attr).maybeDownCast();
- });
+ c.def("__getitem__",
+ [](PyDictAttribute &self,
+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+ MlirAttribute attr =
+ mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+ if (mlirAttributeIsNull(attr))
+ throw nb::key_error("attempt to access a non-existent attribute");
+ return PyAttribute(self.getContext(), attr).maybeDownCast();
+ });
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
- c.def_prop_ro("value", [](PyTypeAttribute &self) {
- return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
- .maybeDownCast();
- });
+ c.def_prop_ro(
+ "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+ return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+ .maybeDownCast();
+ });
}
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..e9cfdc7c6d674 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -513,7 +513,7 @@ class PyOperationIterator {
PyOperationIterator &dunderIter() { return *this; }
- nb::object dunderNext() {
+ nb::typed<nb::object, PyOpView> dunderNext() {
parentOperation->checkValid();
if (mlirOperationIsNull(next)) {
throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
return count;
}
- nb::object dunderGetItem(intptr_t index) {
+ nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
parentOperation->checkValid();
if (index < 0) {
index += dunderLen();
@@ -725,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
new PyDiagnosticHandler(get(), std::move(callback));
nb::object pyHandlerObject =
nb::cast(pyHandler, nb::rv_policy::take_ownership);
- pyHandlerObject.inc_ref();
+ (void)pyHandlerObject.inc_ref();
// In these C callbacks, the userData is a PyDiagnosticHandler* that is
// guaranteed to be known to pybind.
@@ -1395,7 +1395,7 @@ nb::object PyOperation::getCapsule() {
return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
}
-nb::object PyOperation::createFromCapsule(nb::object capsule) {
+nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
if (mlirOperationIsNull(rawOperation))
throw nb::python_error();
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
},
nb::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
+ [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+ return self.maybeDownCast();
+ });
DerivedTy::bindDerived(cls);
}
@@ -1623,13 +1625,14 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyOpResult &self) {
- assert(
- mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in the IR");
- return self.getParentOperation().getObject();
- });
+ c.def_prop_ro(
+ "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+ assert(mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in "
+ "the IR");
+ return self.getParentOperation().getObject();
+ });
c.def_prop_ro("result_number", [](PyOpResult &self) {
return mlirOpResultGetResultNumber(self.get());
});
@@ -1638,9 +1641,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
/// Returns the list of types of the values held by container.
template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+ std::vector<nb::typed<nb::object, PyType>> result;
result.reserve(container.size());
for (int i = 0, e = container.size(); i < e; ++i) {
result.push_back(PyType(context->getRef(),
@@ -1671,9 +1674,10 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
c.def_prop_ro("types", [](PyOpResultList &self) {
return getValueTypes(self, self.operation->getContext());
});
- c.def_prop_ro("owner", [](PyOpResultList &self) {
- return self.operation->createOpView();
- });
+ c.def_prop_ro("owner",
+ [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+ return self.operation->createOpView();
+ });
}
PyOperationRef &getOperation() { return operation; }
@@ -2104,7 +2108,7 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
- return PyThreadContextEntry::pushInsertionPoint(insertPoint);
+ return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
}
void PyInsertionPoint::contextExit(const nb::object &excType,
@@ -2125,7 +2129,7 @@ nb::object PyAttribute::getCapsule() {
return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
}
-PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
+PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
if (mlirAttributeIsNull(rawAttr))
throw nb::python_error();
@@ -2677,7 +2681,8 @@ class PyOpAttributeMap {
PyOpAttributeMap(PyOperationRef operation)
: operation(std::move(operation)) {}
- nb::object dunderGetItemNamed(const std::string &name) {
+ nb::typed<nb::object, PyAttribute>
+ dunderGetItemNamed(const std::string &name) {
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
toMlirStringRef(name));
if (mlirAttributeIsNull(attr)) {
@@ -2962,13 +2967,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def_static("_get_live_count", &PyMlirContext::getLiveCount)
.def("_get_context_again",
- [](PyMlirContext &self) {
+ [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
.def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3123,7 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyDialectRegistry>(m, "DialectRegistry")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+ &PyDialectRegistry::createFromCapsule)
.def(nb::init<>());
//----------------------------------------------------------------------------
@@ -3131,7 +3138,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyLocation>(m, "Location")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
.def("__enter__", &PyLocation::contextEnter)
.def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3286,7 +3293,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"Gets a Location from a LocationAttr")
.def_prop_ro(
"context",
- [](PyLocation &self) { return self.getContext().getObject(); },
+ [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that owns the Location")
.def_prop_ro(
"attr",
@@ -3313,12 +3322,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
- .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
- kModuleCAPICreate)
+ .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+ kModuleCAPICreate)
.def("_clear_mlir_module", &PyModule::clearMlirModule)
.def_static(
"parse",
- [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+ [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
@@ -3330,7 +3340,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"parse",
- [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
+ [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParse(
context->get(), toMlirStringRef(moduleAsm));
@@ -3342,7 +3353,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"parseFile",
- [](const std::string &path, DefaultingPyMlirContext context) {
+ [](const std::string &path, DefaultingPyMlirContext context)
+ -> nb::typed<nb::object, PyModule> {
PyMlirContext::ErrorCapture errors(context->getRef());
MlirModule module = mlirModuleCreateParseFromFile(
context->get(), toMlirStringRef(path));
@@ -3354,7 +3366,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
- [](const std::optional<PyLocation> &loc) {
+ [](const std::optional<PyLocation> &loc)
+ -> nb::typed<nb::object, PyModule> {
PyLocation pyLoc = maybeGetTracebackLocation(loc);
MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
@@ -3362,11 +3375,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("loc") = nb::none(), "Creates an empty module")
.def_prop_ro(
"context",
- [](PyModule &self) { return self.getContext().getObject(); },
+ [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+ return self.getContext().getObject();
+ },
"Context that created the Module")
.def_prop_ro(
"operation",
- [](PyModule &self) {
+ [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
return PyOperation::forOperation(self.getContext(),
mlirModuleGetOperation(self.get()),
self.getRef().releaseObject())
@@ -3430,7 +3445,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def_prop_ro(
"context",
- [](PyOperationBase &self) {
+ [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
PyOperation &concreteOperation = self.getOperation();
concreteOperation.checkValid();
return concreteOperation.getContext().getObject();
@@ -3461,7 +347...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/160221
More information about the Mlir-commits
mailing list