[Mlir-commits] [mlir] ab18cc2 - [MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 10 03:19:27 PDT 2025


Author: vfdev
Date: 2025-03-10T11:19:23+01:00
New Revision: ab18cc246c2490564043161db5d9646cf1163de4

URL: https://github.com/llvm/llvm-project/commit/ab18cc246c2490564043161db5d9646cf1163de4
DIFF: https://github.com/llvm/llvm-project/commit/ab18cc246c2490564043161db5d9646cf1163de4.diff

LOG: [MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109)

In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:

https://github.com/jax-ml/jax/blob/623865fe9538100d877ba9d36f788d0f95a11ed2/jax/_src/interpreters/mlir.py#L606-L611

However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/python/mlir/_mlir_libs/__init__.py
    mlir/test/python/ir/context_lifecycle.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 7fd6a41fb435b..1a8e8737f7fed 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
 MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
                                                  MlirLlvmThreadPool threadPool);
 
+/// Gets the number of threads of the thread pool of the context when
+/// multithreading is enabled. Returns 1 if no multithreading.
+MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);
+
+/// Gets the thread pool of the context when enabled multithreading, otherwise
+/// an assertion is raised.
+MLIR_CAPI_EXPORTED MlirLlvmThreadPool
+mlirContextGetThreadPool(MlirContext context);
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 9fd061d1c8dd9..78ba144acf1e9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   // __init__.py will subclass it with site-specific functionality and set a
   // "Context" attribute on this module.
   //----------------------------------------------------------------------------
+
+  // Expose DefaultThreadPool to python
+  nb::class_<PyThreadPool>(m, "ThreadPool")
+      .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+      .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
+      .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
+
   nb::class_<PyMlirContext>(m, "_BaseContext")
       .def("__init__",
            [](PyMlirContext &self) {
@@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             mlirContextEnableMultithreading(self.get(), enable);
           },
           nb::arg("enable"))
+      .def("set_thread_pool",
+           [](PyMlirContext &self, PyThreadPool &pool) {
+             // we should disable multi-threading first before setting
+             // new thread pool otherwise the assert in
+             // MLIRContext::setThreadPool will be raised.
+             mlirContextEnableMultithreading(self.get(), false);
+             mlirContextSetThreadPool(self.get(), pool.get());
+           })
+      .def("get_num_threads",
+           [](PyMlirContext &self) {
+             return mlirContextGetNumThreads(self.get());
+           })
+      .def("_mlir_thread_pool_ptr",
+           [](PyMlirContext &self) {
+             MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+             std::stringstream ss;
+             ss << pool.ptr;
+             return ss.str();
+           })
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..9befcce725bb7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -11,6 +11,7 @@
 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
 
 #include <optional>
+#include <sstream>
 #include <utility>
 #include <vector>
 
@@ -22,9 +23,10 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
 
 namespace mlir {
 namespace python {
@@ -158,6 +160,29 @@ class PyThreadContextEntry {
   FrameKind frameKind;
 };
 
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+  PyThreadPool() {
+    ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+  }
+  PyThreadPool(const PyThreadPool &) = delete;
+  PyThreadPool(PyThreadPool &&) = delete;
+
+  int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+  MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+  std::string _mlir_thread_pool_ptr() const {
+    std::stringstream ss;
+    ss << ownedThreadPool.get();
+    return ss.str();
+  }
+
+private:
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
 /// Wrapper around MlirContext.
 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 378d7d739ba62..e0e386d55ede1 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
   unwrap(context)->setThreadPool(*unwrap(threadPool));
 }
 
+unsigned mlirContextGetNumThreads(MlirContext context) {
+  return unwrap(context)->getNumThreads();
+}
+
+MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
+  return wrap(&unwrap(context)->getThreadPool());
+}
+
 //===----------------------------------------------------------------------===//
 // Dialect API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index d021dde05dd87..083a9075fe4c5 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -148,13 +148,25 @@ def process_initializer_module(module_name):
             break
 
     class Context(ir._BaseContext):
-        def __init__(self, load_on_create_dialects=None, *args, **kwargs):
+        def __init__(
+            self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
+        ):
             super().__init__(*args, **kwargs)
             self.append_dialect_registry(get_dialect_registry())
             for hook in post_init_hooks:
                 hook(self)
+            if disable_multithreading and thread_pool is not None:
+                raise ValueError(
+                    "Context constructor has given thread_pool argument, "
+                    "but disable_multithreading flag is True. "
+                    "Please, set thread_pool argument to None or "
+                    "set disable_multithreading flag to False."
+                )
             if not disable_multithreading:
-                self.enable_multithreading(True)
+                if thread_pool is None:
+                    self.enable_multithreading(True)
+                else:
+                    self.set_thread_pool(thread_pool)
             if load_on_create_dialects is not None:
                 logger.debug(
                     "Loading all dialects from load_on_create_dialects arg %r",

diff  --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425e..230db8277c8e7 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -47,3 +47,26 @@
 assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
 c5 = mlir.ir.Context._CAPICreate(c4_capsule)
 assert c4 is c5
+c4 = None
+c5 = None
+gc.collect()
+
+# Create a global threadpool and use it in two contexts
+tp = mlir.ir.ThreadPool()
+assert tp.get_max_concurrency() > 0
+c5 = mlir.ir.Context()
+c5.set_thread_pool(tp)
+assert c5.get_num_threads() == tp.get_max_concurrency()
+assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
+c6 = mlir.ir.Context()
+c6.set_thread_pool(tp)
+assert c6.get_num_threads() == tp.get_max_concurrency()
+assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
+c7 = mlir.ir.Context(thread_pool=tp)
+assert c7.get_num_threads() == tp.get_max_concurrency()
+assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
+assert mlir.ir.Context._get_live_count() == 3
+c5 = None
+c6 = None
+c7 = None
+gc.collect()


        


More information about the Mlir-commits mailing list