[Mlir-commits] [mlir] [mlir][python] set the registry free (PR #72477)

Maksim Levental llvmlistbot at llvm.org
Wed Nov 15 23:20:34 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72477

>From 8b6c773305dfd477413b99db4ef2b775b78ad685 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 15 Nov 2023 23:48:17 -0600
Subject: [PATCH 1/3] [mlir][python] set the registry free

---
 mlir/python/mlir/_mlir_libs/__init__.py | 206 ++++++++++++------------
 1 file changed, 105 insertions(+), 101 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..468925d278c61dd 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
-def _site_initialize():
-    import importlib
-    import itertools
-    import logging
-    from ._mlir import ir
-
-    logger = logging.getLogger(__name__)
-    registry = ir.DialectRegistry()
-    post_init_hooks = []
-    disable_multithreading = False
-
-    def process_initializer_module(module_name):
-        nonlocal disable_multithreading
-        try:
-            m = importlib.import_module(f".{module_name}", __name__)
-        except ModuleNotFoundError:
-            return False
-        except ImportError:
-            message = (
-                f"Error importing mlir initializer {module_name}. This may "
-                "happen in unclean incremental builds but is likely a real bug if "
-                "encountered otherwise and the MLIR Python API may not function."
+import importlib
+import itertools
+import logging
+from ._mlir import ir
+
+logger = logging.getLogger(__name__)
+registry = ir.DialectRegistry()
+post_init_hooks = []
+disable_multithreading = False
+
+
+def get_registry():
+    return registry
+
+
+def process_initializer_module(module_name):
+    global disable_multithreading
+    try:
+        m = importlib.import_module(f".{module_name}", __name__)
+    except ModuleNotFoundError:
+        return False
+    except ImportError:
+        message = (
+            f"Error importing mlir initializer {module_name}. This may "
+            "happen in unclean incremental builds but is likely a real bug if "
+            "encountered otherwise and the MLIR Python API may not function."
+        )
+        logger.warning(message, exc_info=True)
+
+    logger.debug("Initializing MLIR with module: %s", module_name)
+    if hasattr(m, "register_dialects"):
+        logger.debug("Registering dialects from initializer %r", m)
+        m.register_dialects(registry)
+    if hasattr(m, "context_init_hook"):
+        logger.debug("Adding context init hook from %r", m)
+        post_init_hooks.append(m.context_init_hook)
+    if hasattr(m, "disable_multithreading"):
+        if bool(m.disable_multithreading):
+            logger.debug("Disabling multi-threading for context")
+            disable_multithreading = True
+    return True
+
+
+# If _mlirRegisterEverything is built, then include it as an initializer
+# module.
+init_module = None
+if process_initializer_module("_mlirRegisterEverything"):
+    init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+# Load all _site_initialize_{i} modules, where 'i' is a number starting
+# at 0.
+for i in itertools.count():
+    module_name = f"_site_initialize_{i}"
+    if not process_initializer_module(module_name):
+        break
+
+
+class Context(ir._BaseContext):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.append_dialect_registry(get_registry())
+        for hook in post_init_hooks:
+            hook(self)
+        if not disable_multithreading:
+            self.enable_multithreading(True)
+        # TODO: There is some debate about whether we should eagerly load
+        # all dialects. It is being done here in order to preserve existing
+        # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+        self.load_all_available_dialects()
+        if init_module:
+            logger.debug("Registering translations from initializer %r", init_module)
+            init_module.register_llvm_translations(self)
+
+
+ir.Context = Context
+
+
+class MLIRError(Exception):
+    """
+    An exception with diagnostic information. Has the following fields:
+      message: str
+      error_diagnostics: List[ir.DiagnosticInfo]
+    """
+
+    def __init__(self, message, error_diagnostics):
+        self.message = message
+        self.error_diagnostics = error_diagnostics
+        super().__init__(message, error_diagnostics)
+
+    def __str__(self):
+        s = self.message
+        if self.error_diagnostics:
+            s += ":"
+        for diag in self.error_diagnostics:
+            s += (
+                "\nerror: "
+                + str(diag.location)[4:-1]
+                + ": "
+                + diag.message.replace("\n", "\n  ")
             )
-            logger.warning(message, exc_info=True)
-
-        logger.debug("Initializing MLIR with module: %s", module_name)
-        if hasattr(m, "register_dialects"):
-            logger.debug("Registering dialects from initializer %r", m)
-            m.register_dialects(registry)
-        if hasattr(m, "context_init_hook"):
-            logger.debug("Adding context init hook from %r", m)
-            post_init_hooks.append(m.context_init_hook)
-        if hasattr(m, "disable_multithreading"):
-            if bool(m.disable_multithreading):
-                logger.debug("Disabling multi-threading for context")
-                disable_multithreading = True
-        return True
-
-    # If _mlirRegisterEverything is built, then include it as an initializer
-    # module.
-    init_module = None
-    if process_initializer_module("_mlirRegisterEverything"):
-        init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
-    # Load all _site_initialize_{i} modules, where 'i' is a number starting
-    # at 0.
-    for i in itertools.count():
-        module_name = f"_site_initialize_{i}"
-        if not process_initializer_module(module_name):
-            break
-
-    class Context(ir._BaseContext):
-        def __init__(self, *args, **kwargs):
-            super().__init__(*args, **kwargs)
-            self.append_dialect_registry(registry)
-            for hook in post_init_hooks:
-                hook(self)
-            if not disable_multithreading:
-                self.enable_multithreading(True)
-            # TODO: There is some debate about whether we should eagerly load
-            # all dialects. It is being done here in order to preserve existing
-            # behavior. See: https://github.com/llvm/llvm-project/issues/56037
-            self.load_all_available_dialects()
-            if init_module:
-                logger.debug(
-                    "Registering translations from initializer %r", init_module
-                )
-                init_module.register_llvm_translations(self)
-
-    ir.Context = Context
-
-    class MLIRError(Exception):
-        """
-        An exception with diagnostic information. Has the following fields:
-          message: str
-          error_diagnostics: List[ir.DiagnosticInfo]
-        """
-
-        def __init__(self, message, error_diagnostics):
-            self.message = message
-            self.error_diagnostics = error_diagnostics
-            super().__init__(message, error_diagnostics)
-
-        def __str__(self):
-            s = self.message
-            if self.error_diagnostics:
-                s += ":"
-            for diag in self.error_diagnostics:
+            for note in diag.notes:
                 s += (
-                    "\nerror: "
-                    + str(diag.location)[4:-1]
+                    "\n note: "
+                    + str(note.location)[4:-1]
                     + ": "
-                    + diag.message.replace("\n", "\n  ")
+                    + note.message.replace("\n", "\n  ")
                 )
-                for note in diag.notes:
-                    s += (
-                        "\n note: "
-                        + str(note.location)[4:-1]
-                        + ": "
-                        + note.message.replace("\n", "\n  ")
-                    )
-            return s
-
-    ir.MLIRError = MLIRError
+        return s
 
 
-_site_initialize()
+ir.MLIRError = MLIRError

>From b23938a217f9fb6e489b154a79c42182c4db4e34 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 16 Nov 2023 00:34:19 -0600
Subject: [PATCH 2/3] [mlir][python] hide everything in a namespace/module

---
 mlir/python/mlir/_mlir_libs/__init__.py | 319 ++++++++++++------------
 1 file changed, 166 insertions(+), 153 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 468925d278c61dd..2b2306d121cbfd5 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -2,168 +2,181 @@
 # See https://llvm.org/LICENSE.txt for license information.
 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from typing import Any, Sequence
-
-import os
-
-_this_dir = os.path.dirname(__file__)
-
-
-def get_lib_dirs() -> Sequence[str]:
-    """Gets the lib directory for linking to shared libraries.
-
-    On some platforms, the package may need to be built specially to export
-    development libraries.
-    """
-    return [_this_dir]
-
-
-def get_include_dirs() -> Sequence[str]:
-    """Gets the include directory for compiling against exported C libraries.
-
-    Depending on how the package was build, development C libraries may or may
-    not be present.
-    """
-    return [os.path.join(_this_dir, "include")]
-
-
-# Perform Python level site initialization. This involves:
-#   1. Attempting to load initializer modules, specific to the distribution.
-#   2. Defining the concrete mlir.ir.Context that does site specific
-#      initialization.
-#
-# Aside from just being far more convenient to do this at the Python level,
-# it is actually quite hard/impossible to have such __init__ hooks, given
-# the pybind memory model (i.e. there is not a Python reference to the object
-# in the scope of the base class __init__).
-#
-# For #1, we:
-#   a. Probe for modules named '_mlirRegisterEverything' and
-#     '_site_initialize_{i}', where 'i' is a number starting at zero and
-#     proceeding so long as a module with the name is found.
-#   b. If the module has a 'register_dialects' attribute, it will be called
-#     immediately with a DialectRegistry to populate.
-#   c. If the module has a 'context_init_hook', it will be added to a list
-#     of callbacks that are invoked as the last step of Context
-#     initialization (and passed the Context under construction).
-#   d. If the module has a 'disable_multithreading' attribute, it will be
-#     taken as a boolean. If it is True for any initializer, then the
-#     default behavior of enabling multithreading on the context
-#     will be suppressed. This complies with the original behavior of all
-#     contexts being created with multithreading enabled while allowing
-#     this behavior to be changed if needed (i.e. if a context_init_hook
-#     explicitly sets up multithreading).
-#
-# This facility allows downstreams to customize Context creation to their
-# needs.
 import importlib
 import itertools
 import logging
+import os
+import sys
+from typing import Sequence
+
 from ._mlir import ir
 
+
+_this_dir = os.path.dirname(__file__)
+
 logger = logging.getLogger(__name__)
-registry = ir.DialectRegistry()
-post_init_hooks = []
-disable_multithreading = False
-
-
-def get_registry():
-    return registry
-
-
-def process_initializer_module(module_name):
-    global disable_multithreading
-    try:
-        m = importlib.import_module(f".{module_name}", __name__)
-    except ModuleNotFoundError:
-        return False
-    except ImportError:
-        message = (
-            f"Error importing mlir initializer {module_name}. This may "
-            "happen in unclean incremental builds but is likely a real bug if "
-            "encountered otherwise and the MLIR Python API may not function."
-        )
-        logger.warning(message, exc_info=True)
-
-    logger.debug("Initializing MLIR with module: %s", module_name)
-    if hasattr(m, "register_dialects"):
-        logger.debug("Registering dialects from initializer %r", m)
-        m.register_dialects(registry)
-    if hasattr(m, "context_init_hook"):
-        logger.debug("Adding context init hook from %r", m)
-        post_init_hooks.append(m.context_init_hook)
-    if hasattr(m, "disable_multithreading"):
-        if bool(m.disable_multithreading):
-            logger.debug("Disabling multi-threading for context")
-            disable_multithreading = True
-    return True
-
-
-# If _mlirRegisterEverything is built, then include it as an initializer
-# module.
-init_module = None
-if process_initializer_module("_mlirRegisterEverything"):
-    init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
-# Load all _site_initialize_{i} modules, where 'i' is a number starting
-# at 0.
-for i in itertools.count():
-    module_name = f"_site_initialize_{i}"
-    if not process_initializer_module(module_name):
-        break
-
-
-class Context(ir._BaseContext):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.append_dialect_registry(get_registry())
-        for hook in post_init_hooks:
-            hook(self)
-        if not disable_multithreading:
-            self.enable_multithreading(True)
-        # TODO: There is some debate about whether we should eagerly load
-        # all dialects. It is being done here in order to preserve existing
-        # behavior. See: https://github.com/llvm/llvm-project/issues/56037
-        self.load_all_available_dialects()
-        if init_module:
-            logger.debug("Registering translations from initializer %r", init_module)
-            init_module.register_llvm_translations(self)
-
-
-ir.Context = Context
-
-
-class MLIRError(Exception):
-    """
-    An exception with diagnostic information. Has the following fields:
-      message: str
-      error_diagnostics: List[ir.DiagnosticInfo]
-    """
-
-    def __init__(self, message, error_diagnostics):
-        self.message = message
-        self.error_diagnostics = error_diagnostics
-        super().__init__(message, error_diagnostics)
-
-    def __str__(self):
-        s = self.message
-        if self.error_diagnostics:
-            s += ":"
-        for diag in self.error_diagnostics:
-            s += (
-                "\nerror: "
-                + str(diag.location)[4:-1]
-                + ": "
-                + diag.message.replace("\n", "\n  ")
+
+_path = __path__
+_spec = __spec__
+_name = __name__
+
+
+class _M:
+    __path__ = _path
+    __spec__ = _spec
+    __name__ = _name
+
+    @staticmethod
+    def get_lib_dirs() -> Sequence[str]:
+        """Gets the lib directory for linking to shared libraries.
+
+        On some platforms, the package may need to be built specially to export
+        development libraries.
+        """
+        return [_this_dir]
+
+    @staticmethod
+    def get_include_dirs() -> Sequence[str]:
+        """Gets the include directory for compiling against exported C libraries.
+
+        Depending on how the package was build, development C libraries may or may
+        not be present.
+        """
+        return [os.path.join(_this_dir, "include")]
+
+    # Perform Python level site initialization. This involves:
+    #   1. Attempting to load initializer modules, specific to the distribution.
+    #   2. Defining the concrete mlir.ir.Context that does site specific
+    #      initialization.
+    #
+    # Aside from just being far more convenient to do this at the Python level,
+    # it is actually quite hard/impossible to have such __init__ hooks, given
+    # the pybind memory model (i.e. there is not a Python reference to the object
+    # in the scope of the base class __init__).
+    #
+    # For #1, we:
+    #   a. Probe for modules named '_mlirRegisterEverything' and
+    #     '_site_initialize_{i}', where 'i' is a number starting at zero and
+    #     proceeding so long as a module with the name is found.
+    #   b. If the module has a 'register_dialects' attribute, it will be called
+    #     immediately with a DialectRegistry to populate.
+    #   c. If the module has a 'context_init_hook', it will be added to a list
+    #     of callbacks that are invoked as the last step of Context
+    #     initialization (and passed the Context under construction).
+    #   d. If the module has a 'disable_multithreading' attribute, it will be
+    #     taken as a boolean. If it is True for any initializer, then the
+    #     default behavior of enabling multithreading on the context
+    #     will be suppressed. This complies with the original behavior of all
+    #     contexts being created with multithreading enabled while allowing
+    #     this behavior to be changed if needed (i.e. if a context_init_hook
+    #     explicitly sets up multithreading).
+    #
+    # This facility allows downstreams to customize Context creation to their
+    # needs.
+
+    __registry = ir.DialectRegistry()
+    __post_init_hooks = []
+    __disable_multithreading = False
+    from . import _mlir as _mlir
+
+    def __get_registry(self):
+        return self.__registry
+
+    def process_initializer_module(self, module_name):
+        try:
+            m = importlib.import_module(f".{module_name}", __name__)
+        except ModuleNotFoundError:
+            return False
+        except ImportError:
+            message = (
+                f"Error importing mlir initializer {module_name}. This may "
+                "happen in unclean incremental builds but is likely a real bug if "
+                "encountered otherwise and the MLIR Python API may not function."
             )
-            for note in diag.notes:
+            logger.warning(message, exc_info=True)
+
+        logger.debug("Initializing MLIR with module: %s", module_name)
+        if hasattr(m, "register_dialects"):
+            logger.debug("Registering dialects from initializer %r", m)
+            m.register_dialects(self.__get_registry())
+        if hasattr(m, "context_init_hook"):
+            logger.debug("Adding context init hook from %r", m)
+            self.__post_init_hooks.append(m.context_init_hook)
+        if hasattr(m, "disable_multithreading"):
+            if bool(m.disable_multithreading):
+                logger.debug("Disabling multi-threading for context")
+                self.__disable_multithreading = True
+        return True
+
+    def __init__(self):
+        # If _mlirRegisterEverything is built, then include it as an initializer
+        # module.
+        init_module = None
+        if self.process_initializer_module("_mlirRegisterEverything"):
+            init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+        # Load all _site_initialize_{i} modules, where 'i' is a number starting
+        # at 0.
+        for i in itertools.count():
+            module_name = f"_site_initialize_{i}"
+            if not self.process_initializer_module(module_name):
+                break
+
+        that = self
+
+        class Context(ir._BaseContext):
+            def __init__(self, *args, **kwargs):
+                super().__init__(*args, **kwargs)
+                self.append_dialect_registry(that._M__get_registry())
+                for hook in that._M__post_init_hooks:
+                    hook(self)
+                if not that._M__disable_multithreading:
+                    self.enable_multithreading(True)
+                # TODO: There is some debate about whether we should eagerly load
+                # all dialects. It is being done here in order to preserve existing
+                # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+                self.load_all_available_dialects()
+                if init_module:
+                    logger.debug(
+                        "Registering translations from initializer %r", init_module
+                    )
+                    init_module.register_llvm_translations(self)
+
+        ir.Context = Context
+
+    class MLIRError(Exception):
+        """
+        An exception with diagnostic information. Has the following fields:
+          message: str
+          error_diagnostics: List[ir.DiagnosticInfo]
+        """
+
+        def __init__(self, message, error_diagnostics):
+            self.message = message
+            self.error_diagnostics = error_diagnostics
+            super().__init__(message, error_diagnostics)
+
+        def __str__(self):
+            s = self.message
+            if self.error_diagnostics:
+                s += ":"
+            for diag in self.error_diagnostics:
                 s += (
-                    "\n note: "
-                    + str(note.location)[4:-1]
+                    "\nerror: "
+                    + str(diag.location)[4:-1]
                     + ": "
-                    + note.message.replace("\n", "\n  ")
+                    + diag.message.replace("\n", "\n  ")
                 )
-        return s
+                for note in diag.notes:
+                    s += (
+                        "\n note: "
+                        + str(note.location)[4:-1]
+                        + ": "
+                        + note.message.replace("\n", "\n  ")
+                    )
+            return s
+
+    ir.MLIRError = MLIRError
 
 
-ir.MLIRError = MLIRError
+sys.modules[__name__] = _M()

>From b58593b2f79f2531588eaa3278e23798a15b15e1 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 16 Nov 2023 01:20:22 -0600
Subject: [PATCH 3/3] [mlir][python] demo registering

---
 mlir/python/mlir/_mlir_libs/__init__.py   | 10 +++++-----
 mlir/python/mlir/dialects/python_test.py  |  6 ------
 mlir/test/python/dialects/python_test.py  | 17 ++++-------------
 mlir/test/python/lib/PythonTestModule.cpp |  9 +++++++++
 4 files changed, 18 insertions(+), 24 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 2b2306d121cbfd5..f5dfd1edf5a3e61 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -82,9 +82,9 @@ def get_include_dirs() -> Sequence[str]:
     def __get_registry(self):
         return self.__registry
 
-    def process_initializer_module(self, module_name):
+    def process_c_ext_module(self, module_name):
         try:
-            m = importlib.import_module(f".{module_name}", __name__)
+            m = importlib.import_module(f"{module_name}", __name__)
         except ModuleNotFoundError:
             return False
         except ImportError:
@@ -112,14 +112,14 @@ def __init__(self):
         # If _mlirRegisterEverything is built, then include it as an initializer
         # module.
         init_module = None
-        if self.process_initializer_module("_mlirRegisterEverything"):
+        if self.process_c_ext_module("._mlirRegisterEverything"):
             init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
 
         # Load all _site_initialize_{i} modules, where 'i' is a number starting
         # at 0.
         for i in itertools.count():
-            module_name = f"_site_initialize_{i}"
-            if not self.process_initializer_module(module_name):
+            module_name = f"._site_initialize_{i}"
+            if not self.process_c_ext_module(module_name):
                 break
 
         that = self
diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 6579e02d8549efa..8b4f718d8a53b13 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -9,9 +9,3 @@
     TestTensorValue,
     TestIntegerRankedTensorType,
 )
-
-
-def register_python_test_dialect(context, load=True):
-    from .._mlir_libs import _mlirPythonTest
-
-    _mlirPythonTest.register_python_test_dialect(context, load)
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index f313a400b73c0a5..309de8037049c2c 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,5 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+from mlir import _mlir_libs
+
+_mlir_libs.process_c_ext_module("mlir._mlir_libs._mlirPythonTest")
+
 from mlir.ir import *
 import mlir.dialects.func as func
 import mlir.dialects.python_test as test
@@ -17,7 +21,6 @@ def run(f):
 @run
 def testAttributes():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
         #
         # Check op construction with attributes.
         #
@@ -138,7 +141,6 @@ def testAttributes():
 @run
 def attrBuilder():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
         # CHECK: python_test.attributes_op
         op = test.AttributesOp(
             # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
@@ -215,7 +217,6 @@ def attrBuilder():
 @run
 def inferReturnTypes():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             op = test.InferResultsOp()
@@ -260,7 +261,6 @@ def inferReturnTypes():
 @run
 def resultTypesDefinedByTraits():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             inferred = test.InferResultsOp()
@@ -295,7 +295,6 @@ def resultTypesDefinedByTraits():
 @run
 def testOptionalOperandOp():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
 
         module = Module.create()
         with InsertionPoint(module.body):
@@ -312,7 +311,6 @@ def testOptionalOperandOp():
 @run
 def testCustomAttribute():
     with Context() as ctx:
-        test.register_python_test_dialect(ctx)
         a = test.TestAttr.get()
         # CHECK: #python_test.test_attr
         print(a)
@@ -350,7 +348,6 @@ def testCustomAttribute():
 @run
 def testCustomType():
     with Context() as ctx:
-        test.register_python_test_dialect(ctx)
         a = test.TestType.get()
         # CHECK: !python_test.test_type
         print(a)
@@ -397,8 +394,6 @@ def testCustomType():
 # CHECK-LABEL: TEST: testTensorValue
 def testTensorValue():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
-
         i8 = IntegerType.get_signless(8)
 
         class Tensor(test.TestTensorValue):
@@ -436,7 +431,6 @@ def __str__(self):
 @run
 def inferReturnTypeComponents():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         i32 = IntegerType.get_signless(32)
         with InsertionPoint(module.body):
@@ -488,8 +482,6 @@ def inferReturnTypeComponents():
 @run
 def testCustomTypeTypeCaster():
     with Context() as ctx, Location.unknown():
-        test.register_python_test_dialect(ctx)
-
         a = test.TestType.get()
         assert a.typeid is not None
 
@@ -542,7 +534,6 @@ def type_caster(pytype):
 @run
 def testInferTypeOpInterface():
     with Context() as ctx, Location.unknown(ctx):
-        test.register_python_test_dialect(ctx)
         module = Module.create()
         with InsertionPoint(module.body):
             i64 = IntegerType.get_signless(64)
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index aff414894cb825a..9e7decefa7166bc 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -34,6 +34,15 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       },
       py::arg("context"), py::arg("load") = true);
 
+  m.def(
+      "register_dialects",
+      [](MlirDialectRegistry registry) {
+        MlirDialectHandle pythonTestDialect =
+            mlirGetDialectHandle__python_test__();
+        mlirDialectHandleInsertDialect(pythonTestDialect, registry);
+      },
+      py::arg("registry"));
+
   mlir_attribute_subclass(m, "TestAttr",
                           mlirAttributeIsAPythonTestTestAttribute)
       .def_classmethod(



More information about the Mlir-commits mailing list