[Mlir-commits] [mlir] [mlir][python] set the registry free (PR #72477)

Maksim Levental llvmlistbot at llvm.org
Wed Nov 15 22:10:33 PST 2023


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72477

>From 8b6c773305dfd477413b99db4ef2b775b78ad685 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 15 Nov 2023 23:48:17 -0600
Subject: [PATCH] [mlir][python] set the registry free

---
 mlir/python/mlir/_mlir_libs/__init__.py | 206 ++++++++++++------------
 1 file changed, 105 insertions(+), 101 deletions(-)

diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..468925d278c61dd 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
-def _site_initialize():
-    import importlib
-    import itertools
-    import logging
-    from ._mlir import ir
-
-    logger = logging.getLogger(__name__)
-    registry = ir.DialectRegistry()
-    post_init_hooks = []
-    disable_multithreading = False
-
-    def process_initializer_module(module_name):
-        nonlocal disable_multithreading
-        try:
-            m = importlib.import_module(f".{module_name}", __name__)
-        except ModuleNotFoundError:
-            return False
-        except ImportError:
-            message = (
-                f"Error importing mlir initializer {module_name}. This may "
-                "happen in unclean incremental builds but is likely a real bug if "
-                "encountered otherwise and the MLIR Python API may not function."
+import importlib
+import itertools
+import logging
+from ._mlir import ir
+
+logger = logging.getLogger(__name__)
+registry = ir.DialectRegistry()
+post_init_hooks = []
+disable_multithreading = False
+
+
+def get_registry():
+    return registry
+
+
+def process_initializer_module(module_name):
+    global disable_multithreading
+    try:
+        m = importlib.import_module(f".{module_name}", __name__)
+    except ModuleNotFoundError:
+        return False
+    except ImportError:
+        message = (
+            f"Error importing mlir initializer {module_name}. This may "
+            "happen in unclean incremental builds but is likely a real bug if "
+            "encountered otherwise and the MLIR Python API may not function."
+        )
+        logger.warning(message, exc_info=True)
+
+    logger.debug("Initializing MLIR with module: %s", module_name)
+    if hasattr(m, "register_dialects"):
+        logger.debug("Registering dialects from initializer %r", m)
+        m.register_dialects(registry)
+    if hasattr(m, "context_init_hook"):
+        logger.debug("Adding context init hook from %r", m)
+        post_init_hooks.append(m.context_init_hook)
+    if hasattr(m, "disable_multithreading"):
+        if bool(m.disable_multithreading):
+            logger.debug("Disabling multi-threading for context")
+            disable_multithreading = True
+    return True
+
+
+# If _mlirRegisterEverything is built, then include it as an initializer
+# module.
+init_module = None
+if process_initializer_module("_mlirRegisterEverything"):
+    init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+# Load all _site_initialize_{i} modules, where 'i' is a number starting
+# at 0.
+for i in itertools.count():
+    module_name = f"_site_initialize_{i}"
+    if not process_initializer_module(module_name):
+        break
+
+
+class Context(ir._BaseContext):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.append_dialect_registry(get_registry())
+        for hook in post_init_hooks:
+            hook(self)
+        if not disable_multithreading:
+            self.enable_multithreading(True)
+        # TODO: There is some debate about whether we should eagerly load
+        # all dialects. It is being done here in order to preserve existing
+        # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+        self.load_all_available_dialects()
+        if init_module:
+            logger.debug("Registering translations from initializer %r", init_module)
+            init_module.register_llvm_translations(self)
+
+
+ir.Context = Context
+
+
+class MLIRError(Exception):
+    """
+    An exception with diagnostic information. Has the following fields:
+      message: str
+      error_diagnostics: List[ir.DiagnosticInfo]
+    """
+
+    def __init__(self, message, error_diagnostics):
+        self.message = message
+        self.error_diagnostics = error_diagnostics
+        super().__init__(message, error_diagnostics)
+
+    def __str__(self):
+        s = self.message
+        if self.error_diagnostics:
+            s += ":"
+        for diag in self.error_diagnostics:
+            s += (
+                "\nerror: "
+                + str(diag.location)[4:-1]
+                + ": "
+                + diag.message.replace("\n", "\n  ")
             )
-            logger.warning(message, exc_info=True)
-
-        logger.debug("Initializing MLIR with module: %s", module_name)
-        if hasattr(m, "register_dialects"):
-            logger.debug("Registering dialects from initializer %r", m)
-            m.register_dialects(registry)
-        if hasattr(m, "context_init_hook"):
-            logger.debug("Adding context init hook from %r", m)
-            post_init_hooks.append(m.context_init_hook)
-        if hasattr(m, "disable_multithreading"):
-            if bool(m.disable_multithreading):
-                logger.debug("Disabling multi-threading for context")
-                disable_multithreading = True
-        return True
-
-    # If _mlirRegisterEverything is built, then include it as an initializer
-    # module.
-    init_module = None
-    if process_initializer_module("_mlirRegisterEverything"):
-        init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
-    # Load all _site_initialize_{i} modules, where 'i' is a number starting
-    # at 0.
-    for i in itertools.count():
-        module_name = f"_site_initialize_{i}"
-        if not process_initializer_module(module_name):
-            break
-
-    class Context(ir._BaseContext):
-        def __init__(self, *args, **kwargs):
-            super().__init__(*args, **kwargs)
-            self.append_dialect_registry(registry)
-            for hook in post_init_hooks:
-                hook(self)
-            if not disable_multithreading:
-                self.enable_multithreading(True)
-            # TODO: There is some debate about whether we should eagerly load
-            # all dialects. It is being done here in order to preserve existing
-            # behavior. See: https://github.com/llvm/llvm-project/issues/56037
-            self.load_all_available_dialects()
-            if init_module:
-                logger.debug(
-                    "Registering translations from initializer %r", init_module
-                )
-                init_module.register_llvm_translations(self)
-
-    ir.Context = Context
-
-    class MLIRError(Exception):
-        """
-        An exception with diagnostic information. Has the following fields:
-          message: str
-          error_diagnostics: List[ir.DiagnosticInfo]
-        """
-
-        def __init__(self, message, error_diagnostics):
-            self.message = message
-            self.error_diagnostics = error_diagnostics
-            super().__init__(message, error_diagnostics)
-
-        def __str__(self):
-            s = self.message
-            if self.error_diagnostics:
-                s += ":"
-            for diag in self.error_diagnostics:
+            for note in diag.notes:
                 s += (
-                    "\nerror: "
-                    + str(diag.location)[4:-1]
+                    "\n note: "
+                    + str(note.location)[4:-1]
                     + ": "
-                    + diag.message.replace("\n", "\n  ")
+                    + note.message.replace("\n", "\n  ")
                 )
-                for note in diag.notes:
-                    s += (
-                        "\n note: "
-                        + str(note.location)[4:-1]
-                        + ": "
-                        + note.message.replace("\n", "\n  ")
-                    )
-            return s
-
-    ir.MLIRError = MLIRError
+        return s
 
 
-_site_initialize()
+ir.MLIRError = MLIRError



More information about the Mlir-commits mailing list