[Mlir-commits] [mlir] Added free-threading CPython mode support in MLIR Python bindings (PR #107103)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 14 04:58:27 PST 2024
https://github.com/vfdev-5 updated https://github.com/llvm/llvm-project/pull/107103
>From a5ed9d06afaefaeabd4c1b7208ccfb6531aced34 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Tue, 3 Sep 2024 15:02:12 +0200
Subject: [PATCH 1/3] Added free-threading CPython mode support in Python
bindings - temporarily updated requirements
---
mlir/lib/Bindings/Python/AsyncPasses.cpp | 4 +++-
mlir/lib/Bindings/Python/DialectGPU.cpp | 2 +-
mlir/lib/Bindings/Python/DialectLLVM.cpp | 2 +-
mlir/lib/Bindings/Python/DialectLinalg.cpp | 2 +-
mlir/lib/Bindings/Python/DialectNVGPU.cpp | 2 +-
mlir/lib/Bindings/Python/DialectPDL.cpp | 2 +-
mlir/lib/Bindings/Python/DialectQuant.cpp | 2 +-
mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 2 +-
mlir/lib/Bindings/Python/DialectTransform.cpp | 2 +-
mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 2 +-
mlir/lib/Bindings/Python/GPUPasses.cpp | 4 +++-
mlir/lib/Bindings/Python/LinalgPasses.cpp | 4 +++-
mlir/lib/Bindings/Python/MainModule.cpp | 2 +-
mlir/lib/Bindings/Python/RegisterEverything.cpp | 2 +-
mlir/lib/Bindings/Python/SparseTensorPasses.cpp | 4 +++-
mlir/lib/Bindings/Python/TransformInterpreter.cpp | 2 +-
mlir/python/requirements.txt | 8 +++++---
17 files changed, 29 insertions(+), 19 deletions(-)
diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp
index b611a758dbbb37..d34a164b6e30c2 100644
--- a/mlir/lib/Bindings/Python/AsyncPasses.cpp
+++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp
@@ -11,11 +11,13 @@
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
+namespace py = pybind11;
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlirAsyncPasses, m) {
+PYBIND11_MODULE(_mlirAsyncPasses, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Async Dialect Passes";
// Register all Async passes on load.
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 560a54bcd15919..5acfad007c3055 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -23,7 +23,7 @@ using namespace mlir::python::adaptors;
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlirDialectsGPU, m) {
+PYBIND11_MODULE(_mlirDialectsGPU, m, py::mod_gil_not_used()) {
m.doc() = "MLIR GPU Dialect";
//===-------------------------------------------------------------------===//
// AsyncTokenType
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index 42a4c8c0793ba8..2af133a061eb4b 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -134,7 +134,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
});
}
-PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+PYBIND11_MODULE(_mlirDialectsLLVM, m, py::mod_gil_not_used()) {
m.doc() = "MLIR LLVM Dialect";
populateDialectLLVMSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 2e54ebeb61fb10..118c4ab3f2f573 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -21,7 +21,7 @@ static void populateDialectLinalgSubmodule(py::module m) {
"op.");
}
-PYBIND11_MODULE(_mlirDialectsLinalg, m) {
+PYBIND11_MODULE(_mlirDialectsLinalg, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Linalg dialect.";
populateDialectLinalgSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 754e0a75b0abc7..4c962c403082cb 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -34,7 +34,7 @@ static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
py::arg("ctx") = py::none());
}
-PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
+PYBIND11_MODULE(_mlirDialectsNVGPU, m, py::mod_gil_not_used()) {
m.doc() = "MLIR NVGPU dialect.";
populateDialectNVGPUSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 8d3f9a7ab1d6ac..e8542d5e777a65 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -100,7 +100,7 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
py::arg("context") = py::none());
}
-PYBIND11_MODULE(_mlirDialectsPDL, m) {
+PYBIND11_MODULE(_mlirDialectsPDL, m, py::mod_gil_not_used()) {
m.doc() = "MLIR PDL dialect.";
populateDialectPDLSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index af9cdc7bdd2d89..fc6ef9f46ce8e5 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -307,7 +307,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
});
}
-PYBIND11_MODULE(_mlirDialectsQuant, m) {
+PYBIND11_MODULE(_mlirDialectsQuant, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Quantization dialect";
populateDialectQuantSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index a730bf500be98c..00d8482a91df7a 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -143,7 +143,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
});
}
-PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
+PYBIND11_MODULE(_mlirDialectsSparseTensor, m, py::mod_gil_not_used()) {
m.doc() = "MLIR SparseTensor dialect.";
populateDialectSparseTensorSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 6b57e652aa9d8b..df665dd66bdc28 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -117,7 +117,7 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
"Get the type this ParamType is associated with.");
}
-PYBIND11_MODULE(_mlirDialectsTransform, m) {
+PYBIND11_MODULE(_mlirDialectsTransform, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Transform dialect.";
populateDialectTransformSubmodule(m);
}
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..ddd81d1e7d592e 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -64,7 +64,7 @@ class PyExecutionEngine {
} // namespace
/// Create the `mlir.execution_engine` module here.
-PYBIND11_MODULE(_mlirExecutionEngine, m) {
+PYBIND11_MODULE(_mlirExecutionEngine, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Execution Engine";
//----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp
index e276a3ce3a56a0..bfc43e99946bb9 100644
--- a/mlir/lib/Bindings/Python/GPUPasses.cpp
+++ b/mlir/lib/Bindings/Python/GPUPasses.cpp
@@ -11,11 +11,13 @@
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
+namespace py = pybind11;
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlirGPUPasses, m) {
+PYBIND11_MODULE(_mlirGPUPasses, m, py::mod_gil_not_used()) {
m.doc() = "MLIR GPU Dialect Passes";
// Register all GPU passes on load.
diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp
index 3f230207a42114..e3d8f237e2bab3 100644
--- a/mlir/lib/Bindings/Python/LinalgPasses.cpp
+++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp
@@ -10,11 +10,13 @@
#include <pybind11/pybind11.h>
+namespace py = pybind11;
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlirLinalgPasses, m) {
+PYBIND11_MODULE(_mlirLinalgPasses, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Linalg Dialect Passes";
// Register all Linalg passes on load.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 8da1ab16a4514b..de713e7031a01e 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -22,7 +22,7 @@ using namespace mlir::python;
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlir, m) {
+PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Python Native Extension";
py::class_<PyGlobals>(m, "_Globals", py::module_local())
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 6b2f6b0a6a3b86..5c5c6e32102712 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -9,7 +9,7 @@
#include "mlir-c/RegisterEverything.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
-PYBIND11_MODULE(_mlirRegisterEverything, m) {
+PYBIND11_MODULE(_mlirRegisterEverything, m, py::mod_gil_not_used()) {
m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
m.def("register_dialects", [](MlirDialectRegistry registry) {
diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
index 2a8e2b802df9c4..1bbdf2f3ccf4e5 100644
--- a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
+++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
@@ -10,11 +10,13 @@
#include <pybind11/pybind11.h>
+namespace py = pybind11;
+
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
-PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
+PYBIND11_MODULE(_mlirSparseTensorPasses, m, py::mod_gil_not_used()) {
m.doc() = "MLIR SparseTensor Dialect Passes";
// Register all SparseTensor passes on load.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f6b4532b1b6be4..93ab447d52bec1 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -99,7 +99,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
py::arg("target"), py::arg("other"));
}
-PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+PYBIND11_MODULE(_mlirTransformInterpreter, m, py::mod_gil_not_used()) {
m.doc() = "MLIR Transform dialect interpreter functionality.";
populateTransformInterpreterSubmodule(m);
}
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index d1b5418cca5b23..49b8471c6b771c 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,4 +1,6 @@
-numpy>=1.19.5, <=1.26
-pybind11>=2.9.0, <=2.10.3
+numpy>=1.19.5, <3.0
+# pybind11>=2.14.0, <2.15.0
+# Temporarily set pybind11 version to master waiting the next release to 2.13.6
+pybind11 @ git+https://github.com/pybind/pybind11@master
PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.1.0, <=0.4.0 # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16
>From cdd2b03eb5dece62ea9c6edb2822100b886a7675 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Fri, 20 Sep 2024 22:30:49 +0200
Subject: [PATCH 2/3] [skip-ci] Added lock on PyGlobals::get and PyMlirContext
liveContexts WIP on adding multithreaded_tests
---
mlir/docs/Bindings/Python.md | 4 +-
.../standalone/python/StandaloneExtension.cpp | 4 +-
mlir/lib/Bindings/Python/Globals.h | 22 +++
mlir/lib/Bindings/Python/IRCore.cpp | 72 +++++---
mlir/lib/Bindings/Python/IRModule.h | 21 +++
mlir/test/python/execution_engine.py | 2 +-
mlir/test/python/lib/PythonTestModule.cpp | 2 +-
mlir/test/python/multithreaded_tests.py | 154 ++++++++++++++++++
8 files changed, 251 insertions(+), 30 deletions(-)
create mode 100644 mlir/test/python/multithreaded_tests.py
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 6e52c4deaad9aa..8cbbc44463db98 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView):
...
```
-expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
+expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
```python
@@ -1179,7 +1179,7 @@ make the passes available along with the dialect.
Dialect functionality other than IR objects or passes, such as helper functions,
can be exposed to Python similarly to attributes and types. C API is expected to
exist for this functionality, which can then be wrapped using pybind11 and
-`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`
+[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)
utilities to connect to the rest of Python API. The bindings can be located in a
separate pybind11 module or in the same module as attributes and types, and
loaded along with the dialect.
diff --git a/mlir/examples/standalone/python/StandaloneExtension.cpp b/mlir/examples/standalone/python/StandaloneExtension.cpp
index 5e83060cd48d82..2e30acaec756aa 100644
--- a/mlir/examples/standalone/python/StandaloneExtension.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtension.cpp
@@ -9,9 +9,11 @@
#include "Standalone-c/Dialects.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
+namespace py = pybind11;
+
using namespace mlir::python::adaptors;
-PYBIND11_MODULE(_standaloneDialects, m) {
+PYBIND11_MODULE(_standaloneDialects, m, py::mod_gil_not_used()) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index a022067f5c7e57..05400608ba9ffa 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -36,6 +36,20 @@ class PyGlobals {
return *instance;
}
+ template<typename F>
+ static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
+ auto &instance = get();
+#ifdef Py_GIL_DISABLED
+ auto &lock = getLock();
+ PyMutex_Lock(&lock);
+#endif
+ auto result = cb(instance);
+#ifdef Py_GIL_DISABLED
+ PyMutex_Unlock(&lock);
+#endif
+ return result;
+ }
+
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
std::vector<std::string> &getDialectSearchPrefixes() {
@@ -125,6 +139,14 @@ class PyGlobals {
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
+
+#ifdef Py_GIL_DISABLED
+ static PyMutex &getLock() {
+ static PyMutex lock;
+ return lock;
+ }
+#endif
+
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c12f75e7d224a8..d1c6696559d1d7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -192,7 +192,9 @@ py::object classmethod(Func f, Args... args) {
static py::object
createCustomDialectWrapper(const std::string &dialectNamespace,
py::object dialectDescriptor) {
- auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+ auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
+ return instance.lookupDialectClass(dialectNamespace);
+ });
if (!dialectClass) {
// Use the base class.
return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -595,8 +597,10 @@ class PyOpOperandIterator {
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
py::gil_scoped_acquire acquire;
- auto &liveContexts = getLiveContexts();
- liveContexts[context.ptr] = this;
+ withLiveContexts([&](LiveContextMap& liveContexts) {
+ liveContexts[context.ptr] = this;
+ return this;
+ });
}
PyMlirContext::~PyMlirContext() {
@@ -604,7 +608,12 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
py::gil_scoped_acquire acquire;
- getLiveContexts().erase(context.ptr);
+
+ withLiveContexts([&](LiveContextMap& liveContexts) {
+ liveContexts.erase(context.ptr);
+ return this;
+ });
+
mlirContextDestroy(context);
}
@@ -626,19 +635,20 @@ PyMlirContext *PyMlirContext::createNewContextForInit() {
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
py::gil_scoped_acquire acquire;
- auto &liveContexts = getLiveContexts();
- auto it = liveContexts.find(context.ptr);
- if (it == liveContexts.end()) {
- // Create.
- PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
- py::object pyRef = py::cast(unownedContextWrapper);
- assert(pyRef && "cast to py::object failed");
- liveContexts[context.ptr] = unownedContextWrapper;
- return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
- }
- // Use existing.
- py::object pyRef = py::cast(it->second);
- return PyMlirContextRef(it->second, std::move(pyRef));
+ return withLiveContexts([&](LiveContextMap& liveContexts) {
+ auto it = liveContexts.find(context.ptr);
+ if (it == liveContexts.end()) {
+ // Create.
+ PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+ py::object pyRef = py::cast(unownedContextWrapper);
+ assert(pyRef && "cast to py::object failed");
+ liveContexts[context.ptr] = unownedContextWrapper;
+ return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+ }
+ // Use existing.
+ py::object pyRef = py::cast(it->second);
+ return PyMlirContextRef(it->second, std::move(pyRef));
+ });
}
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
@@ -646,7 +656,11 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
return liveContexts;
}
-size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveCount() {
+ return withLiveContexts([&](LiveContextMap& liveContexts) {
+ return liveContexts.size();
+ });
+}
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
@@ -1550,8 +1564,10 @@ py::object PyOperation::createOpView() {
checkValid();
MlirIdentifier ident = mlirOperationGetName(get());
MlirStringRef identStr = mlirIdentifierStr(ident);
- auto operationCls = PyGlobals::get().lookupOperationClass(
- StringRef(identStr.data, identStr.length));
+ auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
+ return instance.lookupOperationClass(
+ StringRef(identStr.data, identStr.length));
+ });
if (operationCls)
return PyOpView::constructDerived(*operationCls, *getRef().get());
return py::cast(PyOpView(getRef().getObject()));
@@ -2002,7 +2018,9 @@ pybind11::object PyValue::maybeDownCast() {
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> valueCaster =
- PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+ PyGlobals::withInstance([&](PyGlobals& instance) {
+ return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+ });
// py::return_value_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3481,8 +3499,10 @@ void mlir::python::populateIRCore(py::module &m) {
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> typeCaster =
- PyGlobals::get().lookupTypeCaster(mlirTypeID,
- mlirAttributeGetDialect(self));
+ PyGlobals::withInstance([&](PyGlobals& instance){
+ return instance.lookupTypeCaster(mlirTypeID,
+ mlirAttributeGetDialect(self));
+ });
if (!typeCaster)
return py::cast(self);
return typeCaster.value()(self);
@@ -3579,9 +3599,11 @@ void mlir::python::populateIRCore(py::module &m) {
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
- std::optional<pybind11::function> typeCaster =
- PyGlobals::get().lookupTypeCaster(mlirTypeID,
+ std::optional<pybind11::function> typeCaster =
+ PyGlobals::withInstance([&](PyGlobals& instance){
+ return instance.lookupTypeCaster(mlirTypeID,
mlirTypeGetDialect(self));
+ });
if (!typeCaster)
return py::cast(self);
return typeCaster.value()(self);
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 172898cfda0c52..cbafed86c3ea85 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -263,6 +263,27 @@ class PyMlirContext {
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();
+#ifdef Py_GIL_DISABLED
+ static PyMutex &getLock() {
+ static PyMutex lock;
+ return lock;
+ }
+#endif
+
+ template<typename F>
+ static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
+ auto &liveContexts = getLiveContexts();
+#ifdef Py_GIL_DISABLED
+ auto &lock = getLock();
+ PyMutex_Lock(&lock);
+#endif
+ auto result = cb(liveContexts);
+#ifdef Py_GIL_DISABLED
+ PyMutex_Unlock(&lock);
+#endif
+ return result;
+ }
+
// Interns all live modules associated with this context. Modules tracked
// in this map are valid. When a module is invalidated, it is removed
// from this map, and while it still exists as an instance, any
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 7c375ce81de0eb..1435ae0f7569ed 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -306,7 +306,7 @@ def callback(a):
log(arr)
with Context():
- # The module takes a subview of the argument memref, casts it to an unranked memref and
+ # The module takes a subview of the argument memref, casts it to an unranked memref and
# calls the callback with it.
module = Module.parse(
r"""
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index a4f538dcb55944..f92994868114e7 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -21,7 +21,7 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
-PYBIND11_MODULE(_mlirPythonTest, m) {
+PYBIND11_MODULE(_mlirPythonTest, m, py::mod_gil_not_used()) {
m.def(
"register_python_test_dialect",
[](MlirContext context, bool load) {
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
new file mode 100644
index 00000000000000..fa545861cb0d7d
--- /dev/null
+++ b/mlir/test/python/multithreaded_tests.py
@@ -0,0 +1,154 @@
+import concurrent.futures
+import functools
+import importlib.util
+import sys
+import threading
+import tempfile
+
+from collections import defaultdict
+from contextlib import redirect_stderr, redirect_stdout
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
+import mlir.dialects.arith as arith
+from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint
+
+
+
+def import_from_path(module_name: str, file_path: Path):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def copy_and_update(src_filepath: Path, dst_filepath: Path):
+ # We should remove all calls like `run(testMethod)`
+ with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
+ while True:
+ src_line = reader.readline()
+ if len(src_line) == 0:
+ break
+ skip_lines = [
+ "run(",
+ "@run",
+ "@constructAndPrintInModule",
+ ]
+ if any(src_line.startswith(line) for line in skip_lines):
+ continue
+ writer.write(src_line)
+
+
+test_modules = [
+ "execution_engine",
+ # "pass_manager",
+]
+
+
+def add_existing_tests(test_prefix: str = "_original_test"):
+ def decorator(test_cls):
+ this_folder = Path(__file__).parent.absolute()
+ test_cls.output_folder = tempfile.TemporaryDirectory()
+ output_folder = Path(test_cls.output_folder.name)
+
+ for test_module_name in test_modules:
+ src_filepath = this_folder / f"{test_module_name}.py"
+ dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
+ if not dst_filepath.parent.exists():
+ dst_filepath.parent.mkdir(parents=True)
+ copy_and_update(src_filepath, dst_filepath)
+ test_mod = import_from_path(test_module_name, dst_filepath)
+ for attr_name in dir(test_mod):
+ if attr_name.startswith("test"):
+ obj = getattr(test_mod, attr_name)
+ if callable(obj):
+ test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
+ def wrapped_test_fn(*args, __test_fn__=obj, **kwargs):
+ __test_fn__()
+
+ setattr(test_cls, test_name, wrapped_test_fn)
+ return test_cls
+ return decorator
+
+
+def multi_threaded(
+ num_workers: int,
+ num_runs: int = 5,
+ skip_tests: Optional[list[str]] = None,
+ test_prefix: str = "_original_test",
+):
+ """Decorator that runs a test in a multi-threaded environment."""
+ def decorator(test_cls):
+ for name, test_fn in test_cls.__dict__.copy().items():
+ if not (name.startswith(test_prefix) and callable(test_fn)):
+ continue
+
+ name = f"test{name[len(test_prefix):]}"
+ if skip_tests is not None:
+ if any(test_name in name for test_name in skip_tests):
+ continue
+
+ def multi_threaded_test_fn(self, capfd, *args, __test_fn__=test_fn, **kwargs):
+ barrier = threading.Barrier(num_workers)
+
+ def closure():
+ barrier.wait()
+ for _ in range(num_runs):
+ __test_fn__(self, *args, **kwargs)
+
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=num_workers
+ ) as executor:
+ futures = []
+ for _ in range(num_workers):
+ futures.append(executor.submit(closure))
+ # We should call future.result() to re-raise an exception if test has
+ # failed
+ list(f.result() for f in futures)
+
+ captured = capfd.readouterr()
+ if len(captured.err) > 0:
+ if "ThreadSanitizer" in captured.err:
+ raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
+ else:
+ raise RuntimeError(f"Other error:\n{captured.err}")
+
+ setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)
+
+ return test_cls
+ return decorator
+
+
+ at multi_threaded(num_workers=4, num_runs=10)
+ at add_existing_tests(test_prefix="_original_test")
+class TestAllMultiThreaded:
+ @pytest.fixture(scope='class')
+ def teardown(self):
+ self.output_folder.cleanup()
+
+ def _original_test_create_context(self):
+ with Context() as ctx:
+ print(ctx._get_live_count())
+ print(ctx._get_live_module_count())
+ print(ctx._get_live_operation_count())
+ print(ctx._get_live_operation_objects())
+ print(ctx._get_context_again() is ctx)
+ print(ctx._clear_live_operations())
+
+ def _original_test_create_module_with_consts(self):
+ py_values = [123, 234, 345]
+ with Context() as ctx:
+ module = Module.create(loc=Location.file("foo.txt", 0, 0))
+
+ dtype = IntegerType.get_signless(64)
+ with InsertionPoint(module.body), Location.name("a"):
+ arith.constant(dtype, py_values[0])
+
+ with InsertionPoint(module.body), Location.name("b"):
+ arith.constant(dtype, py_values[1])
+
+ with InsertionPoint(module.body), Location.name("c"):
+ arith.constant(dtype, py_values[2])
>From 93b63b7ec7a79b40488ebf77986600fc17a41af5 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 14 Nov 2024 13:57:04 +0100
Subject: [PATCH 3/3] [skip-ci] More tests and added a lock to
_cext.register_operation
---
mlir/lib/Bindings/Python/MainModule.cpp | 9 +++-
mlir/test/python/multithreaded_tests.py | 65 ++++++++++++++++++++++---
2 files changed, 66 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index de713e7031a01e..1cf01e4d894db0 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -73,9 +73,14 @@ PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
[dialectClass, replace](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
- PyGlobals::get().registerOperationImpl(operationName, opClass,
- replace);
+ // Use PyGlobals::withInstance instead of PyGlobals::get()
+ // to prevent data race in multi-threaded context
+ // Error raised in ir/opeation.py testKnownOpView test
+ PyGlobals::withInstance([&](PyGlobals& instance) {
+ instance.registerOperationImpl(operationName, opClass, replace);
+ return 0;
+ });
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index fa545861cb0d7d..87d779ee303454 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -1,5 +1,6 @@
import concurrent.futures
import functools
+import gc
import importlib.util
import sys
import threading
@@ -42,9 +43,55 @@ def copy_and_update(src_filepath: Path, dst_filepath: Path):
writer.write(src_line)
+def run(f):
+ f()
+
+
+def constructAndPrintInModule(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ f()
+ print(module)
+
+
+def run_with_context_and_location(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ f()
+ return f
+
+
test_modules = [
- "execution_engine",
- # "pass_manager",
+ ("execution_engine", run), # Fail
+ ("pass_manager", run), # Fail
+
+ # Dialects tests
+ ("dialects/affine", constructAndPrintInModule), # Fail
+ ("dialects/vector", run_with_context_and_location), # Fail
+
+ # IR tests
+ ("ir/affine_expr", run), # Pass
+ ("ir/affine_map", run), # Pass
+ ("ir/array_attributes", run), # Pass
+ ("ir/attributes", run), # Pass
+ ("ir/blocks", run), # Pass
+ ("ir/builtin_types", run), # Pass
+ ("ir/context_managers", run), # Pass
+ ("ir/debug", run), # Fail
+ ("ir/diagnostic_handler", run), # Fail
+ ("ir/dialects", run), # Fail
+ ("ir/exception", run), # Fail
+ ("ir/insertion_point", run), # Pass
+ ("ir/insertion_point", run), # Pass
+ ("ir/integer_set", run), # Pass
+ ("ir/location", run), # Pass
+ ("ir/module", run), # Pass but may fail randomly on mlirOperationDump in testParseSuccess
+ ("ir/operation", run), # Pass
+ ("ir/symbol_table", run), # Pass
+ ("ir/value", run), # Fail/Crash
+
]
@@ -54,7 +101,7 @@ def decorator(test_cls):
test_cls.output_folder = tempfile.TemporaryDirectory()
output_folder = Path(test_cls.output_folder.name)
- for test_module_name in test_modules:
+ for test_module_name, exec_fn in test_modules:
src_filepath = this_folder / f"{test_module_name}.py"
dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
if not dst_filepath.parent.exists():
@@ -66,8 +113,8 @@ def decorator(test_cls):
obj = getattr(test_mod, attr_name)
if callable(obj):
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
- def wrapped_test_fn(*args, __test_fn__=obj, **kwargs):
- __test_fn__()
+ def wrapped_test_fn(self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs):
+ __exec_fn__(__test_fn__)
setattr(test_cls, test_name, wrapped_test_fn)
return test_cls
@@ -99,6 +146,10 @@ def closure():
for _ in range(num_runs):
__test_fn__(self, *args, **kwargs)
+ barrier.wait()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
) as executor:
@@ -114,7 +165,9 @@ def closure():
if "ThreadSanitizer" in captured.err:
raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
else:
- raise RuntimeError(f"Other error:\n{captured.err}")
+ pass
+ # There are tests that write to stderr, we should ignore them
+ # raise RuntimeError(f"Other error:\n{captured.err}")
setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)
More information about the Mlir-commits
mailing list