[Mlir-commits] [mlir] [mlir][python] set the registry free (PR #72477)
Maksim Levental
llvmlistbot at llvm.org
Wed Nov 15 22:34:31 PST 2023
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/72477
>From 8b6c773305dfd477413b99db4ef2b775b78ad685 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 15 Nov 2023 23:48:17 -0600
Subject: [PATCH 1/2] [mlir][python] set the registry free
---
mlir/python/mlir/_mlir_libs/__init__.py | 206 ++++++++++++------------
1 file changed, 105 insertions(+), 101 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 6ce77b4cb93f609..468925d278c61dd 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -56,110 +56,114 @@ def get_include_dirs() -> Sequence[str]:
#
# This facility allows downstreams to customize Context creation to their
# needs.
-def _site_initialize():
- import importlib
- import itertools
- import logging
- from ._mlir import ir
-
- logger = logging.getLogger(__name__)
- registry = ir.DialectRegistry()
- post_init_hooks = []
- disable_multithreading = False
-
- def process_initializer_module(module_name):
- nonlocal disable_multithreading
- try:
- m = importlib.import_module(f".{module_name}", __name__)
- except ModuleNotFoundError:
- return False
- except ImportError:
- message = (
- f"Error importing mlir initializer {module_name}. This may "
- "happen in unclean incremental builds but is likely a real bug if "
- "encountered otherwise and the MLIR Python API may not function."
+import importlib
+import itertools
+import logging
+from ._mlir import ir
+
+logger = logging.getLogger(__name__)
+registry = ir.DialectRegistry()
+post_init_hooks = []
+disable_multithreading = False
+
+
+def get_registry():
+ return registry
+
+
+def process_initializer_module(module_name):
+ global disable_multithreading
+ try:
+ m = importlib.import_module(f".{module_name}", __name__)
+ except ModuleNotFoundError:
+ return False
+ except ImportError:
+ message = (
+ f"Error importing mlir initializer {module_name}. This may "
+ "happen in unclean incremental builds but is likely a real bug if "
+ "encountered otherwise and the MLIR Python API may not function."
+ )
+ logger.warning(message, exc_info=True)
+
+ logger.debug("Initializing MLIR with module: %s", module_name)
+ if hasattr(m, "register_dialects"):
+ logger.debug("Registering dialects from initializer %r", m)
+ m.register_dialects(registry)
+ if hasattr(m, "context_init_hook"):
+ logger.debug("Adding context init hook from %r", m)
+ post_init_hooks.append(m.context_init_hook)
+ if hasattr(m, "disable_multithreading"):
+ if bool(m.disable_multithreading):
+ logger.debug("Disabling multi-threading for context")
+ disable_multithreading = True
+ return True
+
+
+# If _mlirRegisterEverything is built, then include it as an initializer
+# module.
+init_module = None
+if process_initializer_module("_mlirRegisterEverything"):
+ init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+# Load all _site_initialize_{i} modules, where 'i' is a number starting
+# at 0.
+for i in itertools.count():
+ module_name = f"_site_initialize_{i}"
+ if not process_initializer_module(module_name):
+ break
+
+
+class Context(ir._BaseContext):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.append_dialect_registry(get_registry())
+ for hook in post_init_hooks:
+ hook(self)
+ if not disable_multithreading:
+ self.enable_multithreading(True)
+ # TODO: There is some debate about whether we should eagerly load
+ # all dialects. It is being done here in order to preserve existing
+ # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+ self.load_all_available_dialects()
+ if init_module:
+ logger.debug("Registering translations from initializer %r", init_module)
+ init_module.register_llvm_translations(self)
+
+
+ir.Context = Context
+
+
+class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ":"
+ for diag in self.error_diagnostics:
+ s += (
+ "\nerror: "
+ + str(diag.location)[4:-1]
+ + ": "
+ + diag.message.replace("\n", "\n ")
)
- logger.warning(message, exc_info=True)
-
- logger.debug("Initializing MLIR with module: %s", module_name)
- if hasattr(m, "register_dialects"):
- logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
- if hasattr(m, "context_init_hook"):
- logger.debug("Adding context init hook from %r", m)
- post_init_hooks.append(m.context_init_hook)
- if hasattr(m, "disable_multithreading"):
- if bool(m.disable_multithreading):
- logger.debug("Disabling multi-threading for context")
- disable_multithreading = True
- return True
-
- # If _mlirRegisterEverything is built, then include it as an initializer
- # module.
- init_module = None
- if process_initializer_module("_mlirRegisterEverything"):
- init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
- # Load all _site_initialize_{i} modules, where 'i' is a number starting
- # at 0.
- for i in itertools.count():
- module_name = f"_site_initialize_{i}"
- if not process_initializer_module(module_name):
- break
-
- class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.append_dialect_registry(registry)
- for hook in post_init_hooks:
- hook(self)
- if not disable_multithreading:
- self.enable_multithreading(True)
- # TODO: There is some debate about whether we should eagerly load
- # all dialects. It is being done here in order to preserve existing
- # behavior. See: https://github.com/llvm/llvm-project/issues/56037
- self.load_all_available_dialects()
- if init_module:
- logger.debug(
- "Registering translations from initializer %r", init_module
- )
- init_module.register_llvm_translations(self)
-
- ir.Context = Context
-
- class MLIRError(Exception):
- """
- An exception with diagnostic information. Has the following fields:
- message: str
- error_diagnostics: List[ir.DiagnosticInfo]
- """
-
- def __init__(self, message, error_diagnostics):
- self.message = message
- self.error_diagnostics = error_diagnostics
- super().__init__(message, error_diagnostics)
-
- def __str__(self):
- s = self.message
- if self.error_diagnostics:
- s += ":"
- for diag in self.error_diagnostics:
+ for note in diag.notes:
s += (
- "\nerror: "
- + str(diag.location)[4:-1]
+ "\n note: "
+ + str(note.location)[4:-1]
+ ": "
- + diag.message.replace("\n", "\n ")
+ + note.message.replace("\n", "\n ")
)
- for note in diag.notes:
- s += (
- "\n note: "
- + str(note.location)[4:-1]
- + ": "
- + note.message.replace("\n", "\n ")
- )
- return s
-
- ir.MLIRError = MLIRError
+ return s
-_site_initialize()
+ir.MLIRError = MLIRError
>From b23938a217f9fb6e489b154a79c42182c4db4e34 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Thu, 16 Nov 2023 00:34:19 -0600
Subject: [PATCH 2/2] [mlir][python] hide everything in a namespace/module
---
mlir/python/mlir/_mlir_libs/__init__.py | 319 ++++++++++++------------
1 file changed, 166 insertions(+), 153 deletions(-)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index 468925d278c61dd..2b2306d121cbfd5 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -2,168 +2,181 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-from typing import Any, Sequence
-
-import os
-
-_this_dir = os.path.dirname(__file__)
-
-
-def get_lib_dirs() -> Sequence[str]:
- """Gets the lib directory for linking to shared libraries.
-
- On some platforms, the package may need to be built specially to export
- development libraries.
- """
- return [_this_dir]
-
-
-def get_include_dirs() -> Sequence[str]:
- """Gets the include directory for compiling against exported C libraries.
-
- Depending on how the package was build, development C libraries may or may
- not be present.
- """
- return [os.path.join(_this_dir, "include")]
-
-
-# Perform Python level site initialization. This involves:
-# 1. Attempting to load initializer modules, specific to the distribution.
-# 2. Defining the concrete mlir.ir.Context that does site specific
-# initialization.
-#
-# Aside from just being far more convenient to do this at the Python level,
-# it is actually quite hard/impossible to have such __init__ hooks, given
-# the pybind memory model (i.e. there is not a Python reference to the object
-# in the scope of the base class __init__).
-#
-# For #1, we:
-# a. Probe for modules named '_mlirRegisterEverything' and
-# '_site_initialize_{i}', where 'i' is a number starting at zero and
-# proceeding so long as a module with the name is found.
-# b. If the module has a 'register_dialects' attribute, it will be called
-# immediately with a DialectRegistry to populate.
-# c. If the module has a 'context_init_hook', it will be added to a list
-# of callbacks that are invoked as the last step of Context
-# initialization (and passed the Context under construction).
-# d. If the module has a 'disable_multithreading' attribute, it will be
-# taken as a boolean. If it is True for any initializer, then the
-# default behavior of enabling multithreading on the context
-# will be suppressed. This complies with the original behavior of all
-# contexts being created with multithreading enabled while allowing
-# this behavior to be changed if needed (i.e. if a context_init_hook
-# explicitly sets up multithreading).
-#
-# This facility allows downstreams to customize Context creation to their
-# needs.
import importlib
import itertools
import logging
+import os
+import sys
+from typing import Sequence
+
from ._mlir import ir
+
+_this_dir = os.path.dirname(__file__)
+
logger = logging.getLogger(__name__)
-registry = ir.DialectRegistry()
-post_init_hooks = []
-disable_multithreading = False
-
-
-def get_registry():
- return registry
-
-
-def process_initializer_module(module_name):
- global disable_multithreading
- try:
- m = importlib.import_module(f".{module_name}", __name__)
- except ModuleNotFoundError:
- return False
- except ImportError:
- message = (
- f"Error importing mlir initializer {module_name}. This may "
- "happen in unclean incremental builds but is likely a real bug if "
- "encountered otherwise and the MLIR Python API may not function."
- )
- logger.warning(message, exc_info=True)
-
- logger.debug("Initializing MLIR with module: %s", module_name)
- if hasattr(m, "register_dialects"):
- logger.debug("Registering dialects from initializer %r", m)
- m.register_dialects(registry)
- if hasattr(m, "context_init_hook"):
- logger.debug("Adding context init hook from %r", m)
- post_init_hooks.append(m.context_init_hook)
- if hasattr(m, "disable_multithreading"):
- if bool(m.disable_multithreading):
- logger.debug("Disabling multi-threading for context")
- disable_multithreading = True
- return True
-
-
-# If _mlirRegisterEverything is built, then include it as an initializer
-# module.
-init_module = None
-if process_initializer_module("_mlirRegisterEverything"):
- init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
-
-# Load all _site_initialize_{i} modules, where 'i' is a number starting
-# at 0.
-for i in itertools.count():
- module_name = f"_site_initialize_{i}"
- if not process_initializer_module(module_name):
- break
-
-
-class Context(ir._BaseContext):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.append_dialect_registry(get_registry())
- for hook in post_init_hooks:
- hook(self)
- if not disable_multithreading:
- self.enable_multithreading(True)
- # TODO: There is some debate about whether we should eagerly load
- # all dialects. It is being done here in order to preserve existing
- # behavior. See: https://github.com/llvm/llvm-project/issues/56037
- self.load_all_available_dialects()
- if init_module:
- logger.debug("Registering translations from initializer %r", init_module)
- init_module.register_llvm_translations(self)
-
-
-ir.Context = Context
-
-
-class MLIRError(Exception):
- """
- An exception with diagnostic information. Has the following fields:
- message: str
- error_diagnostics: List[ir.DiagnosticInfo]
- """
-
- def __init__(self, message, error_diagnostics):
- self.message = message
- self.error_diagnostics = error_diagnostics
- super().__init__(message, error_diagnostics)
-
- def __str__(self):
- s = self.message
- if self.error_diagnostics:
- s += ":"
- for diag in self.error_diagnostics:
- s += (
- "\nerror: "
- + str(diag.location)[4:-1]
- + ": "
- + diag.message.replace("\n", "\n ")
+
+_path = __path__
+_spec = __spec__
+_name = __name__
+
+
+class _M:
+ __path__ = _path
+ __spec__ = _spec
+ __name__ = _name
+
+ @staticmethod
+ def get_lib_dirs() -> Sequence[str]:
+ """Gets the lib directory for linking to shared libraries.
+
+ On some platforms, the package may need to be built specially to export
+ development libraries.
+ """
+ return [_this_dir]
+
+ @staticmethod
+ def get_include_dirs() -> Sequence[str]:
+ """Gets the include directory for compiling against exported C libraries.
+
+ Depending on how the package was build, development C libraries may or may
+ not be present.
+ """
+ return [os.path.join(_this_dir, "include")]
+
+ # Perform Python level site initialization. This involves:
+ # 1. Attempting to load initializer modules, specific to the distribution.
+ # 2. Defining the concrete mlir.ir.Context that does site specific
+ # initialization.
+ #
+ # Aside from just being far more convenient to do this at the Python level,
+ # it is actually quite hard/impossible to have such __init__ hooks, given
+ # the pybind memory model (i.e. there is not a Python reference to the object
+ # in the scope of the base class __init__).
+ #
+ # For #1, we:
+ # a. Probe for modules named '_mlirRegisterEverything' and
+ # '_site_initialize_{i}', where 'i' is a number starting at zero and
+ # proceeding so long as a module with the name is found.
+ # b. If the module has a 'register_dialects' attribute, it will be called
+ # immediately with a DialectRegistry to populate.
+ # c. If the module has a 'context_init_hook', it will be added to a list
+ # of callbacks that are invoked as the last step of Context
+ # initialization (and passed the Context under construction).
+ # d. If the module has a 'disable_multithreading' attribute, it will be
+ # taken as a boolean. If it is True for any initializer, then the
+ # default behavior of enabling multithreading on the context
+ # will be suppressed. This complies with the original behavior of all
+ # contexts being created with multithreading enabled while allowing
+ # this behavior to be changed if needed (i.e. if a context_init_hook
+ # explicitly sets up multithreading).
+ #
+ # This facility allows downstreams to customize Context creation to their
+ # needs.
+
+ __registry = ir.DialectRegistry()
+ __post_init_hooks = []
+ __disable_multithreading = False
+ from . import _mlir as _mlir
+
+ def __get_registry(self):
+ return self.__registry
+
+ def process_initializer_module(self, module_name):
+ try:
+ m = importlib.import_module(f".{module_name}", __name__)
+ except ModuleNotFoundError:
+ return False
+ except ImportError:
+ message = (
+ f"Error importing mlir initializer {module_name}. This may "
+ "happen in unclean incremental builds but is likely a real bug if "
+ "encountered otherwise and the MLIR Python API may not function."
)
- for note in diag.notes:
+ logger.warning(message, exc_info=True)
+
+ logger.debug("Initializing MLIR with module: %s", module_name)
+ if hasattr(m, "register_dialects"):
+ logger.debug("Registering dialects from initializer %r", m)
+ m.register_dialects(self.__get_registry())
+ if hasattr(m, "context_init_hook"):
+ logger.debug("Adding context init hook from %r", m)
+ self.__post_init_hooks.append(m.context_init_hook)
+ if hasattr(m, "disable_multithreading"):
+ if bool(m.disable_multithreading):
+ logger.debug("Disabling multi-threading for context")
+ self.__disable_multithreading = True
+ return True
+
+ def __init__(self):
+ # If _mlirRegisterEverything is built, then include it as an initializer
+ # module.
+ init_module = None
+ if self.process_initializer_module("_mlirRegisterEverything"):
+ init_module = importlib.import_module(f"._mlirRegisterEverything", __name__)
+
+ # Load all _site_initialize_{i} modules, where 'i' is a number starting
+ # at 0.
+ for i in itertools.count():
+ module_name = f"_site_initialize_{i}"
+ if not self.process_initializer_module(module_name):
+ break
+
+ that = self
+
+ class Context(ir._BaseContext):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.append_dialect_registry(that._M__get_registry())
+ for hook in that._M__post_init_hooks:
+ hook(self)
+ if not that._M__disable_multithreading:
+ self.enable_multithreading(True)
+ # TODO: There is some debate about whether we should eagerly load
+ # all dialects. It is being done here in order to preserve existing
+ # behavior. See: https://github.com/llvm/llvm-project/issues/56037
+ self.load_all_available_dialects()
+ if init_module:
+ logger.debug(
+ "Registering translations from initializer %r", init_module
+ )
+ init_module.register_llvm_translations(self)
+
+ ir.Context = Context
+
+ class MLIRError(Exception):
+ """
+ An exception with diagnostic information. Has the following fields:
+ message: str
+ error_diagnostics: List[ir.DiagnosticInfo]
+ """
+
+ def __init__(self, message, error_diagnostics):
+ self.message = message
+ self.error_diagnostics = error_diagnostics
+ super().__init__(message, error_diagnostics)
+
+ def __str__(self):
+ s = self.message
+ if self.error_diagnostics:
+ s += ":"
+ for diag in self.error_diagnostics:
s += (
- "\n note: "
- + str(note.location)[4:-1]
+ "\nerror: "
+ + str(diag.location)[4:-1]
+ ": "
- + note.message.replace("\n", "\n ")
+ + diag.message.replace("\n", "\n ")
)
- return s
+ for note in diag.notes:
+ s += (
+ "\n note: "
+ + str(note.location)[4:-1]
+ + ": "
+ + note.message.replace("\n", "\n ")
+ )
+ return s
+
+ ir.MLIRError = MLIRError
-ir.MLIRError = MLIRError
+sys.modules[__name__] = _M()
More information about the Mlir-commits
mailing list