[Mlir-commits] [mlir] [mlir][python] set the registry free (PR #72477)
Maksim Levental
llvmlistbot at llvm.org
Wed Nov 15 22:10:33 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72477
>From 8b6c773305dfd477413b99db4ef2b775b78ad685 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 15 Nov 2023 23:48:17 -0600
Subject: [PATCH] [mlir][python] set the registry free
---
mlir/python/mlir/_mlir_libs/__init__.py | 206 ++++++++++++------------
1 file changed, 105 insertions(+), 101 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..468925d278c61dd 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.
-def _site_initialize():
- import importlib
- import itertools
- import logging
- from ._mlir import ir
-
- logger = logging.getLogger(__name__)
- registry = ir.DialectRegistry()
- post_init_hooks = []
- disable_multithreading = False
-
- def process_initializer_module(module_name):
- nonlocal disable_multithreading
- try:
- m = importlib.import_module(f".{module_name}", __name__)
- except ModuleNotFoundError:
- return False
- except ImportError:
- message = (
- f"Error importing mlir initializer {module_name}. This may "
- "happen in unclean incremental builds but is likely a real bug if "
- "encountered otherwise and the MLIR Python API may not function."
+import importlib
+import itertools
+import logging
+from ._mlir import ir
+
+logger = logging.getLogger(__name__)
+registry = ir.DialectRegistry()
+post_init_hooks = []
+disable_multithreading = False
+
+
+def get_registry():
+ return registry
+
+
+def process_initializer_module(module_name):
+ global disable_multithreading
+ try:
+ m = importlib.import_module(f".{module_name}", __name__)
+ except ModuleNotFoundError:
+ return False
+ except ImportError:
+ message = (
+ f"Error importing mlir initializer {module_name}. This may "
+ "happen in unclean incremental builds but is likely a real bug if "
+ "encountered otherwise and the MLIR Python API may not function."
+ )
+ logger.warning(message, exc_info=True)
+
+ logger.debug("Initializing MLIR with module: %s", module_name)
+ if hasattr(m, "register_dialects"):
+ logger.debug("Registering dialects from initializer %r", m)
+ m.register_dialects(registry)
+ if hasattr(m, "context_init_hook"):
+ logger.debug("Adding context init hook from %r", m)
+ post_init_hooks.append(m.context_init_hook)
+ if hasattr(m, "disable_multithreading"):
+ if bool(m.disable_multithreading):
+ logger.debug("Disabling multi-threading for context")
+ disable_multithreading = True
+ return True
+
+
+# If _mlirRegisterEverything is built, then include it as an initializer
+# module.
+init_module = None
+if process_initializer_module("_mlirRegisterEverything"):
+ init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+# Load all _site_initialize_{i} modules, where 'i' is a number starting
+# at 0.
+for i in itertools.count():
+ module_name = f"_site_initialize_{i}"
+ if not process_initializer_module(module_name):
+ break
+
+
+class Context(ir._BaseContext):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.append_dialect_registry(get_registry())
+ for hook in post_init_hooks:
+ hook(self)
+ if not disable_multithreading:
+ self.enable_multithreading(True)
+ # TODO: There is some debate about whether we should eagerly load
+ # all dialects. It is being done here in order to preserve existing
+ # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+ self.load_all_available_dialects()
+ if init_module:
+ logger.debug("Registering translations from initializer %r", init_module)
+ init_module.register_llvm_translations(self)
+
+
+ir.Context = Context
+
+
+class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ":"
+ for diag in self.error_diagnostics:
+ s += (
+ "\nerror: "
+ + str(diag.location)[4:-1]
+ + ": "
+ + diag.message.replace("\n", "\n ")
)
- logger.warning(message, exc_info=True)
-
- logger.debug("Initializing MLIR with module: %s", module_name)
- if hasattr(m, "register_dialects"):
- logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
- if hasattr(m, "context_init_hook"):
- logger.debug("Adding context init hook from %r", m)
- post_init_hooks.append(m.context_init_hook)
- if hasattr(m, "disable_multithreading"):
- if bool(m.disable_multithreading):
- logger.debug("Disabling multi-threading for context")
- disable_multithreading = True
- return True
-
- # If _mlirRegisterEverything is built, then include it as an initializer
- # module.
- init_module = None
- if process_initializer_module("_mlirRegisterEverything"):
- init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
- # Load all _site_initialize_{i} modules, where 'i' is a number starting
- # at 0.
- for i in itertools.count():
- module_name = f"_site_initialize_{i}"
- if not process_initializer_module(module_name):
- break
-
- class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.append_dialect_registry(registry)
- for hook in post_init_hooks:
- hook(self)
- if not disable_multithreading:
- self.enable_multithreading(True)
- # TODO: There is some debate about whether we should eagerly load
- # all dialects. It is being done here in order to preserve existing
- # behavior. See: https://github.com/llvm/llvm-project/issues/56037
- self.load_all_available_dialects()
- if init_module:
- logger.debug(
- "Registering translations from initializer %r", init_module
- )
- init_module.register_llvm_translations(self)
-
- ir.Context = Context
-
- class MLIRError(Exception):
- """
- An exception with diagnostic information. Has the following fields:
- message: str
- error_diagnostics: List[ir.DiagnosticInfo]
- """
-
- def __init__(self, message, error_diagnostics):
- self.message = message
- self.error_diagnostics = error_diagnostics
- super().__init__(message, error_diagnostics)
-
- def __str__(self):
- s = self.message
- if self.error_diagnostics:
- s += ":"
- for diag in self.error_diagnostics:
+ for note in diag.notes:
s += (
- "\nerror: "
- + str(diag.location)[4:-1]
+ "\n note: "
+ + str(note.location)[4:-1]
+ ": "
- + diag.message.replace("\n", "\n ")
+ + note.message.replace("\n", "\n ")
)
- for note in diag.notes:
- s += (
- "\n note: "
- + str(note.location)[4:-1]
- + ": "
- + note.message.replace("\n", "\n ")
- )
- return s
-
- ir.MLIRError = MLIRError
+ return s
-_site_initialize()
+ir.MLIRError = MLIRError
More information about the Mlir-commits
mailing list