[Mlir-commits] [mlir] [mlir][python] Allow contexts to be created with a custom thread pool. (PR #72042)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 11 17:26:16 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Stella Laurenzo (stellaraccident)

<details>
<summary>Changes</summary>

The existing initialization sequence always enables multi-threading at MLIRContext construction time, making it impractical to provide a customized thread pool.

Here, this is changed to always create the context with threading disabled, process all site-specific init hooks (which can set thread pools) and ultimately enable multi-threading unless if site-configured to not do so.

This should preserve the existing user-visible initialization behavior while also letting downstreams ensure that contexts are always created with a shared thread pool. This was tested with IREE, which has such a concept. Using site-specific thread tuning produced up to 2x single compilation job behavior and customization of batch compilation (i.e. as part of a build system) to utilize half the memory and run the entire test suite ~2x faster. Given this, I believe that the additional configurability can well pay for itself for implementations that use it. We may also want to present user-level Python APIs for controlling threading configuration in the future.

---
Full diff: https://github.com/llvm/llvm-project/pull/72042.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+1-1) 
- (modified) mlir/python/mlir/_mlir_libs/__init__.py (+15) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 0f2ca666ccc050e..745aa64e63b67d4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -597,7 +597,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
 }
 
 PyMlirContext *PyMlirContext::createNewContextForInit() {
-  MlirContext context = mlirContextCreate();
+  MlirContext context = mlirContextCreateWithThreading(false);
   return new PyMlirContext(context);
 }
 
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index d5fc447e49bf3a6..6ce77b4cb93f609 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -46,6 +46,13 @@ def get_include_dirs() -> Sequence[str]:
 #   c. If the module has a 'context_init_hook', it will be added to a list
 #     of callbacks that are invoked as the last step of Context
 #     initialization (and passed the Context under construction).
+#   d. If the module has a 'disable_multithreading' attribute, it will be
+#     taken as a boolean. If it is True for any initializer, then the
+#     default behavior of enabling multithreading on the context
+#     will be suppressed. This complies with the original behavior of all
+#     contexts being created with multithreading enabled while allowing
+#     this behavior to be changed if needed (i.e. if a context_init_hook
+#     explicitly sets up multithreading).
 #
 # This facility allows downstreams to customize Context creation to their
 # needs.
@@ -58,8 +65,10 @@ def _site_initialize():
     logger = logging.getLogger(__name__)
     registry = ir.DialectRegistry()
     post_init_hooks = []
+    disable_multithreading = False
 
     def process_initializer_module(module_name):
+        nonlocal disable_multithreading
         try:
             m = importlib.import_module(f".{module_name}", __name__)
         except ModuleNotFoundError:
@@ -79,6 +88,10 @@ def process_initializer_module(module_name):
         if hasattr(m, "context_init_hook"):
             logger.debug("Adding context init hook from %r", m)
             post_init_hooks.append(m.context_init_hook)
+        if hasattr(m, "disable_multithreading"):
+            if bool(m.disable_multithreading):
+                logger.debug("Disabling multi-threading for context")
+                disable_multithreading = True
         return True
 
     # If _mlirRegisterEverything is built, then include it as an initializer
@@ -100,6 +113,8 @@ def __init__(self, *args, **kwargs):
             self.append_dialect_registry(registry)
             for hook in post_init_hooks:
                 hook(self)
+            if not disable_multithreading:
+                self.enable_multithreading(True)
             # TODO: There is some debate about whether we should eagerly load
             # all dialects. It is being done here in order to preserve existing
             # behavior. See: https://github.com/llvm/llvm-project/issues/56037

``````````

</details>


https://github.com/llvm/llvm-project/pull/72042


More information about the Mlir-commits mailing list