[Mlir-commits] [mlir] [MLIR][Python] use nb::typed for return signatures (PR #160221)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 22 20:32:30 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

https://github.com/llvm/llvm-project/pull/160183 removed `nb::typed` annotation to fix bazel but it turned out to simply a matter of not using the correct version of nanobind (see https://github.com/llvm/llvm-project/pull/160183#issuecomment-3321429155). This PR restores those annotations but (mostly) moves to the return positions of the actual methods.

---

Patch is 41.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160221.diff


6 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRAffine.cpp (+27-22) 
- (modified) mlir/lib/Bindings/Python/IRAttributes.cpp (+28-22) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+125-68) 
- (modified) mlir/lib/Bindings/Python/IRInterfaces.cpp (+19-11) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+8-6) 
- (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+9-12) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp
index bc6aa0dac6221..7147f2cbad149 100644
--- a/mlir/lib/Bindings/Python/IRAffine.cpp
+++ b/mlir/lib/Bindings/Python/IRAffine.cpp
@@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            })
       .def_prop_ro(
           "context",
-          [](PyAffineExpr &self) { return self.getContext().getObject(); })
+          [](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          })
       .def("compose",
            [](PyAffineExpr &self, PyAffineMap &other) {
              return PyAffineExpr(self.getContext(),
@@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            [](PyAffineMap &self) {
              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
            })
-      .def_static("compress_unused_symbols",
-                  [](const nb::list &affineMaps,
-                     DefaultingPyMlirContext context) {
-                    SmallVector<MlirAffineMap> maps;
-                    pyListToVector<PyAffineMap, MlirAffineMap>(
-                        affineMaps, maps, "attempting to create an AffineMap");
-                    std::vector<MlirAffineMap> compressed(affineMaps.size());
-                    auto populate = [](void *result, intptr_t idx,
-                                       MlirAffineMap m) {
-                      static_cast<MlirAffineMap *>(result)[idx] = (m);
-                    };
-                    mlirAffineMapCompressUnusedSymbols(
-                        maps.data(), maps.size(), compressed.data(), populate);
-                    std::vector<PyAffineMap> res;
-                    res.reserve(compressed.size());
-                    for (auto m : compressed)
-                      res.emplace_back(context->getRef(), m);
-                    return res;
-                  })
+      .def_static(
+          "compress_unused_symbols",
+          [](const nb::list &affineMaps, DefaultingPyMlirContext context) {
+            SmallVector<MlirAffineMap> maps;
+            pyListToVector<PyAffineMap, MlirAffineMap>(
+                affineMaps, maps, "attempting to create an AffineMap");
+            std::vector<MlirAffineMap> compressed(affineMaps.size());
+            auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
+              static_cast<MlirAffineMap *>(result)[idx] = (m);
+            };
+            mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
+                                               compressed.data(), populate);
+            std::vector<PyAffineMap> res;
+            res.reserve(compressed.size());
+            for (auto m : compressed)
+              res.emplace_back(context->getRef(), m);
+            return res;
+          })
       .def_prop_ro(
           "context",
-          [](PyAffineMap &self) { return self.getContext().getObject(); },
+          [](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Affine Map")
       .def(
           "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
@@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
            })
       .def_prop_ro(
           "context",
-          [](PyIntegerSet &self) { return self.getContext().getObject(); })
+          [](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          })
       .def(
           "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
           kDumpDocstring)
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index 212228fbac91e..c77653f97e6dd 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
 
     PyArrayAttributeIterator &dunderIter() { return *this; }
 
-    nb::object dunderNext() {
+    nb::typed<nb::object, PyAttribute> dunderNext() {
       // TODO: Throw is an inefficient way to stop iteration.
       if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
         throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
         "Gets a uniqued Array attribute");
     c.def(
          "__getitem__",
-         [](PyArrayAttribute &arr, intptr_t i) {
+         [](PyArrayAttribute &arr,
+            intptr_t i) -> nb::typed<nb::object, PyAttribute> {
            if (i >= mlirArrayAttrGetNumElements(arr))
              throw nb::index_error("ArrayAttribute index out of range");
            return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
                      [](PyDenseElementsAttribute &self) -> bool {
                        return mlirDenseElementsAttrIsSplat(self);
                      })
-        .def("get_splat_value", [](PyDenseElementsAttribute &self) {
-          if (!mlirDenseElementsAttrIsSplat(self))
-            throw nb::value_error(
-                "get_splat_value called on a non-splat attribute");
-          return PyAttribute(self.getContext(),
-                             mlirDenseElementsAttrGetSplatValue(self))
-              .maybeDownCast();
-        });
+        .def("get_splat_value",
+             [](PyDenseElementsAttribute &self)
+                 -> nb::typed<nb::object, PyAttribute> {
+               if (!mlirDenseElementsAttrIsSplat(self))
+                 throw nb::value_error(
+                     "get_splat_value called on a non-splat attribute");
+               return PyAttribute(self.getContext(),
+                                  mlirDenseElementsAttrGetSplatValue(self))
+                   .maybeDownCast();
+             });
   }
 
   static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
 
   /// Returns the element at the given linear position. Asserts if the index
   /// is out of range.
-  nb::object dunderGetItem(intptr_t pos) {
+  nb::int_ dunderGetItem(intptr_t pos) {
     if (pos < 0 || pos >= dunderLen()) {
       throw nb::index_error("attempt to access out of bounds element");
     }
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
         },
         nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
         "Gets an uniqued dict attribute");
-    c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
-      MlirAttribute attr =
-          mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
-      if (mlirAttributeIsNull(attr))
-        throw nb::key_error("attempt to access a non-existent attribute");
-      return PyAttribute(self.getContext(), attr).maybeDownCast();
-    });
+    c.def("__getitem__",
+          [](PyDictAttribute &self,
+             const std::string &name) -> nb::typed<nb::object, PyAttribute> {
+            MlirAttribute attr =
+                mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
+            if (mlirAttributeIsNull(attr))
+              throw nb::key_error("attempt to access a non-existent attribute");
+            return PyAttribute(self.getContext(), attr).maybeDownCast();
+          });
     c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
       if (index < 0 || index >= self.dunderLen()) {
         throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
         },
         nb::arg("value"), nb::arg("context") = nb::none(),
         "Gets a uniqued Type attribute");
-    c.def_prop_ro("value", [](PyTypeAttribute &self) {
-      return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
-          .maybeDownCast();
-    });
+    c.def_prop_ro(
+        "value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
+          return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
+              .maybeDownCast();
+        });
   }
 };
 
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b238e11c7fff..e9cfdc7c6d674 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -513,7 +513,7 @@ class PyOperationIterator {
 
   PyOperationIterator &dunderIter() { return *this; }
 
-  nb::object dunderNext() {
+  nb::typed<nb::object, PyOpView> dunderNext() {
     parentOperation->checkValid();
     if (mlirOperationIsNull(next)) {
       throw nb::stop_iteration();
@@ -562,7 +562,7 @@ class PyOperationList {
     return count;
   }
 
-  nb::object dunderGetItem(intptr_t index) {
+  nb::typed<nb::object, PyOpView> dunderGetItem(intptr_t index) {
     parentOperation->checkValid();
     if (index < 0) {
       index += dunderLen();
@@ -725,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) {
       new PyDiagnosticHandler(get(), std::move(callback));
   nb::object pyHandlerObject =
       nb::cast(pyHandler, nb::rv_policy::take_ownership);
-  pyHandlerObject.inc_ref();
+  (void)pyHandlerObject.inc_ref();
 
   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
   // guaranteed to be known to pybind.
@@ -1395,7 +1395,7 @@ nb::object PyOperation::getCapsule() {
   return nb::steal<nb::object>(mlirPythonOperationToCapsule(get()));
 }
 
-nb::object PyOperation::createFromCapsule(nb::object capsule) {
+nb::object PyOperation::createFromCapsule(const nb::object &capsule) {
   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
   if (mlirOperationIsNull(rawOperation))
     throw nb::python_error();
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
         },
         nb::arg("other_value"));
     cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
+            [](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
+              return self.maybeDownCast();
+            });
     DerivedTy::bindDerived(cls);
   }
 
@@ -1623,13 +1625,14 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
   using PyConcreteValue::PyConcreteValue;
 
   static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("owner", [](PyOpResult &self) {
-      assert(
-          mlirOperationEqual(self.getParentOperation()->get(),
-                             mlirOpResultGetOwner(self.get())) &&
-          "expected the owner of the value in Python to match that in the IR");
-      return self.getParentOperation().getObject();
-    });
+    c.def_prop_ro(
+        "owner", [](PyOpResult &self) -> nb::typed<nb::object, PyOperation> {
+          assert(mlirOperationEqual(self.getParentOperation()->get(),
+                                    mlirOpResultGetOwner(self.get())) &&
+                 "expected the owner of the value in Python to match that in "
+                 "the IR");
+          return self.getParentOperation().getObject();
+        });
     c.def_prop_ro("result_number", [](PyOpResult &self) {
       return mlirOpResultGetResultNumber(self.get());
     });
@@ -1638,9 +1641,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
 
 /// Returns the list of types of the values held by container.
 template <typename Container>
-static std::vector<nb::object> getValueTypes(Container &container,
-                                             PyMlirContextRef &context) {
-  std::vector<nb::object> result;
+static std::vector<nb::typed<nb::object, PyType>>
+getValueTypes(Container &container, PyMlirContextRef &context) {
+  std::vector<nb::typed<nb::object, PyType>> result;
   result.reserve(container.size());
   for (int i = 0, e = container.size(); i < e; ++i) {
     result.push_back(PyType(context->getRef(),
@@ -1671,9 +1674,10 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
     c.def_prop_ro("types", [](PyOpResultList &self) {
       return getValueTypes(self, self.operation->getContext());
     });
-    c.def_prop_ro("owner", [](PyOpResultList &self) {
-      return self.operation->createOpView();
-    });
+    c.def_prop_ro("owner",
+                  [](PyOpResultList &self) -> nb::typed<nb::object, PyOpView> {
+                    return self.operation->createOpView();
+                  });
   }
 
   PyOperationRef &getOperation() { return operation; }
@@ -2104,7 +2108,7 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
-  return PyThreadContextEntry::pushInsertionPoint(insertPoint);
+  return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint));
 }
 
 void PyInsertionPoint::contextExit(const nb::object &excType,
@@ -2125,7 +2129,7 @@ nb::object PyAttribute::getCapsule() {
   return nb::steal<nb::object>(mlirPythonAttributeToCapsule(*this));
 }
 
-PyAttribute PyAttribute::createFromCapsule(nb::object capsule) {
+PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) {
   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
   if (mlirAttributeIsNull(rawAttr))
     throw nb::python_error();
@@ -2677,7 +2681,8 @@ class PyOpAttributeMap {
   PyOpAttributeMap(PyOperationRef operation)
       : operation(std::move(operation)) {}
 
-  nb::object dunderGetItemNamed(const std::string &name) {
+  nb::typed<nb::object, PyAttribute>
+  dunderGetItemNamed(const std::string &name) {
     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
                                                          toMlirStringRef(name));
     if (mlirAttributeIsNull(attr)) {
@@ -2962,13 +2967,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            })
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
       .def("_get_context_again",
-           [](PyMlirContext &self) {
+           [](PyMlirContext &self) -> nb::typed<nb::object, PyMlirContext> {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
       .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(),
            nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3123,7 +3129,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyDialectRegistry>(m, "DialectRegistry")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR,
+                  &PyDialectRegistry::createFromCapsule)
       .def(nb::init<>());
 
   //----------------------------------------------------------------------------
@@ -3131,7 +3138,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyLocation>(m, "Location")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
       .def("__enter__", &PyLocation::contextEnter)
       .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(),
            nb::arg("exc_value").none(), nb::arg("traceback").none())
@@ -3286,7 +3293,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "Gets a Location from a LocationAttr")
       .def_prop_ro(
           "context",
-          [](PyLocation &self) { return self.getContext().getObject(); },
+          [](PyLocation &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that owns the Location")
       .def_prop_ro(
           "attr",
@@ -3313,12 +3322,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
-           kModuleCAPICreate)
+      .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+                  kModuleCAPICreate)
       .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
-          [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
+          [](const std::string &moduleAsm, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParse(
                 context->get(), toMlirStringRef(moduleAsm));
@@ -3330,7 +3340,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "parse",
-          [](nb::bytes moduleAsm, DefaultingPyMlirContext context) {
+          [](nb::bytes moduleAsm, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParse(
                 context->get(), toMlirStringRef(moduleAsm));
@@ -3342,7 +3353,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "parseFile",
-          [](const std::string &path, DefaultingPyMlirContext context) {
+          [](const std::string &path, DefaultingPyMlirContext context)
+              -> nb::typed<nb::object, PyModule> {
             PyMlirContext::ErrorCapture errors(context->getRef());
             MlirModule module = mlirModuleCreateParseFromFile(
                 context->get(), toMlirStringRef(path));
@@ -3354,7 +3366,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           kModuleParseDocstring)
       .def_static(
           "create",
-          [](const std::optional<PyLocation> &loc) {
+          [](const std::optional<PyLocation> &loc)
+              -> nb::typed<nb::object, PyModule> {
             PyLocation pyLoc = maybeGetTracebackLocation(loc);
             MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
             return PyModule::forModule(module).releaseObject();
@@ -3362,11 +3375,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           nb::arg("loc") = nb::none(), "Creates an empty module")
       .def_prop_ro(
           "context",
-          [](PyModule &self) { return self.getContext().getObject(); },
+          [](PyModule &self) -> nb::typed<nb::object, PyMlirContext> {
+            return self.getContext().getObject();
+          },
           "Context that created the Module")
       .def_prop_ro(
           "operation",
-          [](PyModule &self) {
+          [](PyModule &self) -> nb::typed<nb::object, PyOperation> {
             return PyOperation::forOperation(self.getContext(),
                                              mlirModuleGetOperation(self.get()),
                                              self.getRef().releaseObject())
@@ -3430,7 +3445,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                    })
       .def_prop_ro(
           "context",
-          [](PyOperationBase &self) {
+          [](PyOperationBase &self) -> nb::typed<nb::object, PyMlirContext> {
             PyOperation &concreteOperation = self.getOperation();
             concreteOperation.checkValid();
             return concreteOperation.getContext().getObject();
@@ -3461,7 +347...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/160221


More information about the Mlir-commits mailing list