[Mlir-commits] [mlir] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (PR #130109)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 06:58:42 PST 2025


https://github.com/vfdev-5 updated https://github.com/llvm/llvm-project/pull/130109

>From 1d683d0a36d2e8dc0913455f6d66e22a3a4a0a6f Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 14:46:46 +0100
Subject: [PATCH 1/2] Added PyThreadPool as wrapper around MlirLlvmThreadPool
 in MLIR python bindings

---
 mlir/lib/Bindings/Python/IRCore.cpp | 10 ++++++++++
 mlir/lib/Bindings/Python/IRModule.h | 20 +++++++++++++++++++-
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   // __init__.py will subclass it with site-specific functionality and set a
   // "Context" attribute on this module.
   //----------------------------------------------------------------------------
+
+  // Expose DefaultThreadPool to python
+  nb::class_<PyThreadPool>(m, "ThreadPool")
+      .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+      .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
   nb::class_<PyMlirContext>(m, "_BaseContext")
       .def("__init__",
            [](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             mlirContextEnableMultithreading(self.get(), enable);
           },
           nb::arg("enable"))
+      .def("set_thread_pool",
+           [](PyMlirContext &self, PyThreadPool &pool) {
+             mlirContextSetThreadPool(self.get(), pool.get());
+           })
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
 
 namespace mlir {
 namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
   FrameKind frameKind;
 };
 
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+  PyThreadPool() {
+    ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+  }
+  PyThreadPool(const PyThreadPool &) = delete;
+  PyThreadPool(PyThreadPool &&) = delete;
+
+  int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+  MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
 /// Wrapper around MlirContext.
 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {

>From 6e0960a56b8a3aa3755e07eed7c748b7f0bfddef Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 6 Mar 2025 15:58:15 +0100
Subject: [PATCH 2/2] Added a test

---
 mlir/test/python/ir/context_lifecycle.py | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425e..93a98e7f8e9f5 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -47,3 +47,17 @@
 assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
 c5 = mlir.ir.Context._CAPICreate(c4_capsule)
 assert c4 is c5
+c4 = None
+c5 = None
+gc.collect()
+
+# Create a global threadpool and use it in two contexts
+tp = mlir.ir.ThreadPool()
+assert tp.get_max_concurrency() > 0
+c5 = mlir.ir.Context()
+c5.enable_multithreading(False)
+c5.set_thread_pool(tp)
+c6 = mlir.ir.Context()
+c6.enable_multithreading(False)
+c6.set_thread_pool(tp)
+assert mlir.ir.Context._get_live_count() == 2



More information about the Mlir-commits mailing list