[Mlir-commits] [mlir] 0677e54 - [mlir][python] Allow contexts to be created with a custom thread pool. (#72042)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 11 21:42:00 PST 2023


Author: Stella Laurenzo
Date: 2023-11-11T21:41:56-08:00
New Revision: 0677e54653e593ee90bb747fd75605f0bed47137

URL: https://github.com/llvm/llvm-project/commit/0677e54653e593ee90bb747fd75605f0bed47137
DIFF: https://github.com/llvm/llvm-project/commit/0677e54653e593ee90bb747fd75605f0bed47137.diff

LOG: [mlir][python] Allow contexts to be created with a custom thread pool. (#72042)

The existing initialization sequence always enables multi-threading at
MLIRContext construction time, making it impractical to provide a
customized thread pool.

Here, this is changed to always create the context with threading
disabled, process all site-specific init hooks (which can set thread
pools) and ultimately enable multi-threading unless if site-configured
to not do so.

This should preserve the existing user-visible initialization behavior
while also letting downstreams ensure that contexts are always created
with a shared thread pool. This was tested with IREE, which has such a
concept. Using site-specific thread tuning produced up to 2x single
compilation job behavior and customization of batch compilation (i.e. as
part of a build system) to utilize half the memory and run the entire
test suite ~2x faster. Given this, I believe that the additional
configurability can well pay for itself for implementations that use it.
We may also want to present user-level Python APIs for controlling
threading configuration in the future.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/python/mlir/_mlir_libs/__init__.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..745aa64e63b67d4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -597,7 +597,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
 }
 
 PyMlirContext *PyMlirContext::createNewContextForInit() {
-  MlirContext context = mlirContextCreate();
+  MlirContext context = mlirContextCreateWithThreading(false);
   return new PyMlirContext(context);
 }
 

diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index d5fc447e49bf3a6..6ce77b4cb93f609 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -46,6 +46,13 @@ def get_include_dirs() -> Sequence[str]:
 #   c. If the module has a 'context_init_hook', it will be added to a list
 #     of callbacks that are invoked as the last step of Context
 #     initialization (and passed the Context under construction).
+#   d. If the module has a 'disable_multithreading' attribute, it will be
+#     taken as a boolean. If it is True for any initializer, then the
+#     default behavior of enabling multithreading on the context
+#     will be suppressed. This complies with the original behavior of all
+#     contexts being created with multithreading enabled while allowing
+#     this behavior to be changed if needed (i.e. if a context_init_hook
+#     explicitly sets up multithreading).
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
@@ -58,8 +65,10 @@ def _site_initialize():
     logger = logging.getLogger(__name__)
     registry = ir.DialectRegistry()
     post_init_hooks = []
+    disable_multithreading = False
 
     def process_initializer_module(module_name):
+        nonlocal disable_multithreading
         try:
             m = importlib.import_module(f".{module_name}", __name__)
         except ModuleNotFoundError:
@@ -79,6 +88,10 @@ def process_initializer_module(module_name):
         if hasattr(m, "context_init_hook"):
             logger.debug("Adding context init hook from %r", m)
             post_init_hooks.append(m.context_init_hook)
+        if hasattr(m, "disable_multithreading"):
+            if bool(m.disable_multithreading):
+                logger.debug("Disabling multi-threading for context")
+                disable_multithreading = True
         return True
 
     # If _mlirRegisterEverything is built, then include it as an initializer
@@ -100,6 +113,8 @@ def __init__(self, *args, **kwargs):
             self.append_dialect_registry(registry)
             for hook in post_init_hooks:
                 hook(self)
+            if not disable_multithreading:
+                self.enable_multithreading(True)
             # TODO: There is some debate about whether we should eagerly load
             # all dialects. It is being done here in order to preserve existing
             # behavior. See: https://github.com/llvm/llvm-project/issues/56037


        


More information about the Mlir-commits mailing list