[Mlir-commits] [mlir] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (PR #130109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 6 06:58:42 PST 2025
https://github.com/vfdev-5 updated https://github.com/llvm/llvm-project/pull/130109
>From 1d683d0a36d2e8dc0913455f6d66e22a3a4a0a6f Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 14:46:46 +0100
Subject: [PATCH 1/2] Added PyThreadPool as wrapper around MlirLlvmThreadPool
in MLIR python bindings
---
mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++++++
mlir/lib/Bindings/Python/IRModule.h | 20 +++++++++++++++++++-
2 files changed, 29 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// __init__.py will subclass it with site-specific functionality and set a
// "Context" attribute on this module.
//----------------------------------------------------------------------------
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
nb::class_<PyMlirContext>(m, "_BaseContext")
.def("__init__",
[](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirContextEnableMultithreading(self.get(), enable);
},
nb::arg("enable"))
+ .def("set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ mlirContextSetThreadPool(self.get(), pool.get());
+ })
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
namespace mlir {
namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
FrameKind frameKind;
};
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+ PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+ }
+ PyThreadPool(const PyThreadPool &) = delete;
+ PyThreadPool(PyThreadPool &&) = delete;
+
+ int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+ MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+ std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
>From 6e0960a56b8a3aa3755e07eed7c748b7f0bfddef Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 15:58:15 +0100
Subject: [PATCH 2/2] Added a test
---
mlir/test/python/ir/context_lifecycle.py | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425e..93a98e7f8e9f5 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -47,3 +47,17 @@
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
assert c4 is c5
+c4 = None
+c5 = None
+gc.collect()
+
+# Create a global threadpool and use it in two contexts
+tp = mlir.ir.ThreadPool()
+assert tp.get_max_concurrency() > 0
+c5 = mlir.ir.Context()
+c5.enable_multithreading(False)
+c5.set_thread_pool(tp)
+c6 = mlir.ir.Context()
+c6.enable_multithreading(False)
+c6.set_thread_pool(tp)
+assert mlir.ir.Context._get_live_count() == 2
More information about the Mlir-commits
mailing list