[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)

Maksim Levental llvmlistbot at llvm.org
Mon Sep 22 20:35:50 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/160221

>From fbd9df278abb825145e6a307786a759d1eff2d38 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Mon, 22 Sep 2025 18:57:51 -0700
Subject: [PATCH] [MLIR][Python] use nb::typed for return signatures

---
 mlir/lib/Bindings/Python/IRAffine.cpp     |  49 +++---
 mlir/lib/Bindings/Python/IRAttributes.cpp |  50 +++---
 mlir/lib/Bindings/Python/IRCore.cpp       | 194 ++++++++++++++--------
 mlir/lib/Bindings/Python/IRInterfaces.cpp |  30 ++--
 mlir/lib/Bindings/Python/IRModule.h       |  14 +-
 mlir/lib/Bindings/Python/IRTypes.cpp      |  21 +--
 6 files changed, 216 insertions(+), 142 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index bc6aa0dac6221..7147f2cbad149 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            })
       .def_prop_ro(
           "context",
-          [](PyAffineExpr &self) { return self.getContext().getObject(); })
+          [](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          })
       .def("compose",
            [](PyAffineExpr &self, PyAffineMap &other) {
              return PyAffineExpr(self.getContext(),
@@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            [](PyAffineMap &self) {
              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
            })
-      .def_static("compress_unused_symbols",
-                  [](const nb::list &affineMaps,
-                     DefaultingPyMlirContext context) {
-                    SmallVector<MlirAffineMap> maps;
-                    pyListToVector<PyAffineMap, MlirAffineMap>(
-                        affineMaps, maps, "attempting to create an AffineMap");
-                    std::vector<MlirAffineMap> compressed(affineMaps.size());
-                    auto populate = [](void *result, intptr_t idx,
-                                       MlirAffineMap m) {
-                      static_cast<MlirAffineMap *>(result)[idx] = (m);
-                    };
-                    mlirAffineMapCompressUnusedSymbols(
-                        maps.data(), maps.size(), compressed.data(), populate);
-                    std::vector<PyAffineMap> res;
-                    res.reserve(compressed.size());
-                    for (auto m : compressed)
-                      res.emplace_back(context->getRef(), m);
-                    return res;
-                  })
+      .def_static(
+          "compress_unused_symbols",
+          [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+            SmallVector<MlirAffineMap> maps;
+            pyListToVector<PyAffineMap, MlirAffineMap>(
+                affineMaps, maps, "attempting to create an AffineMap");
+            std::vector<MlirAffineMap> compressed(affineMaps.size());
+            auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
+              static_cast<MlirAffineMap *>(result)[idx] = (m);
+            };
+            mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
+                                               compressed.data(), populate);
+            std::vector<PyAffineMap> res;
+            res.reserve(compressed.size());
+            for (auto m : compressed)
+              res.emplace_back(context->getRef(), m);
+            return res;
+          })
       .def_prop_ro(
           "context",
-          [](PyAffineMap &self) { return self.getContext().getObject(); },
+          [](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Affine Map")
       .def(
           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
@@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            })
       .def_prop_ro(
           "context",
-          [](PyIntegerSet &self) { return self.getContext().getObject(); })
+          [](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          })
       .def(
           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
           kDumpDocstring)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..c77653f97e6dd 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 
     PyArrayAttributeIterator &dunderIter() { return *this; }
 
-    nb::object dunderNext() {
+    nb::typed<nb::object, PyAttribute> dunderNext() {
       // TODO: Throw is an inefficient way to stop iteration.
       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
         throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
         "Gets a uniqued Array attribute");
     c.def(
          "__getitem__",
-         [](PyArrayAttribute &arr, intptr_t i) {
+         [](PyArrayAttribute &arr,
+            intptr_t i) -> nb::typed<nb::object, PyAttribute> {
            if (i >= mlirArrayAttrGetNumElements(arr))
              throw nb::index_error("ArrayAttribute index out of range");
            return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
                      [](PyDenseElementsAttribute &self) -> bool {
                        return mlirDenseElementsAttrIsSplat(self);
                      })
-        .def("get_splat_value", [](PyDenseElementsAttribute &self) {
-          if (!mlirDenseElementsAttrIsSplat(self))
-            throw nb::value_error(
-                "get_splat_value called on a non-splat attribute");
-          return PyAttribute(self.getContext(),
-                             mlirDenseElementsAttrGetSplatValue(self))
-              .maybeDownCast();
-        });
+        .def("get_splat_value",
+             [](PyDenseElementsAttribute &self)
+                 -> nb::typed<nb::object, PyAttribute> {
+               if (!mlirDenseElementsAttrIsSplat(self))
+                 throw nb::value_error(
+                     "get_splat_value called on a non-splat attribute");
+               return PyAttribute(self.getContext(),
+                                  mlirDenseElementsAttrGetSplatValue(self))
+                   .maybeDownCast();
+             });
   }
 
   static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
 
   /// Returns the element at the given linear position. Asserts if the index
   /// is out of range.
-  nb::object dunderGetItem(intptr_t pos) {
+  nb::int_ dunderGetItem(intptr_t pos) {
     if (pos < 0 || pos >= dunderLen()) {
       throw nb::index_error("attempt to access out of bounds element");
     }
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
         },
         nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
         "Gets an uniqued dict attribute");
-    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
-      MlirAttribute attr =
-          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-      if (mlirAttributeIsNull(attr))
-        throw nb::key_error("attempt to access a non-existent attribute");
-      return PyAttribute(self.getContext(), attr).maybeDownCast();
-    });
+    c.def("__getitem__",
+          [](PyDictAttribute &self,
+             const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+            MlirAttribute attr =
+                mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+            if (mlirAttributeIsNull(attr))
+              throw nb::key_error("attempt to access a non-existent attribute");
+            return PyAttribute(self.getContext(), attr).maybeDownCast();
+          });
     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
       if (index < 0 || index >= self.dunderLen()) {
         throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
         },
         nb::arg("value"), nb::arg("context") = nb::none(),
         "Gets a uniqued Type attribute");
-    c.def_prop_ro("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
-          .maybeDownCast();
-    });
+    c.def_prop_ro(
+        "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+          return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+              .maybeDownCast();
+        });
   }
 };
 
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..5bad942d70374 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -513,7 +513,7 @@ class PyOperationIterator {
 
   PyOperationIterator &dunderIter() { return *this; }
 
-  nb::object dunderNext() {
+  nb::typed<nb::object, PyOpView> dunderNext() {
     parentOperation->checkValid();
     if (mlirOperationIsNull(next)) {
       throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
     return count;
   }
 
-  nb::object dunderGetItem(intptr_t index) {
+  nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
     parentOperation->checkValid();
     if (index < 0) {
       index += dunderLen();
@@ -725,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
       new PyDiagnosticHandler(get(), std::move(callback));
   nb::object pyHandlerObject =
       nb::cast(pyHandler, nb::rv_policy::take_ownership);
-  pyHandlerObject.inc_ref();
+  (void)pyHandlerObject.inc_ref();
 
   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
   // guaranteed to be known to pybind.
@@ -1395,7 +1395,7 @@ nb::object PyOperation::getCapsule() {
   return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
 }
 
-nb::object PyOperation::createFromCapsule(nb::object capsule) {
+nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
   if (mlirOperationIsNull(rawOperation))
     throw nb::python_error();
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
         },
         nb::arg("other_value"));
     cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
+            [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+              return self.maybeDownCast();
+            });
     DerivedTy::bindDerived(cls);
   }
 
@@ -1623,13 +1625,14 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   using PyConcreteValue::PyConcreteValue;
 
   static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("owner", [](PyOpResult &self) {
-      assert(
-          mlirOperationEqual(self.getParentOperation()->get(),
-                             mlirOpResultGetOwner(self.get())) &&
-          "expected the owner of the value in Python to match that in the IR");
-      return self.getParentOperation().getObject();
-    });
+    c.def_prop_ro(
+        "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+          assert(mlirOperationEqual(self.getParentOperation()->get(),
+                                    mlirOpResultGetOwner(self.get())) &&
+                 "expected the owner of the value in Python to match that in "
+                 "the IR");
+          return self.getParentOperation().getObject();
+        });
     c.def_prop_ro("result_number", [](PyOpResult &self) {
       return mlirOpResultGetResultNumber(self.get());
     });
@@ -1638,9 +1641,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 
 /// Returns the list of types of the values held by container.
 template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
-                                             PyMlirContextRef &context) {
-  std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+  std::vector<nb::typed<nb::object, PyType>> result;
   result.reserve(container.size());
   for (int i = 0, e = container.size(); i < e; ++i) {
     result.push_back(PyType(context->getRef(),
@@ -1671,9 +1674,10 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
     c.def_prop_ro("types", [](PyOpResultList &self) {
       return getValueTypes(self, self.operation->getContext());
     });
-    c.def_prop_ro("owner", [](PyOpResultList &self) {
-      return self.operation->createOpView();
-    });
+    c.def_prop_ro("owner",
+                  [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+                    return self.operation->createOpView();
+                  });
   }
 
   PyOperationRef &getOperation() { return operation; }
@@ -2104,7 +2108,7 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
-  return PyThreadContextEntry::pushInsertionPoint(insertPoint);
+  return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
 }
 
 void PyInsertionPoint::contextExit(const nb::object &excType,
@@ -2125,7 +2129,7 @@ nb::object PyAttribute::getCapsule() {
   return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
 }
 
-PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
+PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
   if (mlirAttributeIsNull(rawAttr))
     throw nb::python_error();
@@ -2677,7 +2681,8 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation)
       : operation(std::move(operation)) {}
 
-  nb::object dunderGetItemNamed(const std::string &name) {
+  nb::typed<nb::object, PyAttribute>
+  dunderGetItemNamed(const std::string &name) {
     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
                                                          toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
@@ -2962,13 +2967,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            })
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
       .def("_get_context_again",
-           [](PyMlirContext &self) {
+           [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
       .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
            nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3123,7 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyDialectRegistry>(m, "DialectRegistry")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyDialectRegistry::createFromCapsule)
       .def(nb::init<>());
 
   //----------------------------------------------------------------------------
@@ -3131,7 +3138,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyLocation>(m, "Location")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
       .def("__enter__", &PyLocation::contextEnter)
       .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
            nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3286,7 +3293,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Gets a Location from a LocationAttr")
       .def_prop_ro(
           "context",
-          [](PyLocation &self) { return self.getContext().getObject(); },
+          [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Location")
       .def_prop_ro(
           "attr",
@@ -3313,12 +3322,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
-           kModuleCAPICreate)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+                  kModuleCAPICreate)
       .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
-          [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+          [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParse(
                 context->get(), toMlirStringRef(moduleAsm));
@@ -3330,7 +3340,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "parse",
-          [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
+          [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParse(
                 context->get(), toMlirStringRef(moduleAsm));
@@ -3342,7 +3353,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "parseFile",
-          [](const std::string &path, DefaultingPyMlirContext context) {
+          [](const std::string &path, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParseFromFile(
                 context->get(), toMlirStringRef(path));
@@ -3354,7 +3366,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "create",
-          [](const std::optional<PyLocation> &loc) {
+          [](const std::optional<PyLocation> &loc)
+              -> nb::typed<nb::object, PyModule> {
             PyLocation pyLoc = maybeGetTracebackLocation(loc);
             MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
             return PyModule::forModule(module).releaseObject();
@@ -3362,11 +3375,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           nb::arg("loc") = nb::none(), "Creates an empty module")
       .def_prop_ro(
           "context",
-          [](PyModule &self) { return self.getContext().getObject(); },
+          [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that created the Module")
       .def_prop_ro(
           "operation",
-          [](PyModule &self) {
+          [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
             return PyOperation::forOperation(self.getContext(),
                                              mlirModuleGetOperation(self.get()),
                                              self.getRef().releaseObject())
@@ -3430,7 +3445,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                    })
       .def_prop_ro(
           "context",
-          [](PyOperationBase &self) {
+          [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
             PyOperation &concreteOperation = self.getOperation();
             concreteOperation.checkValid();
             return concreteOperation.getContext().getObject();
@@ -3461,7 +3476,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Returns the list of Operation results.")
       .def_prop_ro(
           "result",
-          [](PyOperationBase &self) {
+          [](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
             auto &operation = self.getOperation();
             return PyOpResult(operation.getRef(), getUniqueResult(operation))
                 .maybeDownCast();
@@ -3478,11 +3493,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Returns the source location the operation was defined or derived "
           "from.")
       .def_prop_ro("parent",
-                   [](PyOperationBase &self) -> nb::object {
+                   [](PyOperationBase &self)
+                       -> std::optional<nb::typed<nb::object, PyOperation>> {
                      auto parent = self.getOperation().getParentOperation();
                      if (parent)
                        return parent->getObject();
-                     return nb::none();
+                     return {};
                    })
       .def(
           "__str__",
@@ -3553,13 +3569,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            "of the parent block.")
       .def(
           "clone",
-          [](PyOperationBase &self, nb::object ip) {
+          [](PyOperationBase &self,
+             const nb::object &ip) -> nb::typed<nb::object, PyOperation> {
             return self.getOperation().clone(ip);
           },
           nb::arg("ip") = nb::none())
       .def(
           "detach_from_parent",
-          [](PyOperationBase &self) {
+          [](PyOperationBase &self) -> nb::typed<nb::object, PyOpView> {
             PyOperation &operation = self.getOperation();
             operation.checkValid();
             if (!operation.isAttached())
@@ -3595,7 +3612,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              std::optional<nb::dict> attributes,
              std::optional<std::vector<PyBlock *>> successors, int regions,
              const std::optional<PyLocation> &location,
-             const nb::object &maybeIp, bool inferType) {
+             const nb::object &maybeIp,
+             bool inferType) -> nb::typed<nb::object, PyOperation> {
             // Unpack/validate operands.
             llvm::SmallVector<MlirValue, 4> mlirOperands;
             if (operands) {
@@ -3620,7 +3638,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def_static(
           "parse",
           [](const std::string &sourceStr, const std::string &sourceName,
-             DefaultingPyMlirContext context) {
+             DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyOpView> {
             return PyOperation::parse(context->getRef(), sourceStr, sourceName)
                 ->createOpView();
           },
@@ -3629,9 +3648,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Parses an operation. Supports both text assembly format and binary "
           "bytecode format.")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
-      .def_prop_ro("operation", [](nb::object self) { return self; })
-      .def_prop_ro("opview", &PyOperation::createOpView)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyOperation::createFromCapsule)
+      .def_prop_ro("operation",
+                   [](nb::object self) -> nb::typed<nb::object, PyOperation> {
+                     return self;
+                   })
+      .def_prop_ro("opview",
+                   [](PyOperation &self) -> nb::typed<nb::object, PyOpView> {
+                     return self.createOpView();
+                   })
       .def_prop_ro("block", &PyOperation::getBlock)
       .def_prop_ro(
           "successors",
@@ -3644,7 +3670,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
 
   auto opViewClass =
       nb::class_<PyOpView, PyOperationBase>(m, "OpView")
-          .def(nb::init<nb::object>(), nb::arg("operation"))
+          .def(nb::init<nb::typed<nb::object, PyOperation>>(),
+               nb::arg("operation"))
           .def(
               "__init__",
               [](PyOpView *self, std::string_view name,
@@ -3671,9 +3698,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
               nb::arg("successors") = nb::none(),
               nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(),
               nb::arg("ip") = nb::none())
-
-          .def_prop_ro("operation", &PyOpView::getOperationObject)
-          .def_prop_ro("opview", [](nb::object self) { return self; })
+          .def_prop_ro(
+              "operation",
+              [](PyOpView &self) -> nb::typed<nb::object, PyOperation> {
+                return self.getOperationObject();
+              })
+          .def_prop_ro("opview",
+                       [](nb::object self) -> nb::typed<nb::object, PyOpView> {
+                         return self;
+                       })
           .def(
               "__str__",
               [](PyOpView &self) { return nb::str(self.getOperationObject()); })
@@ -3717,7 +3750,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       "Builds a specific, generated OpView based on class level attributes.");
   opViewClass.attr("parse") = classmethod(
       [](const nb::object &cls, const std::string &sourceStr,
-         const std::string &sourceName, DefaultingPyMlirContext context) {
+         const std::string &sourceName,
+         DefaultingPyMlirContext context) -> nb::typed<nb::object, PyOpView> {
         PyOperationRef parsed =
             PyOperation::parse(context->getRef(), sourceStr, sourceName);
 
@@ -3752,7 +3786,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Returns a forward-optimized sequence of blocks.")
       .def_prop_ro(
           "owner",
-          [](PyRegion &self) {
+          [](PyRegion &self) -> nb::typed<nb::object, PyOpView> {
             return self.getParentOperation()->createOpView();
           },
           "Returns the operation owning this region.")
@@ -3777,7 +3811,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule)
       .def_prop_ro(
           "owner",
-          [](PyBlock &self) {
+          [](PyBlock &self) -> nb::typed<nb::object, PyOpView> {
             return self.getParentOperation()->createOpView();
           },
           "Returns the owning operation of this block.")
@@ -3960,11 +3994,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Returns the block that this InsertionPoint points to.")
       .def_prop_ro(
           "ref_operation",
-          [](PyInsertionPoint &self) -> nb::object {
+          [](PyInsertionPoint &self)
+              -> std::optional<nb::typed<nb::object, PyOperation>> {
             auto refOperation = self.getRefOperation();
             if (refOperation)
               return refOperation->getObject();
-            return nb::none();
+            return {};
           },
           "The reference operation before which new operations are "
           "inserted, or None if the insertion point is at the end of "
@@ -3979,10 +4014,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def(nb::init<PyAttribute &>(), nb::arg("cast_from_type"),
            "Casts the passed attribute to the generic Attribute")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyAttribute::createFromCapsule)
       .def_static(
           "parse",
-          [](const std::string &attrSpec, DefaultingPyMlirContext context) {
+          [](const std::string &attrSpec, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyAttribute> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirAttribute attr = mlirAttributeParseGet(
                 context->get(), toMlirStringRef(attrSpec));
@@ -3995,10 +4032,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "failure.")
       .def_prop_ro(
           "context",
-          [](PyAttribute &self) { return self.getContext().getObject(); },
+          [](PyAttribute &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Attribute")
       .def_prop_ro("type",
-                   [](PyAttribute &self) {
+                   [](PyAttribute &self) -> nb::typed<nb::object, PyType> {
                      return PyType(self.getContext(),
                                    mlirAttributeGetType(self))
                          .maybeDownCast();
@@ -4049,7 +4088,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                             "mlirTypeID was expected to be non-null.");
                      return PyTypeID(mlirTypeID);
                    })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
+             return self.maybeDownCast();
+           });
 
   //----------------------------------------------------------------------------
   // Mapping of PyNamedAttribute
@@ -4091,10 +4133,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def(nb::init<PyType &>(), nb::arg("cast_from_type"),
            "Casts the passed type to the generic Type")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
       .def_static(
           "parse",
-          [](std::string typeSpec, DefaultingPyMlirContext context) {
+          [](std::string typeSpec,
+             DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirType type =
                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
@@ -4105,7 +4148,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           nb::arg("asm"), nb::arg("context") = nb::none(),
           kContextParseTypeDocstring)
       .def_prop_ro(
-          "context", [](PyType &self) { return self.getContext().getObject(); },
+          "context",
+          [](PyType &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Type")
       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
       .def(
@@ -4139,7 +4185,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              printAccum.parts.append(")");
              return printAccum.join();
            })
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyType &self) -> nb::typed<nb::object, PyType> {
+             return self.maybeDownCast();
+           })
       .def_prop_ro("typeid", [](PyType &self) {
         MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
         if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4154,7 +4203,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyTypeID>(m, "TypeID")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule)
       // Note, this tests whether the underlying TypeIDs are the same,
       // not whether the wrapper MlirTypeIDs are the same, nor whether
       // the Python objects are the same (i.e., PyTypeID is a value type).
@@ -4175,10 +4224,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   nb::class_<PyValue>(m, "Value")
       .def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"))
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
       .def_prop_ro(
           "context",
-          [](PyValue &self) {
+          [](PyValue &self) -> nb::typed<nb::object, PyMlirContext> {
             return self.getParentOperation()->getContext().getObject();
           },
           "Context in which the value lives.")
@@ -4266,7 +4315,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("state"), kGetNameAsOperand)
       .def_prop_ro("type",
-                   [](PyValue &self) {
+                   [](PyValue &self) -> nb::typed<nb::object, PyType> {
                      return PyType(self.getParentOperation()->getContext(),
                                    mlirValueGetType(self.get()))
                          .maybeDownCast();
@@ -4332,7 +4381,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("with_"), nb::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)
-      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
+      .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+           [](PyValue &self) -> nb::typed<nb::object, PyValue> {
+             return self.maybeDownCast();
+           })
       .def_prop_ro(
           "location",
           [](MlirValue self) {
@@ -4357,7 +4409,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PySymbolTable>(m, "SymbolTable")
       .def(nb::init<PyOperationBase &>())
-      .def("__getitem__", &PySymbolTable::dunderGetItem)
+      .def("__getitem__",
+           [](PySymbolTable &self,
+              const std::string &name) -> nb::typed<nb::object, PyOpView> {
+             return self.dunderGetItem(name);
+           })
       .def("insert", &PySymbolTable::insert, nb::arg("operation"))
       .def("erase", &PySymbolTable::erase, nb::arg("operation"))
       .def("__delitem__", &PySymbolTable::dunderDel)
diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 44aad10ded082..32d52b48668cf 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -196,9 +196,19 @@ class PyConcreteOpInterface {
     nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
     cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
             nb::arg("context") = nb::none(), constructorDoc)
-        .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
-                     operationDoc)
-        .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
+        .def_prop_ro(
+            "operation",
+            [](PyConcreteOpInterface &self)
+                -> nb::typed<nb::object, PyOperation> {
+              return self.getOperationObject();
+            },
+            operationDoc)
+        .def_prop_ro(
+            "opview",
+            [](PyConcreteOpInterface &self) -> nb::typed<nb::object, PyOpView> {
+              return self.getOpView();
+            },
+            opviewDoc);
     ConcreteIface::bindDerived(cls);
   }
 
@@ -362,10 +372,9 @@ class PyShapedTypeComponents {
             "Returns whether the given shaped type component is ranked.")
         .def_prop_ro(
             "rank",
-            [](PyShapedTypeComponents &self) -> nb::object {
-              if (!self.ranked) {
-                return nb::none();
-              }
+            [](PyShapedTypeComponents &self) -> std::optional<nb::int_> {
+              if (!self.ranked)
+                return {};
               return nb::int_(self.shape.size());
             },
             "Returns the rank of the given ranked shaped type components. If "
@@ -373,10 +382,9 @@ class PyShapedTypeComponents {
             "returned.")
         .def_prop_ro(
             "shape",
-            [](PyShapedTypeComponents &self) -> nb::object {
-              if (!self.ranked) {
-                return nb::none();
-              }
+            [](PyShapedTypeComponents &self) -> std::optional<nb::list> {
+              if (!self.ranked)
+                return {};
               return nb::list(self.shape);
             },
             "Returns the shape of the ranked shaped type components as a list "
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6e97c00d478f1..598ae0188464a 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -671,7 +671,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
   /// Ownership of the underlying MlirOperation is taken by calling this
   /// function.
-  static nanobind::object createFromCapsule(nanobind::object capsule);
+  static nanobind::object createFromCapsule(const nanobind::object &capsule);
 
   /// Creates an operation. See corresponding python docstring.
   static nanobind::object
@@ -1020,7 +1020,7 @@ class PyAttribute : public BaseContextObject {
   /// Note that PyAttribute instances are uniqued, so the returned object
   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
   /// is taken by calling this function.
-  static PyAttribute createFromCapsule(nanobind::object capsule);
+  static PyAttribute createFromCapsule(const nanobind::object &capsule);
 
   nanobind::object maybeDownCast();
 
@@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy {
           return DerivedTy::isaFunction(otherAttr);
         },
         nanobind::arg("other"));
-    cls.def_prop_ro("type", [](PyAttribute &attr) {
-      return PyType(attr.getContext(), mlirAttributeGetType(attr))
-          .maybeDownCast();
-    });
+    cls.def_prop_ro(
+        "type",
+        [](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
+          return PyType(attr.getContext(), mlirAttributeGetType(attr))
+              .maybeDownCast();
+        });
     cls.def_prop_ro_static(
         "static_typeid",
         [](nanobind::object & /*class*/) -> PyTypeID {
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index cab3bf549295b..07dc00521833f 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
         "Create a complex type");
     c.def_prop_ro(
         "element_type",
-        [](PyComplexType &self) {
+        [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
           return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
               .maybeDownCast();
         },
@@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
 void mlir::PyShapedType::bindDerived(ClassTy &c) {
   c.def_prop_ro(
       "element_type",
-      [](PyShapedType &self) {
+      [](PyShapedType &self) -> nb::typed<nb::object, PyType> {
         return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
             .maybeDownCast();
       },
@@ -731,8 +731,7 @@ class PyRankedTensorType
           MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
           if (mlirAttributeIsNull(encoding))
             return std::nullopt;
-          return nb::cast<nb::typed<nb::object, PyAttribute>>(
-              PyAttribute(self.getContext(), encoding).maybeDownCast());
+          return PyAttribute(self.getContext(), encoding).maybeDownCast();
         });
   }
 };
@@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
         .def_prop_ro(
             "layout",
             [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
-                      .maybeDownCast());
+              return PyAttribute(self.getContext(),
+                                 mlirMemRefTypeGetLayout(self))
+                  .maybeDownCast();
             },
             "The layout of the MemRef type.")
         .def(
@@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
               MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
               if (mlirAttributeIsNull(a))
                 return std::nullopt;
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), a).maybeDownCast());
+              return PyAttribute(self.getContext(), a).maybeDownCast();
             },
             "Returns the memory space of the given MemRef type.");
   }
@@ -867,8 +865,7 @@ class PyUnrankedMemRefType
               MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
               if (mlirAttributeIsNull(a))
                 return std::nullopt;
-              return nb::cast<nb::typed<nb::object, PyAttribute>>(
-                  PyAttribute(self.getContext(), a).maybeDownCast());
+              return PyAttribute(self.getContext(), a).maybeDownCast();
             },
             "Returns the memory space of the given Unranked MemRef type.");
   }
@@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
         "Create a tuple type");
     c.def(
         "get_type",
-        [](PyTupleType &self, intptr_t pos) {
+        [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
           return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
               .maybeDownCast();
         },



More information about the Mlir-commits mailing list