[Mlir-commits] [mlir] 0677e54 - [mlir][python] Allow contexts to be created with a custom thread pool. (#72042)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 11 21:42:00 PST 2023
Author: Stella Laurenzo
Date: 2023-11-11T21:41:56-08:00
New Revision: 0677e54653e593ee90bb747fd75605f0bed47137
URL: https://github.com/llvm/llvm-project/commit/0677e54653e593ee90bb747fd75605f0bed47137
DIFF: https://github.com/llvm/llvm-project/commit/0677e54653e593ee90bb747fd75605f0bed47137.diff
LOG: [mlir][python] Allow contexts to be created with a custom thread pool. (#72042)
The existing initialization sequence always enables multi-threading at
MLIRContext construction time, making it impractical to provide a
customized thread pool.
Here, this is changed to always create the context with threading
disabled, process all site-specific init hooks (which can set thread
pools) and ultimately enable multi-threading unless if site-configured
to not do so.
This should preserve the existing user-visible initialization behavior
while also letting downstreams ensure that contexts are always created
with a shared thread pool. This was tested with IREE, which has such a
concept. Using site-specific thread tuning produced up to 2x single
compilation job behavior and customization of batch compilation (i.e. as
part of a build system) to utilize half the memory and run the entire
test suite ~2x faster. Given this, I believe that the additional
configurability can well pay for itself for implementations that use it.
We may also want to present user-level Python APIs for controlling
threading configuration in the future.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/python/mlir/_mlir_libs/__init__.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..745aa64e63b67d4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -597,7 +597,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
}
PyMlirContext *PyMlirContext::createNewContextForInit() {
- MlirContext context = mlirContextCreate();
+ MlirContext context = mlirContextCreateWithThreading(false);
return new PyMlirContext(context);
}
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index d5fc447e49bf3a6..6ce77b4cb93f609 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -46,6 +46,13 @@ def get_include_dirs() -> Sequence[str]:
# c. If the module has a 'context_init_hook', it will be added to a list
# of callbacks that are invoked as the last step of Context
# initialization (and passed the Context under construction).
+# d. If the module has a 'disable_multithreading' attribute, it will be
+# taken as a boolean. If it is True for any initializer, then the
+# default behavior of enabling multithreading on the context
+# will be suppressed. This complies with the original behavior of all
+# contexts being created with multithreading enabled while allowing
+# this behavior to be changed if needed (i.e. if a context_init_hook
+# explicitly sets up multithreading).
#
# This facility allows downstreams to customize Context creation to their
# needs.
@@ -58,8 +65,10 @@ def _site_initialize():
logger = logging.getLogger(__name__)
registry = ir.DialectRegistry()
post_init_hooks = []
+ disable_multithreading = False
def process_initializer_module(module_name):
+ nonlocal disable_multithreading
try:
m = importlib.import_module(f".{module_name}", __name__)
except ModuleNotFoundError:
@@ -79,6 +88,10 @@ def process_initializer_module(module_name):
if hasattr(m, "context_init_hook"):
logger.debug("Adding context init hook from %r", m)
post_init_hooks.append(m.context_init_hook)
+ if hasattr(m, "disable_multithreading"):
+ if bool(m.disable_multithreading):
+ logger.debug("Disabling multi-threading for context")
+ disable_multithreading = True
return True
# If _mlirRegisterEverything is built, then include it as an initializer
@@ -100,6 +113,8 @@ def __init__(self, *args, **kwargs):
self.append_dialect_registry(registry)
for hook in post_init_hooks:
hook(self)
+ if not disable_multithreading:
+ self.enable_multithreading(True)
# TODO: There is some debate about whether we should eagerly load
# all dialects. It is being done here in order to preserve existing
# behavior. See: https://github.com/llvm/llvm-project/issues/56037
More information about the Mlir-commits
mailing list