[Mlir-commits] [mlir] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (PR #130109)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 6 05:51:15 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: vfdev (vfdev-5)

<details>
<summary>Changes</summary>

cc @<!-- -->hawkinsp 

---
Full diff: https://github.com/llvm/llvm-project/pull/130109.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+10) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+19-1) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   // __init__.py will subclass it with site-specific functionality and set a
   // "Context" attribute on this module.
   //----------------------------------------------------------------------------
+
+  // Expose DefaultThreadPool to python
+  nb::class_<PyThreadPool>(m, "ThreadPool")
+      .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+      .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
   nb::class_<PyMlirContext>(m, "_BaseContext")
       .def("__init__",
            [](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             mlirContextEnableMultithreading(self.get(), enable);
           },
           nb::arg("enable"))
+      .def("set_thread_pool",
+           [](PyMlirContext &self, PyThreadPool &pool) {
+             mlirContextSetThreadPool(self.get(), pool.get());
+           })
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
 #include "mlir-c/IR.h"
 #include "mlir-c/IntegerSet.h"
 #include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
 
 namespace mlir {
 namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
   FrameKind frameKind;
 };
 
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+  PyThreadPool() {
+    ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+  }
+  PyThreadPool(const PyThreadPool &) = delete;
+  PyThreadPool(PyThreadPool &&) = delete;
+
+  int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+  MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+  std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
 /// Wrapper around MlirContext.
 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
 class PyMlirContext {

``````````

</details>


https://github.com/llvm/llvm-project/pull/130109


More information about the Mlir-commits mailing list