[Mlir-commits] [mlir] Added free-threading CPython mode support in MLIR Python bindings (PR #107103)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 22 13:45:56 PDT 2024
vfdev-5 wrote:
@stellaraccident thanks, i have few questions though about the current python bindings tests using lit.
My local tests are rather simple and I would like to improve that:
```python
import typer
import threading
import concurrent.futures
import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint
def mt_run(fn, num_threads, args=(), kwargs={}):
barrier = threading.Barrier(num_threads)
def closure():
barrier.wait()
return fn(*args, **kwargs)
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_threads
) as executor:
futures = []
for _ in range(num_threads):
futures.append(executor.submit(closure))
# We should call future.result() to re-raise an exception if test has
# failed
return list(f.result() for f in futures)
def func():
py_values = [123, 234, 345]
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))
dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
# with Location.name("a"):
arith.constant(dtype, py_values[0])
with InsertionPoint(module.body), Location.name("b"):
# with Location.name("b"):
arith.constant(dtype, py_values[1])
with InsertionPoint(module.body), Location.name("c"):
# with Location.name("c"):
arith.constant(dtype, py_values[2])
return str(module)
def func2():
py_values = [123, 234, 345]
with Context() as ctx, Location.file("foo.txt", 0, 0):
module = Module.create()
with InsertionPoint(module.body):
dtype = IntegerType.get_signless(64)
arith.constant(dtype, py_values[0])
return str(module)
def test(func, num_threads=10, expected_first = False):
if expected_first:
expected_mlir = func()
print("Expected MLIR:", expected_mlir)
output_mlir_list = mt_run(func, num_threads=num_threads)
if not expected_first:
expected_mlir = func()
print("Expected MLIR:", expected_mlir)
for i, output_mlir in enumerate(output_mlir_list):
assert output_mlir == expected_mlir, (i, output_mlir, expected_mlir)
def main(
n: int = 100,
name: str = "test",
nt: int = 10,
ef: bool = False,
):
test_fn = {
"test": func,
"test2": func2,
}[name]
for i in range(n):
print("- Count: ", i)
test(test_fn, num_threads=nt, expected_first=ef)
if __name__ == "__main__":
typer.run(main)
```
Ideally, making existing tests to run in a multi-threaded execution (either providing a manual implementation or using tools like: https://github.com/Quansight-Labs/pytest-run-parallel).
Seems like lit is running tests and using stdout output checks which may not always work correctly with multi-treaded execution...
> we should come up with a convention to protect those with a global mutex. I know there is an idiom for this in CPython itself, but is there a common thing done for pybind/extensions yet?
Yes, there is an example in pybind11 for that:
https://github.com/pybind/pybind11/blob/1f8b4a7f1a1c5cc9bd6e0d63fe15540e6c458b24/include/pybind11/detail/internals.h#L645-L649
I applied a similar thing for `getLiveContexts` (locally):
```c++
#ifdef Py_GIL_DISABLED
static PyMutex &getLock() {
static PyMutex lock;
return lock;
}
#endif
template<typename F>
static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
auto &liveContexts = getLiveContexts();
#ifdef Py_GIL_DISABLED
auto &lock = getLock();
PyMutex_Lock(&lock);
#endif
auto result = cb(liveContexts);
#ifdef Py_GIL_DISABLED
PyMutex_Unlock(&lock);
#endif
return result;
}
```
https://github.com/llvm/llvm-project/pull/107103
More information about the Mlir-commits
mailing list