[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)

Maksim Levental llvmlistbot at llvm.org
Mon Sep 22 19:18:04 PDT 2025


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/160221

None

>From 3376f4f23dbee295a81948b1dfb5d8c3459990e1 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 18:57:51 -0700
Subject: [PATCH] [MLIR][Python] use nb::typed for return signatures

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 48 +++++++++++++----------
 mlir/lib/Bindings/Python/IRCore.cpp       | 40 +++++++++++++------
 mlir/lib/Bindings/Python/IRModule.h       | 10 +++--
 mlir/lib/Bindings/Python/IRTypes.cpp      | 21 +++++-----
 4 files changed, 69 insertions(+), 50 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..51c3c46bcd02b 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 
     PyArrayAttributeIterator &dunderIter() { return *this; }
 
-    nb::object dunderNext() {
+    nb::typed<nb::object, PyAttribute> dunderNext() {
       // TODO: Throw is an inefficient way to stop iteration.
       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
         throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
         "Gets a uniqued Array attribute");
     c.def(
          "__getitem__",
-         [](PyArrayAttribute &arr, intptr_t i) {
+         [](PyArrayAttribute &arr,
+            intptr_t i) -> nb::typed<nb::object, PyAttribute> {
            if (i >= mlirArrayAttrGetNumElements(arr))
              throw nb::index_error("ArrayAttribute index out of range");
            return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
                      [](PyDenseElementsAttribute &self) -> bool {
                        return mlirDenseElementsAttrIsSplat(self);
                      })
-        .def("get_splat_value", [](PyDenseElementsAttribute &self) {
-          if (!mlirDenseElementsAttrIsSplat(self))
-            throw nb::value_error(
-                "get_splat_value called on a non-splat attribute");
-          return PyAttribute(self.getContext(),
-                             mlirDenseElementsAttrGetSplatValue(self))
-              .maybeDownCast();
-        });
+        .def("get_splat_value",
+             [](PyDenseElementsAttribute &self)
+                 -> nb::typed<nb::object, PyAttribute> {
+               if (!mlirDenseElementsAttrIsSplat(self))
+                 throw nb::value_error(
+                     "get_splat_value called on a non-splat attribute");
+               return PyAttribute(self.getContext(),
+                                  mlirDenseElementsAttrGetSplatValue(self))
+                   .maybeDownCast();
+             });
   }
 
   static PyType_Slot slots[];
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
         },
         nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
         "Gets an uniqued dict attribute");
-    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
-      MlirAttribute attr =
-          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-      if (mlirAttributeIsNull(attr))
-        throw nb::key_error("attempt to access a non-existent attribute");
-      return PyAttribute(self.getContext(), attr).maybeDownCast();
-    });
+    c.def("__getitem__",
+          [](PyDictAttribute &self,
+             const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+            MlirAttribute attr =
+                mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+            if (mlirAttributeIsNull(attr))
+              throw nb::key_error("attempt to access a non-existent attribute");
+            return PyAttribute(self.getContext(), attr).maybeDownCast();
+          });
     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
       if (index < 0 || index >= self.dunderLen()) {
         throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
         },
         nb::arg("value"), nb::arg("context") = nb::none(),
         "Gets a uniqued Type attribute");
-    c.def_prop_ro("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
-          .maybeDownCast();
-    });
+    c.def_prop_ro(
+        "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+          return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+              .maybeDownCast();
+        });
   }
 };
 
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..c7af6a6ce0d60 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
         },
         nb::arg("other_value"));
     cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
+            [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+              return self.maybeDownCast();
+            });
     DerivedTy::bindDerived(cls);
   }
 
@@ -1638,9 +1640,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 
 /// Returns the list of types of the values held by container.
 template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
-                                             PyMlirContextRef &context) {
-  std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+  std::vector<nb::typed<nb::object, PyType>> result;
   result.reserve(container.size());
   for (int i = 0, e = container.size(); i < e; ++i) {
     result.push_back(PyType(context->getRef(),
@@ -2677,7 +2679,8 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation)
       : operation(std::move(operation)) {}
 
-  nb::object dunderGetItemNamed(const std::string &name) {
+  nb::typed<nb::object, PyAttribute>
+  dunderGetItemNamed(const std::string &name) {
     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
                                                          toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
@@ -3461,7 +3464,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Returns the list of Operation results.")
       .def_prop_ro(
           "result",
-          [](PyOperationBase &self) {
+          [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
             auto &operation = self.getOperation();
             return PyOpResult(operation.getRef(), getUniqueResult(operation))
                 .maybeDownCast();
@@ -3982,7 +3985,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
       .def_static(
           "parse",
-          [](const std::string &attrSpec, DefaultingPyMlirContext context) {
+          [](const std::string &attrSpec, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyAttribute> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirAttribute attr = mlirAttributeParseGet(
                 context->get(), toMlirStringRef(attrSpec));
@@ -3998,7 +4002,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           [](PyAttribute &self) { return self.getContext().getObject(); },
           "Context that owns the Attribute")
       .def_prop_ro("type",
-                   [](PyAttribute &self) {
+                   [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
                      return PyType(self.getContext(),
                                    mlirAttributeGetType(self))
                          .maybeDownCast();
@@ -4049,7 +4053,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                             "mlirTypeID was expected to be non-null.");
                      return PyTypeID(mlirTypeID);
                    })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+             return self.maybeDownCast();
+           });
 
   //----------------------------------------------------------------------------
   // Mapping of PyNamedAttribute
@@ -4094,7 +4101,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
       .def_static(
           "parse",
-          [](std::string typeSpec, DefaultingPyMlirContext context) {
+          [](std::string typeSpec,
+             DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirType type =
                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
@@ -4139,7 +4147,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyType &self) -> nb::typed<nb::object, PyType> {
+             return self.maybeDownCast();
+           })
       .def_prop_ro("typeid", [](PyType &self) {
         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
         if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4266,7 +4277,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("state"), kGetNameAsOperand)
       .def_prop_ro("type",
-                   [](PyValue &self) {
+                   [](PyValue &self) -> nb::typed<nb::object, PyType> {
                      return PyType(self.getParentOperation()->getContext(),
                                    mlirValueGetType(self.get()))
                          .maybeDownCast();
@@ -4332,7 +4343,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("with_"), nb::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+             return self.maybeDownCast();
+           })
       .def_prop_ro(
           "location",
           [](MlirValue self) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..0aabab8991b46 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy {
           return DerivedTy::isaFunction(otherAttr);
         },
         nanobind::arg("other"));
-    cls.def_prop_ro("type", [](PyAttribute &attr) {
-      return PyType(attr.getContext(), mlirAttributeGetType(attr))
-          .maybeDownCast();
-    });
+    cls.def_prop_ro(
+        "type",
+        [](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
+          return PyType(attr.getContext(), mlirAttributeGetType(attr))
+              .maybeDownCast();
+        });
     cls.def_prop_ro_static(
         "static_typeid",
         [](nanobind::object & /*class*/) -> PyTypeID {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cab3bf549295b..07dc00521833f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
         "Create a complex type");
     c.def_prop_ro(
         "element_type",
-        [](PyComplexType &self) {
+        [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
           return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
               .maybeDownCast();
         },
@@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
 void mlir::PyShapedType::bindDerived(ClassTy &c) {
   c.def_prop_ro(
       "element_type",
-      [](PyShapedType &self) {
+      [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
         return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
             .maybeDownCast();
       },
@@ -731,8 +731,7 @@ class PyRankedTensorType
           MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
           if (mlirAttributeIsNull(encoding))
             return std::nullopt;
-          return nb::cast<nb::typed<nb::object, PyAttribute>>(
-              PyAttribute(self.getContext(), encoding).maybeDownCast());
+          return PyAttribute(self.getContext(), encoding).maybeDownCast();
         });
   }
 };
@@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
         .def_prop_ro(
             "layout",
             [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
-                      .maybeDownCast());
+              return PyAttribute(self.getContext(),
+                                 mlirMemRefTypeGetLayout(self))
+                  .maybeDownCast();
             },
             "The layout of the MemRef type.")
         .def(
@@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
               if (mlirAttributeIsNull(a))
                 return std::nullopt;
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), a).maybeDownCast());
+              return PyAttribute(self.getContext(), a).maybeDownCast();
             },
             "Returns the memory space of the given MemRef type.");
   }
@@ -867,8 +865,7 @@ class PyUnrankedMemRefType
               MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
               if (mlirAttributeIsNull(a))
                 return std::nullopt;
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), a).maybeDownCast());
+              return PyAttribute(self.getContext(), a).maybeDownCast();
             },
             "Returns the memory space of the given Unranked MemRef type.");
   }
@@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
         "Create a tuple type");
     c.def(
         "get_type",
-        [](PyTupleType &self, intptr_t pos) {
+        [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
           return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
               .maybeDownCast();
         },



More information about the Mlir-commits mailing list