[Mlir-commits] [mlir] Enabled freethreading support in MLIR python bindings (PR #122684)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 02:23:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: vfdev (vfdev-5)

<details>
<summary>Changes</summary>

Reland reverted https://github.com/llvm/llvm-project/pull/107103 with the fixes for Python 3.8

cc @<!-- -->jpienaar 

---

Patch is 39.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122684.diff


10 Files Affected:

- (modified) mlir/cmake/modules/AddMLIRPython.cmake (+20-1) 
- (modified) mlir/docs/Bindings/Python.md (+40) 
- (modified) mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp (+3-1) 
- (modified) mlir/lib/Bindings/Python/Globals.h (+11-1) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+27-4) 
- (modified) mlir/lib/Bindings/Python/IRModule.cpp (+16-2) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+1) 
- (modified) mlir/lib/Bindings/Python/MainModule.cpp (+2-7) 
- (modified) mlir/python/requirements.txt (+2-1) 
- (added) mlir/test/python/multithreaded_tests.py (+518) 


``````````diff
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 717a503468a85d..0679db9cf93e19 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -668,12 +668,31 @@ function(add_mlir_python_extension libname extname)
   elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
     nanobind_add_module(${libname}
       NB_DOMAIN mlir
+      FREE_THREADED
       ${ARG_SOURCES}
     )
 
     if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
       # Avoids warnings from upstream nanobind.
-      target_compile_options(nanobind-static
+      set(nanobind_target "nanobind-static")
+      if (NOT TARGET ${nanobind_target})
+        # Get correct nanobind target name: nanobind-static-ft or something else
+        # It is set by nanobind_add_module function according to the passed options
+        get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
+
+        # Iterate over the list of targets
+        foreach(target ${all_targets})
+          # Check if the target name matches the given string
+          if("${target}" MATCHES "nanobind-")
+            set(nanobind_target "${target}")
+          endif()
+        endforeach()
+
+        if (NOT TARGET ${nanobind_target})
+          message(FATAL_ERROR "Could not find nanobind target to set compile options to")
+        endif()
+      endif()
+      target_compile_options(${nanobind_target}
         PRIVATE
           -Wno-cast-qual
           -Wno-zero-length-array
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 32df3310d811d7..b8bd0f507a5108 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1187,3 +1187,43 @@ or nanobind and
 utilities to connect to the rest of Python API. The bindings can be located in a
 separate module or in the same module as attributes and types, and
 loaded along with the dialect.
+
+## Free-threading (No-GIL) support
+
+Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
+
+MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
+
+```python
+# python3.13t example.py
+import concurrent.futures
+
+import mlir.dialects.arith as arith
+from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
+
+
+def func(py_value):
+    with Context() as ctx:
+        module = Module.create(loc=Location.file("foo.txt", 0, 0))
+
+        dtype = IntegerType.get_signless(64)
+        with InsertionPoint(module.body), Location.name("a"):
+            arith.constant(dtype, py_value)
+
+    return module
+
+
+num_workers = 8
+with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+    futures = []
+    for i in range(num_workers):
+        futures.append(executor.submit(func, i))
+    assert len(list(f.result() for f in futures)) == num_workers
+```
+
+The exceptions to the free-threading compatibility:
+- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
+- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
+- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
+- Usage of `mlir.dialects.transform.interpreter` is unsafe.
+- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
\ No newline at end of file
diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
index 397db4c20e7432..dd3c4c2945cca8 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
@@ -12,9 +12,11 @@
 #include "Standalone-c/Dialects.h"
 #include "mlir/Bindings/Python/PybindAdaptors.h"
 
+namespace py = pybind11;
+
 using namespace mlir::python::adaptors;
 
-PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
+PYBIND11_MODULE(_standaloneDialectsPybind11, m, py::mod_gil_not_used()) {
   //===--------------------------------------------------------------------===//
   // standalone dialect
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 0ec522d14f74bd..826a34a5351765 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -24,6 +24,7 @@ namespace mlir {
 namespace python {
 
 /// Globals that are always accessible once the extension has been initialized.
+/// Methods of this class are thread-safe.
 class PyGlobals {
 public:
   PyGlobals();
@@ -37,12 +38,18 @@ class PyGlobals {
 
   /// Get and set the list of parent modules to search for dialect
   /// implementation classes.
-  std::vector<std::string> &getDialectSearchPrefixes() {
+  std::vector<std::string> getDialectSearchPrefixes() {
+    nanobind::ft_lock_guard lock(mutex);
     return dialectSearchPrefixes;
   }
   void setDialectSearchPrefixes(std::vector<std::string> newValues) {
+    nanobind::ft_lock_guard lock(mutex);
     dialectSearchPrefixes.swap(newValues);
   }
+  void addDialectSearchPrefix(std::string value) {
+    nanobind::ft_lock_guard lock(mutex);
+    dialectSearchPrefixes.push_back(std::move(value));
+  }
 
   /// Loads a python module corresponding to the given dialect namespace.
   /// No-ops if the module has already been loaded or is not found. Raises
@@ -109,6 +116,9 @@ class PyGlobals {
 
 private:
   static PyGlobals *instance;
+
+  nanobind::ft_mutex mutex;
+
   /// Module name prefixes to search under for dialect implementation modules.
   std::vector<std::string> dialectSearchPrefixes;
   /// Map of dialect namespace to external dialect class object.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 453d4f7c7e8bca..463ebdebb3f3f6 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
 
 /// Wrapper for the global LLVM debugging flag.
 struct PyGlobalDebugFlag {
-  static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
+  static void set(nb::object &o, bool enable) {
+    nb::ft_lock_guard lock(mutex);
+    mlirEnableGlobalDebug(enable);
+  }
 
-  static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
+  static bool get(const nb::object &) {
+    nb::ft_lock_guard lock(mutex);
+    return mlirIsGlobalDebugEnabled();
+  }
 
   static void bind(nb::module_ &m) {
     // Debug flags.
@@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
         .def_static(
             "set_types",
             [](const std::string &type) {
+              nb::ft_lock_guard lock(mutex);
               mlirSetGlobalDebugType(type.c_str());
             },
             "types"_a, "Sets specific debug types to be produced by LLVM")
@@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
           pointers.reserve(types.size());
           for (const std::string &str : types)
             pointers.push_back(str.c_str());
+          nb::ft_lock_guard lock(mutex);
           mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
         });
   }
+
+private:
+  static nb::ft_mutex mutex;
 };
 
+nb::ft_mutex PyGlobalDebugFlag::mutex;
+
 struct PyAttrBuilderMap {
   static bool dunderContains(const std::string &attributeKind) {
     return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -606,6 +619,7 @@ class PyOpOperandIterator {
 
 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
   nb::gil_scoped_acquire acquire;
+  nb::ft_lock_guard lock(live_contexts_mutex);
   auto &liveContexts = getLiveContexts();
   liveContexts[context.ptr] = this;
 }
@@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
   // forContext method, which always puts the associated handle into
   // liveContexts.
   nb::gil_scoped_acquire acquire;
-  getLiveContexts().erase(context.ptr);
+  {
+    nb::ft_lock_guard lock(live_contexts_mutex);
+    getLiveContexts().erase(context.ptr);
+  }
   mlirContextDestroy(context);
 }
 
@@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
 
 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   nb::gil_scoped_acquire acquire;
+  nb::ft_lock_guard lock(live_contexts_mutex);
   auto &liveContexts = getLiveContexts();
   auto it = liveContexts.find(context.ptr);
   if (it == liveContexts.end()) {
@@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   return PyMlirContextRef(it->second, std::move(pyRef));
 }
 
+nb::ft_mutex PyMlirContext::live_contexts_mutex;
+
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
   static LiveContextMap liveContexts;
   return liveContexts;
 }
 
-size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveCount() {
+  nb::ft_lock_guard lock(live_contexts_mutex);
+  return getLiveContexts().size();
+}
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f7bf77e5a7e043..e600f1bbd44932 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
 PyGlobals::~PyGlobals() { instance = nullptr; }
 
 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
-  if (loadedDialectModules.contains(dialectNamespace))
-    return true;
+  {
+    nb::ft_lock_guard lock(mutex);
+    if (loadedDialectModules.contains(dialectNamespace))
+      return true;
+  }
   // Since re-entrancy is possible, make a copy of the search prefixes.
   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
   nb::object loaded = nb::none();
@@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
     return false;
   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
   // may have occurred, which may do anything.
+  nb::ft_lock_guard lock(mutex);
   loadedDialectModules.insert(dialectNamespace);
   return true;
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
                                          nb::callable pyFunc, bool replace) {
+  nb::ft_lock_guard lock(mutex);
   nb::object &found = attributeBuilderMap[attributeKind];
   if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
 
 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
                                    nb::callable typeCaster, bool replace) {
+  nb::ft_lock_guard lock(mutex);
   nb::object &found = typeCasterMap[mlirTypeID];
   if (found && !replace)
     throw std::runtime_error("Type caster is already registered with caster: " +
@@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
 
 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
                                     nb::callable valueCaster, bool replace) {
+  nb::ft_lock_guard lock(mutex);
   nb::object &found = valueCasterMap[mlirTypeID];
   if (found && !replace)
     throw std::runtime_error("Value caster is already registered: " +
@@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     nb::object pyClass) {
+  nb::ft_lock_guard lock(mutex);
   nb::object &found = dialectClassMap[dialectNamespace];
   if (found) {
     throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
                                       nb::object pyClass, bool replace) {
+  nb::ft_lock_guard lock(mutex);
   nb::object &found = operationClassMap[operationName];
   if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
 
 std::optional<nb::callable>
 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
+  nb::ft_lock_guard lock(mutex);
   const auto foundIt = attributeBuilderMap.find(attributeKind);
   if (foundIt != attributeBuilderMap.end()) {
     assert(foundIt->second && "attribute builder is defined");
@@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
                                                         MlirDialect dialect) {
   // Try to load dialect module.
   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+  nb::ft_lock_guard lock(mutex);
   const auto foundIt = typeCasterMap.find(mlirTypeID);
   if (foundIt != typeCasterMap.end()) {
     assert(foundIt->second && "type caster is defined");
@@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
                                                          MlirDialect dialect) {
   // Try to load dialect module.
   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+  nb::ft_lock_guard lock(mutex);
   const auto foundIt = valueCasterMap.find(mlirTypeID);
   if (foundIt != valueCasterMap.end()) {
     assert(foundIt->second && "value caster is defined");
@@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   // Make sure dialect module is loaded.
   if (!loadDialectModule(dialectNamespace))
     return std::nullopt;
+  nb::ft_lock_guard lock(mutex);
   const auto foundIt = dialectClassMap.find(dialectNamespace);
   if (foundIt != dialectClassMap.end()) {
     assert(foundIt->second && "dialect class is defined");
@@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   if (!loadDialectModule(dialectNamespace))
     return std::nullopt;
 
+  nb::ft_lock_guard lock(mutex);
   auto foundIt = operationClassMap.find(operationName);
   if (foundIt != operationClassMap.end()) {
     assert(foundIt->second && "OpView is defined");
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 8fb32a225e65f1..f5fbb6c61b57e2 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -260,6 +260,7 @@ class PyMlirContext {
   // Note that this holds a handle, which does not imply ownership.
   // Mappings will be removed when the context is destructed.
   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
+  static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
   // Interns all live modules associated with this context. Modules tracked
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7c4064262012ef..6f49431006605a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) {
       .def_prop_rw("dialect_search_modules",
                    &PyGlobals::getDialectSearchPrefixes,
                    &PyGlobals::setDialectSearchPrefixes)
-      .def(
-          "append_dialect_search_prefix",
-          [](PyGlobals &self, std::string moduleName) {
-            self.getDialectSearchPrefixes().push_back(std::move(moduleName));
-          },
-          "module_name"_a)
+      .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
+           "module_name"_a)
       .def(
           "_check_dialect_module_loaded",
           [](PyGlobals &self, const std::string &dialectNamespace) {
@@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) {
                   nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
               PyGlobals::get().registerOperationImpl(operationName, opClass,
                                                      replace);
-
               // Dict-stuff the new opClass by name onto the dialect class.
               nb::object opClassName = opClass.attr("__name__");
               dialectClass.attr(opClassName) = opClass;
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index f240d6ef944ec7..1a0075e829aef2 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -2,4 +2,5 @@ nanobind>=2.4, <3.0
 numpy>=1.19.5, <=2.1.2
 pybind11>=2.10.0, <=2.13.6
 PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.1.0, <=0.5.0   # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13"   # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
\ No newline at end of file
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
new file mode 100644
index 00000000000000..6e1a6683468729
--- /dev/null
+++ b/mlir/test/python/multithreaded_tests.py
@@ -0,0 +1,518 @@
+# RUN: %PYTHON %s
+"""
+This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
+Tests can be run using pytest:
+```bash
+python3.13t -mpytest -vvv multithreaded_tests.py
+```
+
+IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
+and passing if no warnings reported by TSAN and failing otherwise.
+
+
+Details on the generated tests and execution:
+1) Multi-threaded execution: all generated tests are executed independently by
+a pool of threads, running each test multiple times, see @multi_threaded for details
+
+2) Tests generation: we use existing tests: test/python/ir/*.py,
+test/python/dialects/*.py, etc to generate multi-threaded tests.
+In details, we perform the following:
+a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
+b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
+c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
+In order to import the test file as python module, we remove all executing functions, like
+`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
+
+
+Observed warnings reported by TSAN.
+
+CPython and free-threading known data-races:
+1) ctypes related races: https://github.com/python/cpython/issues/127945
+2) LLVM related data-races, llvm::raw_ostream is not thread-safe
+- mlir pass manager
+- dialects/transform_interpreter.py
+- ir/diagnostic_handler.py
+- ir/module.py
+3) Dialect gpu module-to-binary method is unsafe
+"""
+import concurrent.futures
+import gc
+import importlib.util
+import os
+import sys
+import threading
+import tempfile
+import unittest
+
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from typing import Optional, List
+
+import mlir.dialects.arith as arith
+from mlir.dialects import transform
+from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
+
+
+def import_from_path(module_name: str, file_path: Path):
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[module_name] = module
+    spec.loader.exec_module(module)
+    return module
+
+
+def copy_and_update(src_filepath: Path, dst_filepath: Path):
+    # We should remove all calls like `run(testMethod)`
+    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
+        while True:
+            src_line = reader.readline()
+            if len(src_line) == 0:
+                break
+            skip_lines = [
+                "run(",
+                "@run",
+                "@constructAndPrintInModule",
+                "run_apply_patterns(",
+        ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/122684


More information about the Mlir-commits mailing list