[Mlir-commits] [mlir] Added free-threading CPython mode support in MLIR Python bindings (PR #107103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 10 10:34:59 PST 2024


https://github.com/vfdev-5 updated https://github.com/llvm/llvm-project/pull/107103

>From ad296c90aa5c3957415aa53e8a07b8f2a129b4da Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Tue, 3 Sep 2024 15:02:12 +0200
Subject: [PATCH 1/4] Added free-threading CPython mode support in Python
 bindings - temporarily updated requirements

---
 mlir/lib/Bindings/Python/AsyncPasses.cpp           | 4 +++-
 mlir/lib/Bindings/Python/DialectGPU.cpp            | 2 +-
 mlir/lib/Bindings/Python/DialectLLVM.cpp           | 2 +-
 mlir/lib/Bindings/Python/DialectLinalg.cpp         | 2 +-
 mlir/lib/Bindings/Python/DialectNVGPU.cpp          | 2 +-
 mlir/lib/Bindings/Python/DialectPDL.cpp            | 2 +-
 mlir/lib/Bindings/Python/DialectQuant.cpp          | 2 +-
 mlir/lib/Bindings/Python/DialectSparseTensor.cpp   | 2 +-
 mlir/lib/Bindings/Python/DialectTransform.cpp      | 2 +-
 mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 2 +-
 mlir/lib/Bindings/Python/GPUPasses.cpp             | 4 +++-
 mlir/lib/Bindings/Python/LinalgPasses.cpp          | 4 +++-
 mlir/lib/Bindings/Python/MainModule.cpp            | 2 +-
 mlir/lib/Bindings/Python/RegisterEverything.cpp    | 2 +-
 mlir/lib/Bindings/Python/SparseTensorPasses.cpp    | 4 +++-
 mlir/lib/Bindings/Python/TransformInterpreter.cpp  | 2 +-
 mlir/python/requirements.txt                       | 6 ++++--
 17 files changed, 28 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp
index b611a758dbbb37..d34a164b6e30c2 100644
--- a/mlir/lib/Bindings/Python/AsyncPasses.cpp
+++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp
@@ -11,11 +11,13 @@
 #include <pybind11/detail/common.h>
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirAsyncPasses, m) {
+PYBIND11_MODULE(_mlirAsyncPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Async Dialect Passes";
 
   // Register all Async passes on load.
diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp
index 560a54bcd15919..5acfad007c3055 100644
--- a/mlir/lib/Bindings/Python/DialectGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectGPU.cpp
@@ -23,7 +23,7 @@ using namespace mlir::python::adaptors;
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirDialectsGPU, m) {
+PYBIND11_MODULE(_mlirDialectsGPU, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR GPU Dialect";
   //===-------------------------------------------------------------------===//
   // AsyncTokenType
diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp
index cccf1370b8cc87..6981e5ed6b8427 100644
--- a/mlir/lib/Bindings/Python/DialectLLVM.cpp
+++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp
@@ -136,7 +136,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) {
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsLLVM, m) {
+PYBIND11_MODULE(_mlirDialectsLLVM, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR LLVM Dialect";
 
   populateDialectLLVMSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp
index 2e54ebeb61fb10..118c4ab3f2f573 100644
--- a/mlir/lib/Bindings/Python/DialectLinalg.cpp
+++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp
@@ -21,7 +21,7 @@ static void populateDialectLinalgSubmodule(py::module m) {
       "op.");
 }
 
-PYBIND11_MODULE(_mlirDialectsLinalg, m) {
+PYBIND11_MODULE(_mlirDialectsLinalg, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Linalg dialect.";
 
   populateDialectLinalgSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
index 754e0a75b0abc7..4c962c403082cb 100644
--- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp
+++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp
@@ -34,7 +34,7 @@ static void populateDialectNVGPUSubmodule(const pybind11::module &m) {
       py::arg("ctx") = py::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsNVGPU, m) {
+PYBIND11_MODULE(_mlirDialectsNVGPU, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR NVGPU dialect.";
 
   populateDialectNVGPUSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp
index 8d3f9a7ab1d6ac..e8542d5e777a65 100644
--- a/mlir/lib/Bindings/Python/DialectPDL.cpp
+++ b/mlir/lib/Bindings/Python/DialectPDL.cpp
@@ -100,7 +100,7 @@ void populateDialectPDLSubmodule(const pybind11::module &m) {
       py::arg("context") = py::none());
 }
 
-PYBIND11_MODULE(_mlirDialectsPDL, m) {
+PYBIND11_MODULE(_mlirDialectsPDL, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR PDL dialect.";
   populateDialectPDLSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp
index 9a871f2c122d12..e974ecb66effb8 100644
--- a/mlir/lib/Bindings/Python/DialectQuant.cpp
+++ b/mlir/lib/Bindings/Python/DialectQuant.cpp
@@ -309,7 +309,7 @@ static void populateDialectQuantSubmodule(const py::module &m) {
   });
 }
 
-PYBIND11_MODULE(_mlirDialectsQuant, m) {
+PYBIND11_MODULE(_mlirDialectsQuant, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Quantization dialect";
 
   populateDialectQuantSubmodule(m);
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index a730bf500be98c..00d8482a91df7a 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -143,7 +143,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
       });
 }
 
-PYBIND11_MODULE(_mlirDialectsSparseTensor, m) {
+PYBIND11_MODULE(_mlirDialectsSparseTensor, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR SparseTensor dialect.";
   populateDialectSparseTensorSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 6b57e652aa9d8b..df665dd66bdc28 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -117,7 +117,7 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get the type this ParamType is associated with.");
 }
 
-PYBIND11_MODULE(_mlirDialectsTransform, m) {
+PYBIND11_MODULE(_mlirDialectsTransform, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Transform dialect.";
   populateDialectTransformSubmodule(m);
 }
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index b3df30583fc963..ddd81d1e7d592e 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -64,7 +64,7 @@ class PyExecutionEngine {
 } // namespace
 
 /// Create the `mlir.execution_engine` module here.
-PYBIND11_MODULE(_mlirExecutionEngine, m) {
+PYBIND11_MODULE(_mlirExecutionEngine, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Execution Engine";
 
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp
index e276a3ce3a56a0..bfc43e99946bb9 100644
--- a/mlir/lib/Bindings/Python/GPUPasses.cpp
+++ b/mlir/lib/Bindings/Python/GPUPasses.cpp
@@ -11,11 +11,13 @@
 #include <pybind11/detail/common.h>
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirGPUPasses, m) {
+PYBIND11_MODULE(_mlirGPUPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR GPU Dialect Passes";
 
   // Register all GPU passes on load.
diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp
index 3f230207a42114..e3d8f237e2bab3 100644
--- a/mlir/lib/Bindings/Python/LinalgPasses.cpp
+++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp
@@ -10,11 +10,13 @@
 
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirLinalgPasses, m) {
+PYBIND11_MODULE(_mlirLinalgPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Linalg Dialect Passes";
 
   // Register all Linalg passes on load.
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7c27021902de31..168a939ff3cd8a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -22,7 +22,7 @@ using namespace mlir::python;
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlir, m) {
+PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Python Native Extension";
 
   py::class_<PyGlobals>(m, "_Globals", py::module_local())
diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp
index 6b2f6b0a6a3b86..5c5c6e32102712 100644
--- a/mlir/lib/Bindings/Python/RegisterEverything.cpp
+++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp
@@ -9,7 +9,7 @@
 #include "mlir-c/RegisterEverything.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
-PYBIND11_MODULE(_mlirRegisterEverything, m) {
+PYBIND11_MODULE(_mlirRegisterEverything, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration";
 
   m.def("register_dialects", [](MlirDialectRegistry registry) {
diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
index 2a8e2b802df9c4..1bbdf2f3ccf4e5 100644
--- a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
+++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp
@@ -10,11 +10,13 @@
 
 #include <pybind11/pybind11.h>
 
+namespace py = pybind11;
+
 // -----------------------------------------------------------------------------
 // Module initialization.
 // -----------------------------------------------------------------------------
 
-PYBIND11_MODULE(_mlirSparseTensorPasses, m) {
+PYBIND11_MODULE(_mlirSparseTensorPasses, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR SparseTensor Dialect Passes";
 
   // Register all SparseTensor passes on load.
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index 0c8c0e0a965aa7..d55efe13319b79 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -100,7 +100,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
       py::arg("target"), py::arg("other"));
 }
 
-PYBIND11_MODULE(_mlirTransformInterpreter, m) {
+PYBIND11_MODULE(_mlirTransformInterpreter, m, py::mod_gil_not_used()) {
   m.doc() = "MLIR Transform dialect interpreter functionality.";
   populateTransformInterpreterSubmodule(m);
 }
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index ab8a9122919e19..449748bb02cc22 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -1,5 +1,7 @@
 nanobind>=2.0, <3.0
 numpy>=1.19.5, <=2.1.2
-pybind11>=2.10.0, <=2.13.6
+# pybind11>=2.14.0, <2.15.0
+# Temporarily set pybind11 version to master waiting the next release to 2.13.6
+pybind11 @ git+https://github.com/pybind/pybind11@master
 PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.1.0, <=0.5.0   # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.5.0, <=0.6.0   # provides several NumPy dtype extensions, including the bf16

>From ce8e9d006b20302a9597a1303c93d367d016e30e Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Fri, 20 Sep 2024 22:30:49 +0200
Subject: [PATCH 2/4] [skip-ci] Added lock on PyGlobals::get and PyMlirContext
 liveContexts WIP on adding multithreaded_tests

---
 mlir/docs/Bindings/Python.md                  |   6 +-
 .../python/StandaloneExtensionPybind11.cpp    |   4 +-
 mlir/lib/Bindings/Python/Globals.h            |  22 +++
 mlir/lib/Bindings/Python/IRCore.cpp           |  72 +++++---
 mlir/lib/Bindings/Python/IRModule.h           |  21 +++
 mlir/test/python/execution_engine.py          |   2 +-
 .../python/lib/PythonTestModulePybind11.cpp   |   2 +-
 mlir/test/python/multithreaded_tests.py       | 154 ++++++++++++++++++
 mlir/test/python/test_to_remove.py            |  46 ++++++
 9 files changed, 298 insertions(+), 31 deletions(-)
 create mode 100644 mlir/test/python/multithreaded_tests.py
 create mode 100644 mlir/test/python/test_to_remove.py

diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index a0bd1cac118bad..32df3310d811d7 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1035,7 +1035,7 @@ class ConstantOp(_ods_ir.OpView):
     ...
 ```
 
-expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`). 
+expects `value` to be a `TypedAttr` (e.g., `IntegerAttr` or `FloatAttr`).
 Thus, a natural extension is a builder that accepts a MLIR type and a Python value and instantiates the appropriate `TypedAttr`:
 
 ```python
@@ -1181,9 +1181,9 @@ make the passes available along with the dialect.
 Dialect functionality other than IR objects or passes, such as helper functions,
 can be exposed to Python similarly to attributes and types. C API is expected to
 exist for this functionality, which can then be wrapped using pybind11 and
-`[include/mlir/Bindings/Python/PybindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)`,
+[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h),
 or nanobind and
-`[include/mlir/Bindings/Python/NanobindAdaptors.h](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)`
+[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
 utilities to connect to the rest of Python API. The bindings can be located in a
 separate module or in the same module as attributes and types, and
 loaded along with the dialect.
diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
index 397db4c20e7432..0bf8e150ee354c 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
@@ -12,9 +12,11 @@
 #include "Standalone-c/Dialects.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
+namespace py = pybind11;
+
 using namespace mlir::python::adaptors;
 
-PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
+PYBIND11_MODULE(_standaloneDialects, m, py::mod_gil_not_used()) {
   //===--------------------------------------------------------------------===//
   // standalone dialect
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index a022067f5c7e57..05400608ba9ffa 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -36,6 +36,20 @@ class PyGlobals {
     return *instance;
   }
 
+  template<typename F>
+  static inline auto withInstance(const F& cb) -> decltype(cb(get())) {
+    auto &instance = get();
+#ifdef Py_GIL_DISABLED
+    auto &lock = getLock();
+    PyMutex_Lock(&lock);
+#endif
+    auto result = cb(instance);
+#ifdef Py_GIL_DISABLED
+    PyMutex_Unlock(&lock);
+#endif
+    return result;
+  }
+
   /// Get and set the list of parent modules to search for dialect
   /// implementation classes.
   std::vector<std::string> &getDialectSearchPrefixes() {
@@ -125,6 +139,14 @@ class PyGlobals {
   /// Set of dialect namespaces that we have attempted to import implementation
   /// modules for.
   llvm::StringSet<> loadedDialectModules;
+
+#ifdef Py_GIL_DISABLED
+  static PyMutex &getLock() {
+    static PyMutex lock;
+    return lock;
+  }
+#endif
+
 };
 
 } // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3e96f8c60ba7cd..727cbc2e106d5b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -198,7 +198,9 @@ py::object classmethod(Func f, Args... args) {
 static py::object
 createCustomDialectWrapper(const std::string &dialectNamespace,
                            py::object dialectDescriptor) {
-  auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
+  auto dialectClass = PyGlobals::withInstance([&](PyGlobals& instance) {
+    return instance.lookupDialectClass(dialectNamespace);
+  });
   if (!dialectClass) {
     // Use the base class.
     return py::cast(PyDialect(std::move(dialectDescriptor)));
@@ -601,8 +603,10 @@ class PyOpOperandIterator {
 
 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
   py::gil_scoped_acquire acquire;
-  auto &liveContexts = getLiveContexts();
-  liveContexts[context.ptr] = this;
+  withLiveContexts([&](LiveContextMap& liveContexts) {
+    liveContexts[context.ptr] = this;
+    return this;
+  });
 }
 
 PyMlirContext::~PyMlirContext() {
@@ -610,7 +614,12 @@ PyMlirContext::~PyMlirContext() {
   // forContext method, which always puts the associated handle into
   // liveContexts.
   py::gil_scoped_acquire acquire;
-  getLiveContexts().erase(context.ptr);
+
+  withLiveContexts([&](LiveContextMap& liveContexts) {
+    liveContexts.erase(context.ptr);
+    return this;
+  });
+
   mlirContextDestroy(context);
 }
 
@@ -632,19 +641,20 @@ PyMlirContext *PyMlirContext::createNewContextForInit() {
 
 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   py::gil_scoped_acquire acquire;
-  auto &liveContexts = getLiveContexts();
-  auto it = liveContexts.find(context.ptr);
-  if (it == liveContexts.end()) {
-    // Create.
-    PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
-    py::object pyRef = py::cast(unownedContextWrapper);
-    assert(pyRef && "cast to py::object failed");
-    liveContexts[context.ptr] = unownedContextWrapper;
-    return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
-  }
-  // Use existing.
-  py::object pyRef = py::cast(it->second);
-  return PyMlirContextRef(it->second, std::move(pyRef));
+  return withLiveContexts([&](LiveContextMap& liveContexts) {
+    auto it = liveContexts.find(context.ptr);
+    if (it == liveContexts.end()) {
+      // Create.
+      PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+      py::object pyRef = py::cast(unownedContextWrapper);
+      assert(pyRef && "cast to py::object failed");
+      liveContexts[context.ptr] = unownedContextWrapper;
+      return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+    }
+    // Use existing.
+    py::object pyRef = py::cast(it->second);
+    return PyMlirContextRef(it->second, std::move(pyRef));
+  });
 }
 
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
@@ -652,7 +662,11 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
   return liveContexts;
 }
 
-size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveCount() {
+  return withLiveContexts([&](LiveContextMap& liveContexts) {
+    return liveContexts.size();
+  });
+}
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
@@ -1556,8 +1570,10 @@ py::object PyOperation::createOpView() {
   checkValid();
   MlirIdentifier ident = mlirOperationGetName(get());
   MlirStringRef identStr = mlirIdentifierStr(ident);
-  auto operationCls = PyGlobals::get().lookupOperationClass(
-      StringRef(identStr.data, identStr.length));
+  auto operationCls = PyGlobals::withInstance([&](PyGlobals& instance){
+      return instance.lookupOperationClass(
+          StringRef(identStr.data, identStr.length));
+  });
   if (operationCls)
     return PyOpView::constructDerived(*operationCls, *getRef().get());
   return py::cast(PyOpView(getRef().getObject()));
@@ -2008,7 +2024,9 @@ pybind11::object PyValue::maybeDownCast() {
   assert(!mlirTypeIDIsNull(mlirTypeID) &&
          "mlirTypeID was expected to be non-null.");
   std::optional<pybind11::function> valueCaster =
-      PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+      PyGlobals::withInstance([&](PyGlobals& instance) {
+        return instance.lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
+  });
   // py::return_value_policy::move means use std::move to move the return value
   // contents into a new instance that will be owned by Python.
   py::object thisObj = py::cast(this, py::return_value_policy::move);
@@ -3487,8 +3505,10 @@ void mlir::python::populateIRCore(py::module &m) {
         assert(!mlirTypeIDIsNull(mlirTypeID) &&
                "mlirTypeID was expected to be non-null.");
         std::optional<pybind11::function> typeCaster =
-            PyGlobals::get().lookupTypeCaster(mlirTypeID,
-                                              mlirAttributeGetDialect(self));
+            PyGlobals::withInstance([&](PyGlobals& instance){
+              return instance.lookupTypeCaster(mlirTypeID,
+                                               mlirAttributeGetDialect(self));
+            });
         if (!typeCaster)
           return py::cast(self);
         return typeCaster.value()(self);
@@ -3585,9 +3605,11 @@ void mlir::python::populateIRCore(py::module &m) {
              MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
              assert(!mlirTypeIDIsNull(mlirTypeID) &&
                     "mlirTypeID was expected to be non-null.");
-             std::optional<pybind11::function> typeCaster =
-                 PyGlobals::get().lookupTypeCaster(mlirTypeID,
+            std::optional<pybind11::function> typeCaster =
+                PyGlobals::withInstance([&](PyGlobals& instance){
+                  return instance.lookupTypeCaster(mlirTypeID,
                                                    mlirTypeGetDialect(self));
+                });
              if (!typeCaster)
                return py::cast(self);
              return typeCaster.value()(self);
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 172898cfda0c52..cbafed86c3ea85 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -263,6 +263,27 @@ class PyMlirContext {
   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
   static LiveContextMap &getLiveContexts();
 
+#ifdef Py_GIL_DISABLED
+  static PyMutex &getLock() {
+    static PyMutex lock;
+    return lock;
+  }
+#endif
+
+  template<typename F>
+  static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
+    auto &liveContexts = getLiveContexts();
+#ifdef Py_GIL_DISABLED
+    auto &lock = getLock();
+    PyMutex_Lock(&lock);
+#endif
+    auto result = cb(liveContexts);
+#ifdef Py_GIL_DISABLED
+    PyMutex_Unlock(&lock);
+#endif
+    return result;
+  }
+
   // Interns all live modules associated with this context. Modules tracked
   // in this map are valid. When a module is invalidated, it is removed
   // from this map, and while it still exists as an instance, any
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 7c375ce81de0eb..1435ae0f7569ed 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -306,7 +306,7 @@ def callback(a):
         log(arr)
 
     with Context():
-        # The module takes a subview of the argument memref, casts it to an unranked memref and 
+        # The module takes a subview of the argument memref, casts it to an unranked memref and
         # calls the callback with it.
         module = Module.parse(
             r"""
diff --git a/mlir/test/python/lib/PythonTestModulePybind11.cpp b/mlir/test/python/lib/PythonTestModulePybind11.cpp
index 94a5f5178d16e8..9b63ea07a0eaba 100644
--- a/mlir/test/python/lib/PythonTestModulePybind11.cpp
+++ b/mlir/test/python/lib/PythonTestModulePybind11.cpp
@@ -23,7 +23,7 @@ static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
          mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
 }
 
-PYBIND11_MODULE(_mlirPythonTestPybind11, m) {
+PYBIND11_MODULE(_mlirPythonTest, m, py::mod_gil_not_used()) {
   m.def(
       "register_python_test_dialect",
       [](MlirContext context, bool load) {
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
new file mode 100644
index 00000000000000..fa545861cb0d7d
--- /dev/null
+++ b/mlir/test/python/multithreaded_tests.py
@@ -0,0 +1,154 @@
+import concurrent.futures
+import functools
+import importlib.util
+import sys
+import threading
+import tempfile
+
+from collections import defaultdict
+from contextlib import redirect_stderr, redirect_stdout
+from pathlib import Path
+from typing import Optional
+
+import pytest
+
+import mlir.dialects.arith as arith
+from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint
+
+
+
+def import_from_path(module_name: str, file_path: Path):
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[module_name] = module
+    spec.loader.exec_module(module)
+    return module
+
+
+def copy_and_update(src_filepath: Path, dst_filepath: Path):
+    # We should remove all calls like `run(testMethod)`
+    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
+        while True:
+            src_line = reader.readline()
+            if len(src_line) == 0:
+                break
+            skip_lines = [
+                "run(",
+                "@run",
+                "@constructAndPrintInModule",
+            ]
+            if any(src_line.startswith(line) for line in skip_lines):
+                continue
+            writer.write(src_line)
+
+
+test_modules = [
+    "execution_engine",
+    # "pass_manager",
+]
+
+
+def add_existing_tests(test_prefix: str = "_original_test"):
+    def decorator(test_cls):
+        this_folder = Path(__file__).parent.absolute()
+        test_cls.output_folder = tempfile.TemporaryDirectory()
+        output_folder = Path(test_cls.output_folder.name)
+
+        for test_module_name in test_modules:
+            src_filepath = this_folder / f"{test_module_name}.py"
+            dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
+            if not dst_filepath.parent.exists():
+                dst_filepath.parent.mkdir(parents=True)
+            copy_and_update(src_filepath, dst_filepath)
+            test_mod = import_from_path(test_module_name, dst_filepath)
+            for attr_name in dir(test_mod):
+                if attr_name.startswith("test"):
+                    obj = getattr(test_mod, attr_name)
+                    if callable(obj):
+                        test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
+                        def wrapped_test_fn(*args, __test_fn__=obj, **kwargs):
+                            __test_fn__()
+
+                        setattr(test_cls, test_name, wrapped_test_fn)
+        return test_cls
+    return decorator
+
+
+def multi_threaded(
+    num_workers: int,
+    num_runs: int = 5,
+    skip_tests: Optional[list[str]] = None,
+    test_prefix: str = "_original_test",
+):
+    """Decorator that runs a test in a multi-threaded environment."""
+    def decorator(test_cls):
+        for name, test_fn in test_cls.__dict__.copy().items():
+            if not (name.startswith(test_prefix) and callable(test_fn)):
+                continue
+
+            name = f"test{name[len(test_prefix):]}"
+            if skip_tests is not None:
+                if any(test_name in name for test_name in skip_tests):
+                    continue
+
+            def multi_threaded_test_fn(self, capfd, *args, __test_fn__=test_fn, **kwargs):
+                barrier = threading.Barrier(num_workers)
+
+                def closure():
+                    barrier.wait()
+                    for _ in range(num_runs):
+                        __test_fn__(self, *args, **kwargs)
+
+                with concurrent.futures.ThreadPoolExecutor(
+                    max_workers=num_workers
+                ) as executor:
+                    futures = []
+                    for _ in range(num_workers):
+                        futures.append(executor.submit(closure))
+                    # We should call future.result() to re-raise an exception if test has
+                    # failed
+                    list(f.result() for f in futures)
+
+                captured = capfd.readouterr()
+                if len(captured.err) > 0:
+                    if "ThreadSanitizer" in captured.err:
+                        raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
+                    else:
+                        raise RuntimeError(f"Other error:\n{captured.err}")
+
+            setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)
+
+        return test_cls
+    return decorator
+
+
+ at multi_threaded(num_workers=4, num_runs=10)
+ at add_existing_tests(test_prefix="_original_test")
+class TestAllMultiThreaded:
+    @pytest.fixture(scope='class')
+    def teardown(self):
+        self.output_folder.cleanup()
+
+    def _original_test_create_context(self):
+        with Context() as ctx:
+            print(ctx._get_live_count())
+            print(ctx._get_live_module_count())
+            print(ctx._get_live_operation_count())
+            print(ctx._get_live_operation_objects())
+            print(ctx._get_context_again() is ctx)
+            print(ctx._clear_live_operations())
+
+    def _original_test_create_module_with_consts(self):
+        py_values = [123, 234, 345]
+        with Context() as ctx:
+            module = Module.create(loc=Location.file("foo.txt", 0, 0))
+
+            dtype = IntegerType.get_signless(64)
+            with InsertionPoint(module.body), Location.name("a"):
+                arith.constant(dtype, py_values[0])
+
+            with InsertionPoint(module.body), Location.name("b"):
+                arith.constant(dtype, py_values[1])
+
+            with InsertionPoint(module.body), Location.name("c"):
+                arith.constant(dtype, py_values[2])
diff --git a/mlir/test/python/test_to_remove.py b/mlir/test/python/test_to_remove.py
new file mode 100644
index 00000000000000..b2b2cc4b552fe9
--- /dev/null
+++ b/mlir/test/python/test_to_remove.py
@@ -0,0 +1,46 @@
+import concurrent.futures
+import threading
+import inspect
+
+
+def decorator(f):
+    # Introspect the callable for optional features.
+    sig = inspect.signature(f)
+    for param in sig.parameters.values():
+        pass
+
+    def emit_call_op(*call_args):
+        pass
+
+    wrapped = emit_call_op
+    return wrapped
+
+
+def test_dialects_vector_repro_3():
+    num_workers = 6
+    num_runs = 10
+    barrier = threading.Barrier(num_workers)
+
+    def closure():
+        barrier.wait()
+        for _ in range(num_runs):
+
+            @decorator
+            def print_vector(arg):
+                return 0
+
+        barrier.wait()
+
+    with concurrent.futures.ThreadPoolExecutor(
+        max_workers=num_workers
+    ) as executor:
+        futures = []
+        for _ in range(num_workers):
+            futures.append(executor.submit(closure))
+        # We should call future.result() to re-raise an exception if test has
+        # failed
+        assert len(list(f.result() for f in futures)) == num_workers
+
+
+if __name__ == "__main__":
+    test_dialects_vector_repro_3()
\ No newline at end of file

>From ba8ba2e94a65ba009ae51ed53da19060263dfa64 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Thu, 14 Nov 2024 13:57:04 +0100
Subject: [PATCH 3/4] [skip-ci] More tests and added a lock to
 _cext.register_operation

---
 mlir/lib/Bindings/Python/MainModule.cpp |  9 +++-
 mlir/test/python/multithreaded_tests.py | 65 ++++++++++++++++++++++---
 2 files changed, 66 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 168a939ff3cd8a..590d0590f51e4a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -73,9 +73,14 @@ PYBIND11_MODULE(_mlir, m, py::mod_gil_not_used()) {
             [dialectClass, replace](py::type opClass) -> py::type {
               std::string operationName =
                   opClass.attr("OPERATION_NAME").cast<std::string>();
-              PyGlobals::get().registerOperationImpl(operationName, opClass,
-                                                     replace);
 
+              // Use PyGlobals::withInstance instead of PyGlobals::get()
+              // to prevent data race in multi-threaded context
+              // Error raised in ir/opeation.py testKnownOpView test
+              PyGlobals::withInstance([&](PyGlobals& instance) {
+                  instance.registerOperationImpl(operationName, opClass, replace);
+                  return 0;
+              });
               // Dict-stuff the new opClass by name onto the dialect class.
               py::object opClassName = opClass.attr("__name__");
               dialectClass.attr(opClassName) = opClass;
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index fa545861cb0d7d..87d779ee303454 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -1,5 +1,6 @@
 import concurrent.futures
 import functools
+import gc
 import importlib.util
 import sys
 import threading
@@ -42,9 +43,55 @@ def copy_and_update(src_filepath: Path, dst_filepath: Path):
             writer.write(src_line)
 
 
+def run(f):
+    f()
+
+
+def constructAndPrintInModule(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+        print(module)
+
+
+def run_with_context_and_location(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
+
 test_modules = [
-    "execution_engine",
-    # "pass_manager",
+    ("execution_engine", run),  # Fail
+    ("pass_manager", run),  # Fail
+
+    # Dialects tests
+    ("dialects/affine", constructAndPrintInModule),  # Fail
+    ("dialects/vector", run_with_context_and_location),  # Fail
+
+    # IR tests
+    ("ir/affine_expr", run),  # Pass
+    ("ir/affine_map", run),  # Pass
+    ("ir/array_attributes", run),  # Pass
+    ("ir/attributes", run),  # Pass
+    ("ir/blocks", run),  # Pass
+    ("ir/builtin_types", run),  # Pass
+    ("ir/context_managers", run),  # Pass
+    ("ir/debug", run),  # Fail
+    ("ir/diagnostic_handler", run),  # Fail
+    ("ir/dialects", run),  # Fail
+    ("ir/exception", run),  # Fail
+    ("ir/insertion_point", run),  # Pass
+    ("ir/insertion_point", run),  # Pass
+    ("ir/integer_set", run),  # Pass
+    ("ir/location", run),  # Pass
+    ("ir/module", run),  # Pass but may fail randomly on mlirOperationDump in testParseSuccess
+    ("ir/operation", run),  # Pass
+    ("ir/symbol_table", run),  # Pass
+    ("ir/value", run),  # Fail/Crash
+
 ]
 
 
@@ -54,7 +101,7 @@ def decorator(test_cls):
         test_cls.output_folder = tempfile.TemporaryDirectory()
         output_folder = Path(test_cls.output_folder.name)
 
-        for test_module_name in test_modules:
+        for test_module_name, exec_fn in test_modules:
             src_filepath = this_folder / f"{test_module_name}.py"
             dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
             if not dst_filepath.parent.exists():
@@ -66,8 +113,8 @@ def decorator(test_cls):
                     obj = getattr(test_mod, attr_name)
                     if callable(obj):
                         test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
-                        def wrapped_test_fn(*args, __test_fn__=obj, **kwargs):
-                            __test_fn__()
+                        def wrapped_test_fn(self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs):
+                            __exec_fn__(__test_fn__)
 
                         setattr(test_cls, test_name, wrapped_test_fn)
         return test_cls
@@ -99,6 +146,10 @@ def closure():
                     for _ in range(num_runs):
                         __test_fn__(self, *args, **kwargs)
 
+                    barrier.wait()
+                    gc.collect()
+                    assert Context._get_live_count() == 0
+
                 with concurrent.futures.ThreadPoolExecutor(
                     max_workers=num_workers
                 ) as executor:
@@ -114,7 +165,9 @@ def closure():
                     if "ThreadSanitizer" in captured.err:
                         raise RuntimeError(f"ThreadSanitizer reported warnings:\n{captured.err}")
                     else:
-                        raise RuntimeError(f"Other error:\n{captured.err}")
+                        pass
+                        # There are tests that write to stderr, we should ignore them
+                        # raise RuntimeError(f"Other error:\n{captured.err}")
 
             setattr(test_cls, f"{name}_multi_threaded", multi_threaded_test_fn)
 

>From be524efe1559337aecf7c6794efda4d317b0362a Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Tue, 10 Dec 2024 19:33:19 +0100
Subject: [PATCH 4/4] [skip ci] Updated tests + labelled passing and failing
 tests

---
 mlir/python/requirements.txt            |   2 +-
 mlir/test/python/multithreaded_tests.py | 239 +++++++++++++++++++++---
 2 files changed, 219 insertions(+), 22 deletions(-)

diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 449748bb02cc22..66c6168d78540d 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -3,5 +3,5 @@ numpy>=1.19.5, <=2.1.2
 # pybind11>=2.14.0, <2.15.0
 # Temporarily set pybind11 version to master waiting the next release to 2.13.6
 pybind11 @ git+https://github.com/pybind/pybind11@master
-PyYAML>=5.4.0, <=6.0.1
+PyYAML>=5.4.0, <=6.0.2
 ml_dtypes>=0.5.0, <=0.6.0   # provides several NumPy dtype extensions, including the bf16
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 87d779ee303454..0216f72c5f52d2 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -14,10 +14,10 @@
 import pytest
 
 import mlir.dialects.arith as arith
+from mlir.dialects import transform
 from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint
 
 
-
 def import_from_path(module_name: str, file_path: Path):
     spec = importlib.util.spec_from_file_location(module_name, file_path)
     module = importlib.util.module_from_spec(spec)
@@ -37,6 +37,10 @@ def copy_and_update(src_filepath: Path, dst_filepath: Path):
                 "run(",
                 "@run",
                 "@constructAndPrintInModule",
+                "run_apply_patterns(",
+                "@run_apply_patterns",
+                "@test_in_context",
+                "@construct_and_print_in_module",
             ]
             if any(src_line.startswith(line) for line in skip_lines):
                 continue
@@ -47,31 +51,217 @@ def run(f):
     f()
 
 
-def constructAndPrintInModule(f):
+def run_with_context_and_location(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        f()
+    return f
+
+
+def run_with_insertion_point(f):
+    print("\nTEST:", f.__name__)
+    with Context() as ctx, Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f(ctx)
+        print(module)
+
+
+def run_with_insertion_point_v2(f):
     print("\nTEST:", f.__name__)
     with Context(), Location.unknown():
         module = Module.create()
         with InsertionPoint(module.body):
             f()
         print(module)
+    return f
 
 
-def run_with_context_and_location(f):
+def run_with_insertion_point_v3(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f(module)
+        print(module)
+    return f
+
+
+def run_with_insertion_point_v4(f):
     print("\nTEST:", f.__name__)
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        module = Module.create()
+        with InsertionPoint(module.body):
+            f()
+    return f
+
+
+def run_apply_patterns(f):
     with Context(), Location.unknown():
-        f()
+        module = Module.create()
+        with InsertionPoint(module.body):
+            sequence = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [],
+                transform.AnyOpType.get(),
+            )
+            with InsertionPoint(sequence.body):
+                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
+                with InsertionPoint(apply.patterns):
+                    f()
+                transform.YieldOp()
+        print("\nTEST:", f.__name__)
+        print(module)
+    return f
+
+
+def run_transform_tensor_ext(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            sequence = transform.SequenceOp(
+                transform.FailurePropagationMode.Propagate,
+                [],
+                transform.AnyOpType.get(),
+            )
+            with InsertionPoint(sequence.body):
+                f(sequence.bodyTarget)
+                transform.YieldOp()
+        print(module)
     return f
 
 
+def run_transform_structured_ext(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        module.operation.verify()
+        print(module)
+    return f
+
+
+def run_construct_and_print_in_module(f):
+    print("\nTEST:", f.__name__)
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            module = f(module)
+        if module is not None:
+            print(module)
+    return f
+
+
+# Python 3.13.1 experimental free-threading build (tags/v3.13.1:06714517797, Dec 10 2024, 00:18:06) [Clang 15.0.7 ]
+# numpy      2.3.0.dev0
+# nanobind   2.5.0.dev1 /tmp/jax/nanobind
+# pybind11   master
 test_modules = [
-    ("execution_engine", run),  # Fail
+    # Failed tests,
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testBF16Memref_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testBasicCallback_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testComplexMemrefAdd_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testDynamicMemrefAdd2D_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testF16MemrefAdd_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testF8E5M2Memref_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testInvalidModule_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testInvokeFloatAdd_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testMemrefAdd_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_execution_engine__testNanoTime_multi_threaded
+    ("execution_engine", run),  # Fail,
+
+    # Failed tests,
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_pass_manager__testPrintIrAfterAll_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_pass_manager__testPrintIrLargeLimitElements_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_pass_manager__testPrintIrTree_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_pass_manager__testRunPipeline_multi_threaded
     ("pass_manager", run),  # Fail
 
-    # Dialects tests
-    ("dialects/affine", constructAndPrintInModule),  # Fail
-    ("dialects/vector", run_with_context_and_location),  # Fail
+    # Dialects tests, 8 failed, 206 passed
+    # - Failing tests:
+    # - TSAN unrelatd: multithreaded_tests.py::TestAllMultiThreaded::test_dialects_arith_dialect__testArithValue_multi_threaded
+    #   RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_transform_structured_ext__testMatchInterfaceEnumReplaceAttributeBuilder_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_transform_interpreter__include_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_transform_interpreter__print_other_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_transform_interpreter__print_self_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_transform_interpreter__transform_options_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded
+    ("dialects/affine", run_with_insertion_point_v2),  # Pass
+    ("dialects/func", run_with_insertion_point_v2),  # Pass
+    ("dialects/arith_dialect", run),  # Fail
+    ("dialects/arith_llvm", run),  # Pass
+    ("dialects/async_dialect", run),  # Pass
+    ("dialects/builtin", run),  # Pass
+    ("dialects/cf", run_with_insertion_point_v4),  # Pass
+    ("dialects/complex_dialect", run),  # Pass
+    ("dialects/func", run_with_insertion_point_v2),  # Pass
+    ("dialects/index_dialect", run_with_insertion_point),  # Pass
+    ("dialects/llvm", run_with_insertion_point_v2),  # Pass
+    ("dialects/math_dialect", run),  # Pass
+    ("dialects/memref", run),  # Fail
+    ("dialects/ml_program", run_with_insertion_point_v2),  # Pass
+    ("dialects/nvgpu", run_with_insertion_point_v2),  # Pass
+    ("dialects/nvvm", run_with_insertion_point_v2),  # Pass
+    ("dialects/ods_helpers", run),  # Pass
+    ("dialects/openmp_ops", run_with_insertion_point_v2),  # Pass
+    ("dialects/pdl_ops", run_with_insertion_point_v2),  # Pass
+    # ("dialects/python_test", run),  # Need to pass pybind11 or nanobind argv
+    ("dialects/quant", run),  # Pass
+    ("dialects/rocdl", run_with_insertion_point_v2),  # Pass
+    ("dialects/scf", run_with_insertion_point_v2),  # Pass
+    ("dialects/shape", run),  # Pass
+    ("dialects/spirv_dialect", run),  # Pass
+    ("dialects/tensor", run),  # Pass
+    # ("dialects/tosa", ),  # Nothing to test
+    ("dialects/transform_bufferization_ext", run_with_insertion_point_v2),  # Pass
+    # ("dialects/transform_extras", ),  # Needs a more complicated execution schema
+    ("dialects/transform_gpu_ext", run_transform_tensor_ext),  # Pass
+    ("dialects/transform_interpreter", run_with_context_and_location, ["print_", "transform_options", "failed", "include"]),  # Fail
+    ("dialects/transform_loop_ext", run_with_insertion_point_v2, ["loopOutline"]),  # Pass
+    ("dialects/transform_memref_ext", run_with_insertion_point_v2),  # Pass
+    ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2),  # Pass
+    ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),  # Pass
+    ("dialects/transform_structured_ext", run_transform_structured_ext),  # Fail
+    ("dialects/transform_tensor_ext", run_transform_tensor_ext),  # Pass
+    ("dialects/transform_vector_ext", run_apply_patterns, ["configurable_patterns"]),  # Pass
+    ("dialects/transform", run_with_insertion_point_v3),  # Pass
+    ("dialects/vector", run_with_context_and_location),  # Pass
+
+    ("dialects/gpu/dialect", run_with_context_and_location),  # Pass
+    ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),  # Pass
+    ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),  # Fail
+
+    ("dialects/linalg/ops", run),  # Pass
+    # TO ADD:
+    # ("dialects/linalg/opsdsl/*", run),  #
+
+    ("dialects/sparse_tensor/dialect", run),  # Pass
+    ("dialects/sparse_tensor/passes", run),  # Pass
+
+    # Integration tests, 2 failed, 11 passed
+    # - Failing tests:
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded
+    ("integration/dialects/pdl", run_construct_and_print_in_module),  # Pass
+    ("integration/dialects/transform", run_construct_and_print_in_module),  # Pass
+    ("integration/dialects/linalg/opsrun", run),  # Fail
 
     # IR tests
+    # - Failing tests:
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_debug__testDebugDlag_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_dialects__testAppendPrefixSearchPath_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_module__testParseSuccess_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_operation__testKnownOpView_multi_threaded
+    # - multithreaded_tests.py::TestAllMultiThreaded::test_ir_value__testValueCasters_multi_threaded
+    # - Crashed: multithreaded_tests.py::TestAllMultiThreaded::test_ir_value__testValuePrintAsOperand_multi_threaded
     ("ir/affine_expr", run),  # Pass
     ("ir/affine_map", run),  # Pass
     ("ir/array_attributes", run),  # Pass
@@ -82,16 +272,14 @@ def run_with_context_and_location(f):
     ("ir/debug", run),  # Fail
     ("ir/diagnostic_handler", run),  # Fail
     ("ir/dialects", run),  # Fail
-    ("ir/exception", run),  # Fail
-    ("ir/insertion_point", run),  # Pass
+    ("ir/exception", run),  # Pass
     ("ir/insertion_point", run),  # Pass
     ("ir/integer_set", run),  # Pass
     ("ir/location", run),  # Pass
     ("ir/module", run),  # Pass but may fail randomly on mlirOperationDump in testParseSuccess
-    ("ir/operation", run),  # Pass
+    ("ir/operation", run),  # Fail
     ("ir/symbol_table", run),  # Pass
     ("ir/value", run),  # Fail/Crash
-
 ]
 
 
@@ -101,7 +289,14 @@ def decorator(test_cls):
         test_cls.output_folder = tempfile.TemporaryDirectory()
         output_folder = Path(test_cls.output_folder.name)
 
-        for test_module_name, exec_fn in test_modules:
+        for test_mod_info in test_modules:
+            assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
+            if len(test_mod_info) == 2:
+                test_module_name, exec_fn = test_mod_info
+                test_pattern = None
+            else:
+                test_module_name, exec_fn, test_pattern = test_mod_info
+
             src_filepath = this_folder / f"{test_module_name}.py"
             dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
             if not dst_filepath.parent.exists():
@@ -109,7 +304,9 @@ def decorator(test_cls):
             copy_and_update(src_filepath, dst_filepath)
             test_mod = import_from_path(test_module_name, dst_filepath)
             for attr_name in dir(test_mod):
-                if attr_name.startswith("test"):
+                is_test_fn = test_pattern is None and attr_name.startswith("test")
+                is_test_fn |= test_pattern is not None and any([p in attr_name for p in test_pattern])
+                if is_test_fn:
                     obj = getattr(test_mod, attr_name)
                     if callable(obj):
                         test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
@@ -146,10 +343,6 @@ def closure():
                     for _ in range(num_runs):
                         __test_fn__(self, *args, **kwargs)
 
-                    barrier.wait()
-                    gc.collect()
-                    assert Context._get_live_count() == 0
-
                 with concurrent.futures.ThreadPoolExecutor(
                     max_workers=num_workers
                 ) as executor:
@@ -158,7 +351,10 @@ def closure():
                         futures.append(executor.submit(closure))
                     # We should call future.result() to re-raise an exception if test has
                     # failed
-                    list(f.result() for f in futures)
+                    assert len(list(f.result() for f in futures)) == num_workers
+
+                gc.collect()
+                assert Context._get_live_count() == 0
 
                 captured = capfd.readouterr()
                 if len(captured.err) > 0:
@@ -175,12 +371,13 @@ def closure():
     return decorator
 
 
- at multi_threaded(num_workers=4, num_runs=10)
+ at multi_threaded(num_workers=6, num_runs=20)
 @add_existing_tests(test_prefix="_original_test")
 class TestAllMultiThreaded:
     @pytest.fixture(scope='class')
     def teardown(self):
-        self.output_folder.cleanup()
+        if hasattr(self, "output_folder"):
+            self.output_folder.cleanup()
 
     def _original_test_create_context(self):
         with Context() as ctx:



More information about the Mlir-commits mailing list