[Mlir-commits] [mlir] [mlir][python] Make types in register_(dialect|operation) more narrow. (PR #115307)

Ingo Müller llvmlistbot at llvm.org
Thu Nov 7 04:25:43 PST 2024


https://github.com/ingomueller-net created https://github.com/llvm/llvm-project/pull/115307

This PR makes the `pyClass`/`dialectClass` arguments of the pybind11 functions `register_dialect` and `register_operation` as well as their return types more narrow, concretely, a `py::type` instead of a `py::object`. As the name of the arguments indicate, they have to be called with a type instance (a "class"). The PR also updates the typing stubs of these functions (in the corresponding `.pyi` file), such that static type checkers are aware of the changed type. With the previous typing information, `pyright` raised errors on code generated by tablegen.

>From bbcf647c1afcdc6f96b2252acc655d99140c73c4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Thu, 7 Nov 2024 12:21:17 +0000
Subject: [PATCH] [mlir][python] Make types in register_(dialect|operation)
 more narrow.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR makes the `pyClass`/`dialectClass` arguments of the pybind11
functions `register_dialect` and `register_operation` as well as their
return types more narrow, concretely, a `py::type` instead of a
`py::object`. As the name of the arguments indicate, they have to be
called with a type instance (a "class"). The PR also updates the typing
stubs of these functions (in the corresponding `.pyi` file), such that
static type checkers are aware of the changed type. With the previous
typing information, `pyright` raised errors on code generated by
tablegen.

Signed-off-by: Ingo Müller <ingomueller at google.com>
---
 mlir/lib/Bindings/Python/MainModule.cpp        | 6 +++---
 mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 8da1ab16a4514b..7c27021902de31 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlir, m) {
   // Registration decorators.
   m.def(
       "register_dialect",
-      [](py::object pyClass) {
+      [](py::type pyClass) {
         std::string dialectNamespace =
             pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
         PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
@@ -68,9 +68,9 @@ PYBIND11_MODULE(_mlir, m) {
       "Class decorator for registering a custom Dialect wrapper");
   m.def(
       "register_operation",
-      [](const py::object &dialectClass, bool replace) -> py::cpp_function {
+      [](const py::type &dialectClass, bool replace) -> py::cpp_function {
         return py::cpp_function(
-            [dialectClass, replace](py::object opClass) -> py::object {
+            [dialectClass, replace](py::type opClass) -> py::type {
               std::string operationName =
                   opClass.attr("OPERATION_NAME").cast<std::string>();
               PyGlobals::get().registerOperationImpl(operationName, opClass,
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 42694747e5f24f..03449b70b7fa38 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -8,5 +8,5 @@ class _Globals:
     def append_dialect_search_prefix(self, module_name: str) -> None: ...
     def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
 
-def register_dialect(dialect_class: type) -> object: ...
-def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
+def register_dialect(dialect_class: type) -> type: ...
+def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ...



More information about the Mlir-commits mailing list