[Mlir-commits] [mlir] [MLIR][Python] restore types (PR #160194)

Maksim Levental llvmlistbot at llvm.org
Mon Sep 22 14:00:22 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160194

>From 7468dff85f01a66722e9543131f5eefef0815175 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 16:27:46 -0400
Subject: [PATCH] [MLIR][Python] restore types

---
 mlir/lib/Bindings/Python/IRAttributes.cpp | 38 ++++++-----
 mlir/lib/Bindings/Python/IRCore.cpp       | 82 ++++++++++++++---------
 mlir/lib/Bindings/Python/IRModule.h       |  5 +-
 mlir/lib/Bindings/Python/IRTypes.cpp      | 44 ++++++------
 4 files changed, 101 insertions(+), 68 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..404c4d842e02c 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,13 +485,14 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 
     PyArrayAttributeIterator &dunderIter() { return *this; }
 
-    nb::object dunderNext() {
+    nb::typed<nb::object, PyAttribute> dunderNext() {
       // TODO: Throw is an inefficient way to stop iteration.
       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
         throw nb::stop_iteration();
-      return PyAttribute(this->attr.getContext(),
-                         mlirArrayAttrGetElement(attr.get(), nextIndex++))
-          .maybeDownCast();
+      return nb::cast<nb::typed<nb::object, PyAttribute>>(
+          PyAttribute(this->attr.getContext(),
+                      mlirArrayAttrGetElement(attr.get(), nextIndex++))
+              .maybeDownCast());
     }
 
     static void bind(nb::module_ &m) {
@@ -524,13 +525,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
         },
         nb::arg("attributes"), nb::arg("context") = nb::none(),
         "Gets a uniqued Array attribute");
-    c.def(
-         "__getitem__",
-         [](PyArrayAttribute &arr, intptr_t i) {
-           if (i >= mlirArrayAttrGetNumElements(arr))
-             throw nb::index_error("ArrayAttribute index out of range");
-           return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
-         })
+    c.def("__getitem__",
+          [](PyArrayAttribute &arr, intptr_t i) {
+            if (i >= mlirArrayAttrGetNumElements(arr))
+              throw nb::index_error("ArrayAttribute index out of range");
+            return nb::cast<nb::typed<nb::object, PyAttribute>>(
+                PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast());
+          })
         .def("__len__",
              [](const PyArrayAttribute &arr) {
                return mlirArrayAttrGetNumElements(arr);
@@ -1014,9 +1015,10 @@ class PyDenseElementsAttribute
           if (!mlirDenseElementsAttrIsSplat(self))
             throw nb::value_error(
                 "get_splat_value called on a non-splat attribute");
-          return PyAttribute(self.getContext(),
-                             mlirDenseElementsAttrGetSplatValue(self))
-              .maybeDownCast();
+          return nb::cast<nb::typed<nb::object, PyAttribute>>(
+              PyAttribute(self.getContext(),
+                          mlirDenseElementsAttrGetSplatValue(self))
+                  .maybeDownCast());
         });
   }
 
@@ -1527,7 +1529,8 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
           mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
       if (mlirAttributeIsNull(attr))
         throw nb::key_error("attempt to access a non-existent attribute");
-      return PyAttribute(self.getContext(), attr).maybeDownCast();
+      return nb::cast<nb::typed<nb::object, PyAttribute>>(
+          PyAttribute(self.getContext(), attr).maybeDownCast());
     });
     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
       if (index < 0 || index >= self.dunderLen()) {
@@ -1595,8 +1598,9 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
         nb::arg("value"), nb::arg("context") = nb::none(),
         "Gets a uniqued Type attribute");
     c.def_prop_ro("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
-          .maybeDownCast();
+      return nb::cast<nb::typed<nb::object, PyType>>(
+          PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+              .maybeDownCast());
     });
   }
 };
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 81386f2227a7f..5a6edfa737fd7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -528,7 +528,8 @@ class PyOperationIterator {
   static void bind(nb::module_ &m) {
     nb::class_<PyOperationIterator>(m, "OperationIterator")
         .def("__iter__", &PyOperationIterator::dunderIter)
-        .def("__next__", &PyOperationIterator::dunderNext);
+        .def("__next__", &PyOperationIterator::dunderNext,
+             nb::sig("def __next__(self) -> OpView"));
   }
 
 private:
@@ -1604,8 +1605,9 @@ class PyConcreteValue : public PyValue {
           return DerivedTy::isaFunction(otherValue);
         },
         nb::arg("other_value"));
-    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) {
+      return nb::cast<nb::typed<nb::object, DerivedTy>>(self.maybeDownCast());
+    });
     DerivedTy::bindDerived(cls);
   }
 
@@ -1638,14 +1640,15 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 
 /// Returns the list of types of the values held by container.
 template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
-                                             PyMlirContextRef &context) {
-  std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+  std::vector<nb::typed<nb::object, PyType>> result;
   result.reserve(container.size());
   for (int i = 0, e = container.size(); i < e; ++i) {
-    result.push_back(PyType(context->getRef(),
-                            mlirValueGetType(container.getElement(i).get()))
-                         .maybeDownCast());
+    result.push_back(nb::cast<nb::typed<nb::object, PyType>>(
+        PyType(context->getRef(),
+               mlirValueGetType(container.getElement(i).get()))
+            .maybeDownCast()));
   }
   return result;
 }
@@ -2677,13 +2680,15 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation)
       : operation(std::move(operation)) {}
 
-  nb::object dunderGetItemNamed(const std::string &name) {
+  nb::typed<nb::object, PyAttribute>
+  dunderGetItemNamed(const std::string &name) {
     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
                                                          toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
       throw nb::key_error("attempt to access a non-existent attribute");
     }
-    return PyAttribute(operation->getContext(), attr).maybeDownCast();
+    return nb::cast<nb::typed<nb::object, PyAttribute>>(
+        PyAttribute(operation->getContext(), attr).maybeDownCast());
   }
 
   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
@@ -2961,14 +2966,17 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              new (&self) PyMlirContext(context);
            })
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
-      .def("_get_context_again",
-           [](PyMlirContext &self) {
-             PyMlirContextRef ref = PyMlirContext::forContext(self.get());
-             return ref.releaseObject();
-           })
+      .def(
+          "_get_context_again",
+          [](PyMlirContext &self) {
+            PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+            return ref.releaseObject();
+          },
+          nb::sig("def _get_context_again(self) -> Context"))
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule,
+           nb::sig("def _CAPICreate(self) -> Context"))
       .def("__enter__", &PyMlirContext::contextEnter)
       .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
            nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3463,8 +3471,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "result",
           [](PyOperationBase &self) {
             auto &operation = self.getOperation();
-            return PyOpResult(operation.getRef(), getUniqueResult(operation))
-                .maybeDownCast();
+            return nb::cast<nb::typed<nb::object, PyOpResult>>(
+                PyOpResult(operation.getRef(), getUniqueResult(operation))
+                    .maybeDownCast());
           },
           "Shortcut to get an op result if it has only one (throws an error "
           "otherwise).")
@@ -3988,7 +3997,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                 context->get(), toMlirStringRef(attrSpec));
             if (mlirAttributeIsNull(attr))
               throw MLIRError("Unable to parse attribute", errors.take());
-            return PyAttribute(context.get()->getRef(), attr).maybeDownCast();
+            return nb::cast<nb::typed<nb::object, PyAttribute>>(
+                PyAttribute(context.get()->getRef(), attr).maybeDownCast());
           },
           nb::arg("asm"), nb::arg("context") = nb::none(),
           "Parses an attribute from an assembly form. Raises an MLIRError on "
@@ -3999,9 +4009,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Context that owns the Attribute")
       .def_prop_ro("type",
                    [](PyAttribute &self) {
-                     return PyType(self.getContext(),
-                                   mlirAttributeGetType(self))
-                         .maybeDownCast();
+                     return nb::cast<nb::typed<nb::object, PyType>>(
+                         PyType(self.getContext(), mlirAttributeGetType(self))
+                             .maybeDownCast());
                    })
       .def(
           "get_named",
@@ -4049,7 +4059,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                             "mlirTypeID was expected to be non-null.");
                      return PyTypeID(mlirTypeID);
                    })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) {
+        nb::cast<nb::typed<nb::object, PyAttribute>>(self.maybeDownCast());
+      });
 
   //----------------------------------------------------------------------------
   // Mapping of PyNamedAttribute
@@ -4100,7 +4112,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
             if (mlirTypeIsNull(type))
               throw MLIRError("Unable to parse type", errors.take());
-            return PyType(context.get()->getRef(), type).maybeDownCast();
+            return nb::cast<nb::typed<nb::object, PyType>>(
+                PyType(context.get()->getRef(), type).maybeDownCast());
           },
           nb::arg("asm"), nb::arg("context") = nb::none(),
           kContextParseTypeDocstring)
@@ -4139,7 +4152,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyType &self) {
+             return nb::cast<nb::typed<nb::object, PyType>>(
+                 self.maybeDownCast());
+           })
       .def_prop_ro("typeid", [](PyType &self) {
         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
         if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4267,9 +4284,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           nb::arg("state"), kGetNameAsOperand)
       .def_prop_ro("type",
                    [](PyValue &self) {
-                     return PyType(self.getParentOperation()->getContext(),
-                                   mlirValueGetType(self.get()))
-                         .maybeDownCast();
+                     return nb::cast<nb::typed<nb::object, PyType>>(
+                         PyType(self.getParentOperation()->getContext(),
+                                mlirValueGetType(self.get()))
+                             .maybeDownCast());
                    })
       .def(
           "set_type",
@@ -4305,7 +4323,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("with_"), nb::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyValue &self) {
+             return nb::cast<nb::typed<nb::object, PyValue>>(
+                 self.maybeDownCast());
+           })
       .def_prop_ro(
           "location",
           [](MlirValue self) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..dc9913bc5ebb2 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -1102,8 +1102,9 @@ class PyConcreteAttribute : public BaseTy {
         },
         nanobind::arg("other"));
     cls.def_prop_ro("type", [](PyAttribute &attr) {
-      return PyType(attr.getContext(), mlirAttributeGetType(attr))
-          .maybeDownCast();
+      return nanobind::cast<nanobind::typed<nanobind::object, PyType>>(
+          PyType(attr.getContext(), mlirAttributeGetType(attr))
+              .maybeDownCast());
     });
     cls.def_prop_ro_static(
         "static_typeid",
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index a7aa1c65c6c43..a228ca4418c4a 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -502,8 +502,9 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
     c.def_prop_ro(
         "element_type",
         [](PyComplexType &self) {
-          return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
-              .maybeDownCast();
+          return nb::cast<nb::typed<nb::object, PyType>>(
+              PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+                  .maybeDownCast());
         },
         "Returns element type.");
   }
@@ -516,8 +517,9 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
   c.def_prop_ro(
       "element_type",
       [](PyShapedType &self) {
-        return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
-            .maybeDownCast();
+        return nb::cast<nb::typed<nb::object, PyType>>(
+            PyType(self.getContext(), mlirShapedTypeGetElementType(self))
+                .maybeDownCast());
       },
       "Returns the element type of the shaped type.");
   c.def_prop_ro(
@@ -898,11 +900,21 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
         },
         nb::arg("elements"), nb::arg("context") = nb::none(),
         "Create a tuple type");
+    c.def_static(
+        "get_tuple",
+        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+                                        elements.data());
+          return PyTupleType(context->getRef(), t);
+        },
+        nb::arg("elements"), nb::arg("context") = nb::none(),
+        "Create a tuple type");
     c.def(
         "get_type",
         [](PyTupleType &self, intptr_t pos) {
-          return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
-              .maybeDownCast();
+          return nb::cast<nb::typed<nb::object, PyType>>(
+              PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+                  .maybeDownCast());
         },
         nb::arg("pos"), "Returns the pos-th type in the tuple type.");
     c.def_prop_ro(
@@ -926,23 +938,17 @@ class PyFunctionType : public PyConcreteType<PyFunctionType> {
   static void bindDerived(ClassTy &c) {
     c.def_static(
         "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
+        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
            DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirInputs;
-          mlirInputs.reserve(inputs.size());
-          for (const auto &input : inputs)
-            mlirInputs.push_back(input.get());
-          std::vector<MlirType> mlirResults;
-          mlirResults.reserve(results.size());
-          for (const auto &result : results)
-            mlirResults.push_back(result.get());
-
-          MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
-                                           mlirInputs.data(), results.size(),
-                                           mlirResults.data());
+          MlirType t =
+              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+                                  results.size(), results.data());
           return PyFunctionType(context->getRef(), t);
         },
         nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
+        // clang-format off
+        nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: mlir.ir.Context | None = None) -> FunctionType"),
+        // clang-format on
         "Gets a FunctionType from a list of input and result types");
     c.def_prop_ro(
         "inputs",



More information about the Mlir-commits mailing list