[Mlir-commits] [mlir] Added PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (PR #130109)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 6 05:51:15 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: vfdev (vfdev-5)
<details>
<summary>Changes</summary>
cc @<!-- -->hawkinsp
---
Full diff: https://github.com/llvm/llvm-project/pull/130109.diff
2 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+10)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+19-1)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..1ec52a1a9bcd4 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2743,6 +2743,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// __init__.py will subclass it with site-specific functionality and set a
// "Context" attribute on this module.
//----------------------------------------------------------------------------
+
+ // Expose DefaultThreadPool to python
+ nb::class_<PyThreadPool>(m, "ThreadPool")
+ .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
+ .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency);
+
nb::class_<PyMlirContext>(m, "_BaseContext")
.def("__init__",
[](PyMlirContext &self) {
@@ -2814,6 +2820,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
mlirContextEnableMultithreading(self.get(), enable);
},
nb::arg("enable"))
+ .def("set_thread_pool",
+ [](PyMlirContext &self, PyThreadPool &pool) {
+ mlirContextSetThreadPool(self.get(), pool.get());
+ })
.def(
"is_registered_operation",
[](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1ed6240a6ca69..b7bbd646d982e 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -22,9 +22,10 @@
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/ThreadPool.h"
namespace mlir {
namespace python {
@@ -158,6 +159,23 @@ class PyThreadContextEntry {
FrameKind frameKind;
};
+/// Wrapper around MlirLlvmThreadPool
+/// Python object owns the C++ thread pool
+class PyThreadPool {
+public:
+ PyThreadPool() {
+ ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
+ }
+ PyThreadPool(const PyThreadPool &) = delete;
+ PyThreadPool(PyThreadPool &&) = delete;
+
+ int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
+ MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
+
+private:
+ std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
+};
+
/// Wrapper around MlirContext.
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
``````````
</details>
https://github.com/llvm/llvm-project/pull/130109
More information about the Mlir-commits
mailing list