[Mlir-commits] [mlir] Added free-threading CPython mode support in MLIR Python bindings (PR #107103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 22 13:45:56 PDT 2024


vfdev-5 wrote:

@stellaraccident thanks, i have few questions though about the current python bindings tests using lit.

My local tests are rather simple and I would like to improve that:
```python
import typer

import threading
import concurrent.futures

import mlir.dialects.arith as arith
from mlir.ir import Context, Location, Module, IntegerType, F64Type, InsertionPoint


def mt_run(fn, num_threads, args=(), kwargs={}):
    barrier = threading.Barrier(num_threads)

    def closure():
        barrier.wait()
        return fn(*args, **kwargs)

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_threads
    ) as executor:
        futures = []
        for _ in range(num_threads):
            futures.append(executor.submit(closure))
        # We should call future.result() to re-raise an exception if test has
        # failed
        return list(f.result() for f in futures)


def func():
    py_values = [123, 234, 345]
    with Context() as ctx:
        module = Module.create(loc=Location.file("foo.txt", 0, 0))

        dtype = IntegerType.get_signless(64)
        with InsertionPoint(module.body), Location.name("a"):
        # with Location.name("a"):
            arith.constant(dtype, py_values[0])

        with InsertionPoint(module.body), Location.name("b"):
        # with Location.name("b"):
            arith.constant(dtype, py_values[1])

        with InsertionPoint(module.body), Location.name("c"):
        # with Location.name("c"):
            arith.constant(dtype, py_values[2])

    return str(module)


def func2():
    py_values = [123, 234, 345]
    with Context() as ctx, Location.file("foo.txt", 0, 0):
        module = Module.create()
        with InsertionPoint(module.body):
            dtype = IntegerType.get_signless(64)
            arith.constant(dtype, py_values[0])

    return str(module)

def test(func, num_threads=10, expected_first = False):

    if expected_first:
        expected_mlir = func()
        print("Expected MLIR:", expected_mlir)

    output_mlir_list = mt_run(func, num_threads=num_threads)

    if not expected_first:
        expected_mlir = func()
        print("Expected MLIR:", expected_mlir)

    for i, output_mlir in enumerate(output_mlir_list):
        assert output_mlir == expected_mlir, (i, output_mlir, expected_mlir)


def main(
    n: int = 100,
    name: str = "test",
    nt: int = 10,
    ef: bool = False,
):
    test_fn = {
        "test": func,
        "test2": func2,
    }[name]
    for i in range(n):
        print("- Count: ", i)
        test(test_fn, num_threads=nt, expected_first=ef)


if __name__ == "__main__":
    typer.run(main)

```

Ideally, making existing tests to run in a multi-threaded execution (either providing a manual implementation or using tools like: https://github.com/Quansight-Labs/pytest-run-parallel).
Seems like lit is running tests and using stdout output checks which may not always work correctly with multi-treaded execution...

> we should come up with a convention to protect those with a global mutex. I know there is an idiom for this in CPython itself, but is there a common thing done for pybind/extensions yet?

Yes, there is an example in pybind11 for that:
https://github.com/pybind/pybind11/blob/1f8b4a7f1a1c5cc9bd6e0d63fe15540e6c458b24/include/pybind11/detail/internals.h#L645-L649

I applied a similar thing for `getLiveContexts` (locally):
```c++
#ifdef Py_GIL_DISABLED
  static PyMutex &getLock() {
    static PyMutex lock;
    return lock;
  }
#endif

  template<typename F>
  static inline auto withLiveContexts(const F& cb) -> decltype(cb(getLiveContexts())) {
    auto &liveContexts = getLiveContexts();
#ifdef Py_GIL_DISABLED
    auto &lock = getLock();
    PyMutex_Lock(&lock);
#endif
    auto result = cb(liveContexts);
#ifdef Py_GIL_DISABLED
    PyMutex_Unlock(&lock);
#endif
    return result;
  }
```

https://github.com/llvm/llvm-project/pull/107103


More information about the Mlir-commits mailing list