[Mlir-commits] [mlir] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (PR #130109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 8 05:37:22 PST 2025
https://github.com/vfdev-5 updated https://github.com/llvm/llvm-project/pull/130109
>From 922415d694262c71a8c78fa3cd6244ff34280171 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 14:46:46 +0100
Subject: [PATCH 1/4] Added PyThreadPool as wrapper around MlirLlvmThreadPool
in MLIR python bindings
---
mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++++++
mlir/lib/Bindings/Python/IRModule.h | 20 +++++++++++++++++++-
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// __init__.py will subclass it with site-specific functionality and set a
// "Context" attribute on this module.
//----------------------------------------------------------------------------
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
nb::class_<PyMlirContext>(m, "_BaseContext")
.def("__init__",
[](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirContextEnableMultithreading(self.get(), enable);
},
nb::arg("enable"))
+ .def("set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ mlirContextSetThreadPool(self.get(), pool.get());
+ })
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
namespace mlir {
namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
FrameKind frameKind;
};
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+ PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+ }
+ PyThreadPool(const PyThreadPool &) = delete;
+ PyThreadPool(PyThreadPool &&) = delete;
+
+ int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+ MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+ std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
>From 5b6ca6aa0661d9efe4abf2db48e59d16f6192435 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 15:58:15 +0100
Subject: [PATCH 2/4] Added a test
---
mlir/test/python/ir/context_lifecycle.py | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425e..93a98e7f8e9f5 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -47,3 +47,17 @@
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
assert c4 is c5
+c4 = None
+c5 = None
+gc.collect()
+
+# Create a global threadpool and use it in two contexts
+tp = mlir.ir.ThreadPool()
+assert tp.get_max_concurrency() > 0
+c5 = mlir.ir.Context()
+c5.enable_multithreading(False)
+c5.set_thread_pool(tp)
+c6 = mlir.ir.Context()
+c6.enable_multithreading(False)
+c6.set_thread_pool(tp)
+assert mlir.ir.Context._get_live_count() == 2
>From e64830122e0cac98b8a2832307f1134be63d49af Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Fri, 7 Mar 2025 11:44:27 +0100
Subject: [PATCH 3/4] Added get_num_threads and _mlir_thread_pool_ptr methods
to _BaseContext Added thread_pool arg to the constructor:
`mlir.ir.Context(thread_pool=tp)`
---
mlir/include/mlir-c/IR.h | 9 +++++++++
mlir/lib/Bindings/Python/IRCore.cpp | 18 +++++++++++++++++-
mlir/lib/Bindings/Python/IRModule.h | 7 +++++++
mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++
mlir/python/mlir/_mlir_libs/__init__.py | 9 +++++++--
mlir/test/python/ir/context_lifecycle.py | 15 ++++++++++++---
6 files changed, 60 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d562da1f90757..001660ee51311 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
MlirLlvmThreadPool threadPool);
+/// Gets the number of threads of the thread pool of the context when
+/// multithreading is enabled. Returns 1 if no multithreading.
+MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);
+
+/// Gets the thread pool of the context when enabled multithreading, otherwise
+/// an assertion is raised.
+MLIR_CAPI_EXPORTED MlirLlvmThreadPool
+mlirContextGetThreadPool(MlirContext context);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 1ec52a1a9bcd4..22d6d117573b9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2747,7 +2747,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Expose DefaultThreadPool to python
nb::class_<PyThreadPool>(m, "ThreadPool")
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
- .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
+ .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
nb::class_<PyMlirContext>(m, "_BaseContext")
.def("__init__",
@@ -2822,8 +2823,23 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("enable"))
.def("set_thread_pool",
[](PyMlirContext &self, PyThreadPool &pool) {
+ // we should disable multi-threading first before setting
+ // new thread pool otherwise the assert in
+ // MLIRContext::setThreadPool will be raised.
+ mlirContextEnableMultithreading(self.get(), false);
mlirContextSetThreadPool(self.get(), pool.get());
})
+ .def("get_num_threads",
+ [](PyMlirContext &self) {
+ return mlirContextGetNumThreads(self.get());
+ })
+ .def("_mlir_thread_pool_ptr",
+ [](PyMlirContext &self) {
+ MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
+ std::stringstream ss;
+ ss << pool.ptr;
+ return ss.str();
+ })
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b7bbd646d982e..9befcce725bb7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -11,6 +11,7 @@
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
#include <optional>
+#include <sstream>
#include <utility>
#include <vector>
@@ -172,6 +173,12 @@ class PyThreadPool {
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+ std::string _mlir_thread_pool_ptr() const {
+ std::stringstream ss;
+ ss << ownedThreadPool.get();
+ return ss.str();
+ }
+
private:
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
};
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 6cd9ba2aef233..649f3b7056fb0 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
unwrap(context)->setThreadPool(*unwrap(threadPool));
}
+unsigned mlirContextGetNumThreads(MlirContext context) {
+ return unwrap(context)->getNumThreads();
+}
+
+MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
+ return wrap(&unwrap(context)->getThreadPool());
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index d021dde05dd87..c480a0035313d 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -148,13 +148,18 @@ def process_initializer_module(module_name):
break
class Context(ir._BaseContext):
- def __init__(self, load_on_create_dialects=None, *args, **kwargs):
+ def __init__(
+ self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
+ ):
super().__init__(*args, **kwargs)
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
if not disable_multithreading:
- self.enable_multithreading(True)
+ if thread_pool is None:
+ self.enable_multithreading(True)
+ else:
+ self.set_thread_pool(thread_pool)
if load_on_create_dialects is not None:
logger.debug(
"Loading all dialects from load_on_create_dialects arg %r",
diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index 93a98e7f8e9f5..230db8277c8e7 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -55,9 +55,18 @@
tp = mlir.ir.ThreadPool()
assert tp.get_max_concurrency() > 0
c5 = mlir.ir.Context()
-c5.enable_multithreading(False)
c5.set_thread_pool(tp)
+assert c5.get_num_threads() == tp.get_max_concurrency()
+assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
c6 = mlir.ir.Context()
-c6.enable_multithreading(False)
c6.set_thread_pool(tp)
-assert mlir.ir.Context._get_live_count() == 2
+assert c6.get_num_threads() == tp.get_max_concurrency()
+assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
+c7 = mlir.ir.Context(thread_pool=tp)
+assert c7.get_num_threads() == tp.get_max_concurrency()
+assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
+assert mlir.ir.Context._get_live_count() == 3
+c5 = None
+c6 = None
+c7 = None
+gc.collect()
>From a557554c7f60f21f66d0c261fd45c4082c4269e9 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Sat, 8 Mar 2025 14:36:13 +0100
Subject: [PATCH 4/4] Raise error if disable_multithreading and thread_pool is
given
---
mlir/python/mlir/_mlir_libs/__init__.py | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py
index c480a0035313d..083a9075fe4c5 100644
--- a/mlir/python/mlir/_mlir_libs/__init__.py
+++ b/mlir/python/mlir/_mlir_libs/__init__.py
@@ -155,6 +155,13 @@ def __init__(
self.append_dialect_registry(get_dialect_registry())
for hook in post_init_hooks:
hook(self)
+ if disable_multithreading and thread_pool is not None:
+ raise ValueError(
+ "Context constructor has given thread_pool argument, "
+ "but disable_multithreading flag is True. "
+ "Please, set thread_pool argument to None or "
+ "set disable_multithreading flag to False."
+ )
if not disable_multithreading:
if thread_pool is None:
self.enable_multithreading(True)
More information about the Mlir-commits
mailing list