[Mlir-commits] [mlir] Enabled freethreading support in MLIR python bindings (PR #122684)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 13 02:23:55 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: vfdev (vfdev-5)
<details>
<summary>Changes</summary>
Reland reverted https://github.com/llvm/llvm-project/pull/107103 with the fixes for Python 3.8
cc @<!-- -->jpienaar
---
Patch is 39.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122684.diff
10 Files Affected:
- (modified) mlir/cmake/modules/AddMLIRPython.cmake (+20-1)
- (modified) mlir/docs/Bindings/Python.md (+40)
- (modified) mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp (+3-1)
- (modified) mlir/lib/Bindings/Python/Globals.h (+11-1)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+27-4)
- (modified) mlir/lib/Bindings/Python/IRModule.cpp (+16-2)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+1)
- (modified) mlir/lib/Bindings/Python/MainModule.cpp (+2-7)
- (modified) mlir/python/requirements.txt (+2-1)
- (added) mlir/test/python/multithreaded_tests.py (+518)
``````````diff
diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 717a503468a85d..0679db9cf93e19 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -668,12 +668,31 @@ function(add_mlir_python_extension libname extname)
elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
nanobind_add_module(${libname}
NB_DOMAIN mlir
+ FREE_THREADED
${ARG_SOURCES}
)
if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
# Avoids warnings from upstream nanobind.
- target_compile_options(nanobind-static
+ set(nanobind_target "nanobind-static")
+ if (NOT TARGET ${nanobind_target})
+ # Get correct nanobind target name: nanobind-static-ft or something else
+ # It is set by nanobind_add_module function according to the passed options
+ get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
+
+ # Iterate over the list of targets
+ foreach(target ${all_targets})
+ # Check if the target name matches the given string
+ if("${target}" MATCHES "nanobind-")
+ set(nanobind_target "${target}")
+ endif()
+ endforeach()
+
+ if (NOT TARGET ${nanobind_target})
+ message(FATAL_ERROR "Could not find nanobind target to set compile options to")
+ endif()
+ endif()
+ target_compile_options(${nanobind_target}
PRIVATE
-Wno-cast-qual
-Wno-zero-length-array
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 32df3310d811d7..b8bd0f507a5108 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1187,3 +1187,43 @@ or nanobind and
utilities to connect to the rest of Python API. The bindings can be located in a
separate module or in the same module as attributes and types, and
loaded along with the dialect.
+
+## Free-threading (No-GIL) support
+
+Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
+
+MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
+
+```python
+# python3.13t example.py
+import concurrent.futures
+
+import mlir.dialects.arith as arith
+from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
+
+
+def func(py_value):
+ with Context() as ctx:
+ module = Module.create(loc=Location.file("foo.txt", 0, 0))
+
+ dtype = IntegerType.get_signless(64)
+ with InsertionPoint(module.body), Location.name("a"):
+ arith.constant(dtype, py_value)
+
+ return module
+
+
+num_workers = 8
+with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for i in range(num_workers):
+ futures.append(executor.submit(func, i))
+ assert len(list(f.result() for f in futures)) == num_workers
+```
+
+The exceptions to the free-threading compatibility:
+- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
+- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
+- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
+- Usage of `mlir.dialects.transform.interpreter` is unsafe.
+- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
\ No newline at end of file
diff --git a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
index 397db4c20e7432..dd3c4c2945cca8 100644
--- a/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
+++ b/mlir/examples/standalone/python/StandaloneExtensionPybind11.cpp
@@ -12,9 +12,11 @@
#include "Standalone-c/Dialects.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
+namespace py = pybind11;
+
using namespace mlir::python::adaptors;
-PYBIND11_MODULE(_standaloneDialectsPybind11, m) {
+PYBIND11_MODULE(_standaloneDialectsPybind11, m, py::mod_gil_not_used()) {
//===--------------------------------------------------------------------===//
// standalone dialect
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 0ec522d14f74bd..826a34a5351765 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -24,6 +24,7 @@ namespace mlir {
namespace python {
/// Globals that are always accessible once the extension has been initialized.
+/// Methods of this class are thread-safe.
class PyGlobals {
public:
PyGlobals();
@@ -37,12 +38,18 @@ class PyGlobals {
/// Get and set the list of parent modules to search for dialect
/// implementation classes.
- std::vector<std::string> &getDialectSearchPrefixes() {
+ std::vector<std::string> getDialectSearchPrefixes() {
+ nanobind::ft_lock_guard lock(mutex);
return dialectSearchPrefixes;
}
void setDialectSearchPrefixes(std::vector<std::string> newValues) {
+ nanobind::ft_lock_guard lock(mutex);
dialectSearchPrefixes.swap(newValues);
}
+ void addDialectSearchPrefix(std::string value) {
+ nanobind::ft_lock_guard lock(mutex);
+ dialectSearchPrefixes.push_back(std::move(value));
+ }
/// Loads a python module corresponding to the given dialect namespace.
/// No-ops if the module has already been loaded or is not found. Raises
@@ -109,6 +116,9 @@ class PyGlobals {
private:
static PyGlobals *instance;
+
+ nanobind::ft_mutex mutex;
+
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 453d4f7c7e8bca..463ebdebb3f3f6 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
/// Wrapper for the global LLVM debugging flag.
struct PyGlobalDebugFlag {
- static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
+ static void set(nb::object &o, bool enable) {
+ nb::ft_lock_guard lock(mutex);
+ mlirEnableGlobalDebug(enable);
+ }
- static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
+ static bool get(const nb::object &) {
+ nb::ft_lock_guard lock(mutex);
+ return mlirIsGlobalDebugEnabled();
+ }
static void bind(nb::module_ &m) {
// Debug flags.
@@ -255,6 +261,7 @@ struct PyGlobalDebugFlag {
.def_static(
"set_types",
[](const std::string &type) {
+ nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugType(type.c_str());
},
"types"_a, "Sets specific debug types to be produced by LLVM")
@@ -263,11 +270,17 @@ struct PyGlobalDebugFlag {
pointers.reserve(types.size());
for (const std::string &str : types)
pointers.push_back(str.c_str());
+ nb::ft_lock_guard lock(mutex);
mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
});
}
+
+private:
+ static nb::ft_mutex mutex;
};
+nb::ft_mutex PyGlobalDebugFlag::mutex;
+
struct PyAttrBuilderMap {
static bool dunderContains(const std::string &attributeKind) {
return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -606,6 +619,7 @@ class PyOpOperandIterator {
PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
nb::gil_scoped_acquire acquire;
+ nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
liveContexts[context.ptr] = this;
}
@@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() {
// forContext method, which always puts the associated handle into
// liveContexts.
nb::gil_scoped_acquire acquire;
- getLiveContexts().erase(context.ptr);
+ {
+ nb::ft_lock_guard lock(live_contexts_mutex);
+ getLiveContexts().erase(context.ptr);
+ }
mlirContextDestroy(context);
}
@@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
nb::gil_scoped_acquire acquire;
+ nb::ft_lock_guard lock(live_contexts_mutex);
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
@@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
return PyMlirContextRef(it->second, std::move(pyRef));
}
+nb::ft_mutex PyMlirContext::live_contexts_mutex;
+
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
}
-size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+size_t PyMlirContext::getLiveCount() {
+ nb::ft_lock_guard lock(live_contexts_mutex);
+ return getLiveContexts().size();
+}
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index f7bf77e5a7e043..e600f1bbd44932 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -38,8 +38,11 @@ PyGlobals::PyGlobals() {
PyGlobals::~PyGlobals() { instance = nullptr; }
bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
- if (loadedDialectModules.contains(dialectNamespace))
- return true;
+ {
+ nb::ft_lock_guard lock(mutex);
+ if (loadedDialectModules.contains(dialectNamespace))
+ return true;
+ }
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
nb::object loaded = nb::none();
@@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return false;
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
+ nb::ft_lock_guard lock(mutex);
loadedDialectModules.insert(dialectNamespace);
return true;
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
+ nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
nb::callable typeCaster, bool replace) {
+ nb::ft_lock_guard lock(mutex);
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
@@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
nb::callable valueCaster, bool replace) {
+ nb::ft_lock_guard lock(mutex);
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
@@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
nb::object pyClass) {
+ nb::ft_lock_guard lock(mutex);
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
void PyGlobals::registerOperationImpl(const std::string &operationName,
nb::object pyClass, bool replace) {
+ nb::ft_lock_guard lock(mutex);
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
+ nb::ft_lock_guard lock(mutex);
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
assert(foundIt->second && "attribute builder is defined");
@@ -133,6 +143,7 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+ nb::ft_lock_guard lock(mutex);
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
assert(foundIt->second && "type caster is defined");
@@ -145,6 +156,7 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
+ nb::ft_lock_guard lock(mutex);
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
@@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
+ nb::ft_lock_guard lock(mutex);
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
assert(foundIt->second && "dialect class is defined");
@@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
if (!loadDialectModule(dialectNamespace))
return std::nullopt;
+ nb::ft_lock_guard lock(mutex);
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
assert(foundIt->second && "OpView is defined");
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 8fb32a225e65f1..f5fbb6c61b57e2 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -260,6 +260,7 @@ class PyMlirContext {
// Note that this holds a handle, which does not imply ownership.
// Mappings will be removed when the context is destructed.
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
+ static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
// Interns all live modules associated with this context. Modules tracked
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 7c4064262012ef..6f49431006605a 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) {
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
- .def(
- "append_dialect_search_prefix",
- [](PyGlobals &self, std::string moduleName) {
- self.getDialectSearchPrefixes().push_back(std::move(moduleName));
- },
- "module_name"_a)
+ .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
+ "module_name"_a)
.def(
"_check_dialect_module_loaded",
[](PyGlobals &self, const std::string &dialectNamespace) {
@@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) {
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);
-
// Dict-stuff the new opClass by name onto the dialect class.
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index f240d6ef944ec7..1a0075e829aef2 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -2,4 +2,5 @@ nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13"
\ No newline at end of file
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
new file mode 100644
index 00000000000000..6e1a6683468729
--- /dev/null
+++ b/mlir/test/python/multithreaded_tests.py
@@ -0,0 +1,518 @@
+# RUN: %PYTHON %s
+"""
+This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
+Tests can be run using pytest:
+```bash
+python3.13t -mpytest -vvv multithreaded_tests.py
+```
+
+IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
+and passing if no warnings reported by TSAN and failing otherwise.
+
+
+Details on the generated tests and execution:
+1) Multi-threaded execution: all generated tests are executed independently by
+a pool of threads, running each test multiple times, see @multi_threaded for details
+
+2) Tests generation: we use existing tests: test/python/ir/*.py,
+test/python/dialects/*.py, etc to generate multi-threaded tests.
+In details, we perform the following:
+a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
+b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
+c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
+In order to import the test file as python module, we remove all executing functions, like
+`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
+
+
+Observed warnings reported by TSAN.
+
+CPython and free-threading known data-races:
+1) ctypes related races: https://github.com/python/cpython/issues/127945
+2) LLVM related data-races, llvm::raw_ostream is not thread-safe
+- mlir pass manager
+- dialects/transform_interpreter.py
+- ir/diagnostic_handler.py
+- ir/module.py
+3) Dialect gpu module-to-binary method is unsafe
+"""
+import concurrent.futures
+import gc
+import importlib.util
+import os
+import sys
+import threading
+import tempfile
+import unittest
+
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from typing import Optional, List
+
+import mlir.dialects.arith as arith
+from mlir.dialects import transform
+from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
+
+
+def import_from_path(module_name: str, file_path: Path):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[module_name] = module
+ spec.loader.exec_module(module)
+ return module
+
+
+def copy_and_update(src_filepath: Path, dst_filepath: Path):
+ # We should remove all calls like `run(testMethod)`
+ with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
+ while True:
+ src_line = reader.readline()
+ if len(src_line) == 0:
+ break
+ skip_lines = [
+ "run(",
+ "@run",
+ "@constructAndPrintInModule",
+ "run_apply_patterns(",
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/122684
More information about the Mlir-commits
mailing list