[Mlir-commits] [mlir] 3f1486f - Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"

Jacques Pienaar llvmlistbot at llvm.org
Sun Jan 12 10:31:32 PST 2025


Author: Jacques Pienaar
Date: 2025-01-12T18:30:42Z
New Revision: 3f1486f08e0dd64136fb7f50e38cd618dd0255d2

URL: https://github.com/llvm/llvm-project/commit/3f1486f08e0dd64136fb7f50e38cd618dd0255d2
DIFF: https://github.com/llvm/llvm-project/commit/3f1486f08e0dd64136fb7f50e38cd618dd0255d2.diff

LOG: Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)"

Breaks on 3.8, rolling back to avoid breakage while fixing.

This reverts commit 9dee7c44491635ec9037b90050bcdbd3d5291e38.

Added: 
    

Modified: 
    mlir/cmake/modules/AddMLIRPython.cmake
    mlir/docs/Bindings/Python.md
    mlir/lib/Bindings/Python/Globals.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/python/requirements.txt

Removed: 
    mlir/test/python/multithreaded_tests.py


################################################################################
diff  --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake
index 0679db9cf93e19..717a503468a85d 100644
--- a/mlir/cmake/modules/AddMLIRPython.cmake
+++ b/mlir/cmake/modules/AddMLIRPython.cmake
@@ -668,31 +668,12 @@ function(add_mlir_python_extension libname extname)
   elseif(ARG_PYTHON_BINDINGS_LIBRARY STREQUAL "nanobind")
     nanobind_add_module(${libname}
       NB_DOMAIN mlir
-      FREE_THREADED
       ${ARG_SOURCES}
     )
 
     if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL)
       # Avoids warnings from upstream nanobind.
-      set(nanobind_target "nanobind-static")
-      if (NOT TARGET ${nanobind_target})
-        # Get correct nanobind target name: nanobind-static-ft or something else
-        # It is set by nanobind_add_module function according to the passed options
-        get_property(all_targets DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY BUILDSYSTEM_TARGETS)
-
-        # Iterate over the list of targets
-        foreach(target ${all_targets})
-          # Check if the target name matches the given string
-          if("${target}" MATCHES "nanobind-")
-            set(nanobind_target "${target}")
-          endif()
-        endforeach()
-
-        if (NOT TARGET ${nanobind_target})
-          message(FATAL_ERROR "Could not find nanobind target to set compile options to")
-        endif()
-      endif()
-      target_compile_options(${nanobind_target}
+      target_compile_options(nanobind-static
         PRIVATE
           -Wno-cast-qual
           -Wno-zero-length-array

diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index b8bd0f507a5108..32df3310d811d7 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -1187,43 +1187,3 @@ or nanobind and
 utilities to connect to the rest of Python API. The bindings can be located in a
 separate module or in the same module as attributes and types, and
 loaded along with the dialect.
-
-## Free-threading (No-GIL) support
-
-Free-threading or no-GIL support refers to CPython interpreter (>=3.13) with Global Interpreter Lock made optional. For details on the topic, please check [PEP-703](https://peps.python.org/pep-0703/) and this [Python free-threading guide](https://py-free-threading.github.io/).
-
-MLIR Python bindings are free-threading compatible with exceptions (discussed below) in the following sense: it is safe to work in multiple threads with **independent** contexts. Below we show an example code of safe usage:
-
-```python
-# python3.13t example.py
-import concurrent.futures
-
-import mlir.dialects.arith as arith
-from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
-
-
-def func(py_value):
-    with Context() as ctx:
-        module = Module.create(loc=Location.file("foo.txt", 0, 0))
-
-        dtype = IntegerType.get_signless(64)
-        with InsertionPoint(module.body), Location.name("a"):
-            arith.constant(dtype, py_value)
-
-    return module
-
-
-num_workers = 8
-with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
-    futures = []
-    for i in range(num_workers):
-        futures.append(executor.submit(func, i))
-    assert len(list(f.result() for f in futures)) == num_workers
-```
-
-The exceptions to the free-threading compatibility:
-- IR printing is unsafe, e.g. when using `PassManager` with `PassManager.enable_ir_printing()` which calls thread-unsafe `llvm::raw_ostream`.
-- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
-- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
-- Usage of `mlir.dialects.transform.interpreter` is unsafe.
-- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
\ No newline at end of file

diff  --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 826a34a5351765..0ec522d14f74bd 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -24,7 +24,6 @@ namespace mlir {
 namespace python {
 
 /// Globals that are always accessible once the extension has been initialized.
-/// Methods of this class are thread-safe.
 class PyGlobals {
 public:
   PyGlobals();
@@ -38,18 +37,12 @@ class PyGlobals {
 
   /// Get and set the list of parent modules to search for dialect
   /// implementation classes.
-  std::vector<std::string> getDialectSearchPrefixes() {
-    nanobind::ft_lock_guard lock(mutex);
+  std::vector<std::string> &getDialectSearchPrefixes() {
     return dialectSearchPrefixes;
   }
   void setDialectSearchPrefixes(std::vector<std::string> newValues) {
-    nanobind::ft_lock_guard lock(mutex);
     dialectSearchPrefixes.swap(newValues);
   }
-  void addDialectSearchPrefix(std::string value) {
-    nanobind::ft_lock_guard lock(mutex);
-    dialectSearchPrefixes.push_back(std::move(value));
-  }
 
   /// Loads a python module corresponding to the given dialect namespace.
   /// No-ops if the module has already been loaded or is not found. Raises
@@ -116,9 +109,6 @@ class PyGlobals {
 
 private:
   static PyGlobals *instance;
-
-  nanobind::ft_mutex mutex;
-
   /// Module name prefixes to search under for dialect implementation modules.
   std::vector<std::string> dialectSearchPrefixes;
   /// Map of dialect namespace to external dialect class object.

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 463ebdebb3f3f6..453d4f7c7e8bca 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes,
 
 /// Wrapper for the global LLVM debugging flag.
 struct PyGlobalDebugFlag {
-  static void set(nb::object &o, bool enable) {
-    nb::ft_lock_guard lock(mutex);
-    mlirEnableGlobalDebug(enable);
-  }
+  static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
 
-  static bool get(const nb::object &) {
-    nb::ft_lock_guard lock(mutex);
-    return mlirIsGlobalDebugEnabled();
-  }
+  static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); }
 
   static void bind(nb::module_ &m) {
     // Debug flags.
@@ -261,7 +255,6 @@ struct PyGlobalDebugFlag {
         .def_static(
             "set_types",
             [](const std::string &type) {
-              nb::ft_lock_guard lock(mutex);
               mlirSetGlobalDebugType(type.c_str());
             },
             "types"_a, "Sets specific debug types to be produced by LLVM")
@@ -270,17 +263,11 @@ struct PyGlobalDebugFlag {
           pointers.reserve(types.size());
           for (const std::string &str : types)
             pointers.push_back(str.c_str());
-          nb::ft_lock_guard lock(mutex);
           mlirSetGlobalDebugTypes(pointers.data(), pointers.size());
         });
   }
-
-private:
-  static nb::ft_mutex mutex;
 };
 
-nb::ft_mutex PyGlobalDebugFlag::mutex;
-
 struct PyAttrBuilderMap {
   static bool dunderContains(const std::string &attributeKind) {
     return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value();
@@ -619,7 +606,6 @@ class PyOpOperandIterator {
 
 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
   nb::gil_scoped_acquire acquire;
-  nb::ft_lock_guard lock(live_contexts_mutex);
   auto &liveContexts = getLiveContexts();
   liveContexts[context.ptr] = this;
 }
@@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() {
   // forContext method, which always puts the associated handle into
   // liveContexts.
   nb::gil_scoped_acquire acquire;
-  {
-    nb::ft_lock_guard lock(live_contexts_mutex);
-    getLiveContexts().erase(context.ptr);
-  }
+  getLiveContexts().erase(context.ptr);
   mlirContextDestroy(context);
 }
 
@@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) {
 
 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   nb::gil_scoped_acquire acquire;
-  nb::ft_lock_guard lock(live_contexts_mutex);
   auto &liveContexts = getLiveContexts();
   auto it = liveContexts.find(context.ptr);
   if (it == liveContexts.end()) {
@@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   return PyMlirContextRef(it->second, std::move(pyRef));
 }
 
-nb::ft_mutex PyMlirContext::live_contexts_mutex;
-
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
   static LiveContextMap liveContexts;
   return liveContexts;
 }
 
-size_t PyMlirContext::getLiveCount() {
-  nb::ft_lock_guard lock(live_contexts_mutex);
-  return getLiveContexts().size();
-}
+size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 

diff  --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e600f1bbd44932..f7bf77e5a7e043 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -38,11 +38,8 @@ PyGlobals::PyGlobals() {
 PyGlobals::~PyGlobals() { instance = nullptr; }
 
 bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
-  {
-    nb::ft_lock_guard lock(mutex);
-    if (loadedDialectModules.contains(dialectNamespace))
-      return true;
-  }
+  if (loadedDialectModules.contains(dialectNamespace))
+    return true;
   // Since re-entrancy is possible, make a copy of the search prefixes.
   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
   nb::object loaded = nb::none();
@@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
     return false;
   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
   // may have occurred, which may do anything.
-  nb::ft_lock_guard lock(mutex);
   loadedDialectModules.insert(dialectNamespace);
   return true;
 }
 
 void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
                                          nb::callable pyFunc, bool replace) {
-  nb::ft_lock_guard lock(mutex);
   nb::object &found = attributeBuilderMap[attributeKind];
   if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Attribute builder for '") +
@@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
 
 void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
                                    nb::callable typeCaster, bool replace) {
-  nb::ft_lock_guard lock(mutex);
   nb::object &found = typeCasterMap[mlirTypeID];
   if (found && !replace)
     throw std::runtime_error("Type caster is already registered with caster: " +
@@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
 
 void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
                                     nb::callable valueCaster, bool replace) {
-  nb::ft_lock_guard lock(mutex);
   nb::object &found = valueCasterMap[mlirTypeID];
   if (found && !replace)
     throw std::runtime_error("Value caster is already registered: " +
@@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
 
 void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                     nb::object pyClass) {
-  nb::ft_lock_guard lock(mutex);
   nb::object &found = dialectClassMap[dialectNamespace];
   if (found) {
     throw std::runtime_error((llvm::Twine("Dialect namespace '") +
@@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
 
 void PyGlobals::registerOperationImpl(const std::string &operationName,
                                       nb::object pyClass, bool replace) {
-  nb::ft_lock_guard lock(mutex);
   nb::object &found = operationClassMap[operationName];
   if (found && !replace) {
     throw std::runtime_error((llvm::Twine("Operation '") + operationName +
@@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
 
 std::optional<nb::callable>
 PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
-  nb::ft_lock_guard lock(mutex);
   const auto foundIt = attributeBuilderMap.find(attributeKind);
   if (foundIt != attributeBuilderMap.end()) {
     assert(foundIt->second && "attribute builder is defined");
@@ -143,7 +133,6 @@ std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
                                                         MlirDialect dialect) {
   // Try to load dialect module.
   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-  nb::ft_lock_guard lock(mutex);
   const auto foundIt = typeCasterMap.find(mlirTypeID);
   if (foundIt != typeCasterMap.end()) {
     assert(foundIt->second && "type caster is defined");
@@ -156,7 +145,6 @@ std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
                                                          MlirDialect dialect) {
   // Try to load dialect module.
   (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
-  nb::ft_lock_guard lock(mutex);
   const auto foundIt = valueCasterMap.find(mlirTypeID);
   if (foundIt != valueCasterMap.end()) {
     assert(foundIt->second && "value caster is defined");
@@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
   // Make sure dialect module is loaded.
   if (!loadDialectModule(dialectNamespace))
     return std::nullopt;
-  nb::ft_lock_guard lock(mutex);
   const auto foundIt = dialectClassMap.find(dialectNamespace);
   if (foundIt != dialectClassMap.end()) {
     assert(foundIt->second && "dialect class is defined");
@@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
   if (!loadDialectModule(dialectNamespace))
     return std::nullopt;
 
-  nb::ft_lock_guard lock(mutex);
   auto foundIt = operationClassMap.find(operationName);
   if (foundIt != operationClassMap.end()) {
     assert(foundIt->second && "OpView is defined");

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f5fbb6c61b57e2..8fb32a225e65f1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -260,7 +260,6 @@ class PyMlirContext {
   // Note that this holds a handle, which does not imply ownership.
   // Mappings will be removed when the context is destructed.
   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
-  static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
   // Interns all live modules associated with this context. Modules tracked

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431006605a..7c4064262012ef 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) {
       .def_prop_rw("dialect_search_modules",
                    &PyGlobals::getDialectSearchPrefixes,
                    &PyGlobals::setDialectSearchPrefixes)
-      .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix,
-           "module_name"_a)
+      .def(
+          "append_dialect_search_prefix",
+          [](PyGlobals &self, std::string moduleName) {
+            self.getDialectSearchPrefixes().push_back(std::move(moduleName));
+          },
+          "module_name"_a)
       .def(
           "_check_dialect_module_loaded",
           [](PyGlobals &self, const std::string &dialectNamespace) {
@@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) {
                   nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
               PyGlobals::get().registerOperationImpl(operationName, opClass,
                                                      replace);
+
               // Dict-stuff the new opClass by name onto the dialect class.
               nb::object opClassName = opClass.attr("__name__");
               dialectClass.attr(opClassName) = opClass;

diff  --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt
index 259e679f510f70..f240d6ef944ec7 100644
--- a/mlir/python/requirements.txt
+++ b/mlir/python/requirements.txt
@@ -2,4 +2,4 @@ nanobind>=2.4, <3.0
 numpy>=1.19.5, <=2.1.2
 pybind11>=2.10.0, <=2.13.6
 PyYAML>=5.4.0, <=6.0.1
-ml_dtypes>=0.5.0, <=0.6.0   # provides several NumPy dtype extensions, including the bf16
+ml_dtypes>=0.1.0, <=0.5.0   # provides several NumPy dtype extensions, including the bf16

diff  --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
deleted file mode 100644
index 2df75e2e1b90ca..00000000000000
--- a/mlir/test/python/multithreaded_tests.py
+++ /dev/null
@@ -1,531 +0,0 @@
-# RUN: %PYTHON %s
-"""
-This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
-Tests can be run using pytest:
-```bash
-python3.13t -mpytest -vvv multithreaded_tests.py
-```
-
-IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
-and passing if no warnings reported by TSAN and failing otherwise.
-
-
-Details on the generated tests and execution:
-1) Multi-threaded execution: all generated tests are executed independently by
-a pool of threads, running each test multiple times, see @multi_threaded for details
-
-2) Tests generation: we use existing tests: test/python/ir/*.py,
-test/python/dialects/*.py, etc to generate multi-threaded tests.
-In details, we perform the following:
-a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
-b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
-c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
-In order to import the test file as python module, we remove all executing functions, like
-`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
-
-
-Observed warnings reported by TSAN.
-
-CPython and free-threading known data-races:
-1) ctypes related races: https://github.com/python/cpython/issues/127945
-2) LLVM related data-races, llvm::raw_ostream is not thread-safe
-- mlir pass manager
-- dialects/transform_interpreter.py
-- ir/diagnostic_handler.py
-- ir/module.py
-3) Dialect gpu module-to-binary method is unsafe
-"""
-import concurrent.futures
-import gc
-import importlib.util
-import os
-import sys
-import threading
-import tempfile
-import unittest
-
-from contextlib import contextmanager
-from functools import partial
-from pathlib import Path
-from typing import Optional
-
-import mlir.dialects.arith as arith
-from mlir.dialects import transform
-from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
-
-
-def import_from_path(module_name: str, file_path: Path):
-    spec = importlib.util.spec_from_file_location(module_name, file_path)
-    module = importlib.util.module_from_spec(spec)
-    sys.modules[module_name] = module
-    spec.loader.exec_module(module)
-    return module
-
-
-def copy_and_update(src_filepath: Path, dst_filepath: Path):
-    # We should remove all calls like `run(testMethod)`
-    with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
-        while True:
-            src_line = reader.readline()
-            if len(src_line) == 0:
-                break
-            skip_lines = [
-                "run(",
-                "@run",
-                "@constructAndPrintInModule",
-                "run_apply_patterns(",
-                "@run_apply_patterns",
-                "@test_in_context",
-                "@construct_and_print_in_module",
-            ]
-            if any(src_line.startswith(line) for line in skip_lines):
-                continue
-            writer.write(src_line)
-
-
-# Helper run functions
-# They are copied from the test modules (e.g. run function in execution_engine.py)
-def run(test_function):
-    # Generic run tests function used by dialects and ir test modules
-    test_function()
-
-
-def run_with_context_and_location(test_function):
-    # run tests function with a context and a location
-    # used by the following test modules:
-    # - dialects/transform_gpu_ext,
-    # - dialects/vector
-    # - dialects/gpu/*
-    with Context(), Location.unknown():
-        test_function()
-    return test_function
-
-
-def run_with_insertion_point_and_context_arg(test_function):
-    # run tests function used by dialects/index_dialect test module
-    with Context() as ctx, Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            test_function(ctx)
-
-
-def run_with_insertion_point(test_function):
-    # Used by a lot of dialects test modules
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            test_function()
-    return test_function
-
-
-def run_with_insertion_point_and_module_arg(test_function):
-    # Used by dialects/transform test module
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            test_function(module)
-    return test_function
-
-
-def run_with_insertion_point_all_unreg_dialects(test_function):
-    # Used by dialects/cf test module
-    with Context() as ctx, Location.unknown():
-        ctx.allow_unregistered_dialects = True
-        module = Module.create()
-        with InsertionPoint(module.body):
-            test_function()
-    return test_function
-
-
-def run_apply_patterns(test_function):
-    # Used by dialects/transform_tensor_ext test module
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            sequence = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                [],
-                transform.AnyOpType.get(),
-            )
-            with InsertionPoint(sequence.body):
-                apply = transform.ApplyPatternsOp(sequence.bodyTarget)
-                with InsertionPoint(apply.patterns):
-                    test_function()
-                transform.YieldOp()
-        print(module)
-    return test_function
-
-
-def run_transform_tensor_ext(test_function):
-    # Used by test modules:
-    # - dialects/transform_gpu_ext
-    # - dialects/transform_sparse_tensor_ext
-    # - dialects/transform_tensor_ext
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            sequence = transform.SequenceOp(
-                transform.FailurePropagationMode.Propagate,
-                [],
-                transform.AnyOpType.get(),
-            )
-            with InsertionPoint(sequence.body):
-                test_function(sequence.bodyTarget)
-                transform.YieldOp()
-        print(module)
-    return test_function
-
-
-def run_transform_structured_ext(test_function):
-    # Used by dialects/transform_structured_ext test module
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            test_function()
-        module.operation.verify()
-        print(module)
-    return test_function
-
-
-def run_construct_and_print_in_module(test_function):
-    # Used by test modules:
-    # - integration/dialects/pdl
-    # - integration/dialects/transform
-    with Context(), Location.unknown():
-        module = Module.create()
-        with InsertionPoint(module.body):
-            module = test_function(module)
-        if module is not None:
-            print(module)
-    return test_function
-
-
-TEST_MODULES = [
-    ("execution_engine", run),
-    ("pass_manager", run),
-    ("dialects/affine", run_with_insertion_point),
-    ("dialects/func", run_with_insertion_point),
-    ("dialects/arith_dialect", run),
-    ("dialects/arith_llvm", run),
-    ("dialects/async_dialect", run),
-    ("dialects/builtin", run),
-    ("dialects/cf", run_with_insertion_point_all_unreg_dialects),
-    ("dialects/complex_dialect", run),
-    ("dialects/func", run_with_insertion_point),
-    ("dialects/index_dialect", run_with_insertion_point_and_context_arg),
-    ("dialects/llvm", run_with_insertion_point),
-    ("dialects/math_dialect", run),
-    ("dialects/memref", run),
-    ("dialects/ml_program", run_with_insertion_point),
-    ("dialects/nvgpu", run_with_insertion_point),
-    ("dialects/nvvm", run_with_insertion_point),
-    ("dialects/ods_helpers", run),
-    ("dialects/openmp_ops", run_with_insertion_point),
-    ("dialects/pdl_ops", run_with_insertion_point),
-    # ("dialects/python_test", run),  # TODO: Need to pass pybind11 or nanobind argv
-    ("dialects/quant", run),
-    ("dialects/rocdl", run_with_insertion_point),
-    ("dialects/scf", run_with_insertion_point),
-    ("dialects/shape", run),
-    ("dialects/spirv_dialect", run),
-    ("dialects/tensor", run),
-    # ("dialects/tosa", ),  # Nothing to test
-    ("dialects/transform_bufferization_ext", run_with_insertion_point),
-    # ("dialects/transform_extras", ),  # Needs a more complicated execution schema
-    ("dialects/transform_gpu_ext", run_transform_tensor_ext),
-    (
-        "dialects/transform_interpreter",
-        run_with_context_and_location,
-        ["print_", "transform_options", "failed", "include"],
-    ),
-    (
-        "dialects/transform_loop_ext",
-        run_with_insertion_point,
-        ["loopOutline"],
-    ),
-    ("dialects/transform_memref_ext", run_with_insertion_point),
-    ("dialects/transform_nvgpu_ext", run_with_insertion_point),
-    ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
-    ("dialects/transform_structured_ext", run_transform_structured_ext),
-    ("dialects/transform_tensor_ext", run_transform_tensor_ext),
-    (
-        "dialects/transform_vector_ext",
-        run_apply_patterns,
-        ["configurable_patterns"],
-    ),
-    ("dialects/transform", run_with_insertion_point_and_module_arg),
-    ("dialects/vector", run_with_context_and_location),
-    ("dialects/gpu/dialect", run_with_context_and_location),
-    ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
-    ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
-    ("dialects/linalg/ops", run),
-    # TO ADD: No proper tests in this dialects/linalg/opsdsl/*
-    # ("dialects/linalg/opsdsl/*", ...),
-    ("dialects/sparse_tensor/dialect", run),
-    ("dialects/sparse_tensor/passes", run),
-    ("integration/dialects/pdl", run_construct_and_print_in_module),
-    ("integration/dialects/transform", run_construct_and_print_in_module),
-    ("integration/dialects/linalg/opsrun", run),
-    ("ir/affine_expr", run),
-    ("ir/affine_map", run),
-    ("ir/array_attributes", run),
-    ("ir/attributes", run),
-    ("ir/blocks", run),
-    ("ir/builtin_types", run),
-    ("ir/context_managers", run),
-    ("ir/debug", run),
-    ("ir/diagnostic_handler", run),
-    ("ir/dialects", run),
-    ("ir/exception", run),
-    ("ir/insertion_point", run),
-    ("ir/integer_set", run),
-    ("ir/location", run),
-    ("ir/module", run),
-    ("ir/operation", run),
-    ("ir/symbol_table", run),
-    ("ir/value", run),
-]
-
-TESTS_TO_SKIP = [
-    "test_execution_engine__testNanoTime_multi_threaded",  # testNanoTime can't run in multiple threads, even with GIL
-    "test_execution_engine__testSharedLibLoad_multi_threaded",  # testSharedLibLoad can't run in multiple threads, even with GIL
-    "test_dialects_arith_dialect__testArithValue_multi_threaded",  # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
-    "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded",  # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
-    "test_ir_value__testValueCasters_multi_threaded",  # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
-    # tests indirectly calling thread-unsafe llvm::raw_ostream
-    "test_execution_engine__testInvalidModule_multi_threaded",  # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
-    "test_pass_manager__testPrintIrAfterAll_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
-    "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded",  # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
-    "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
-    "test_pass_manager__testPrintIrTree_multi_threaded",  # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
-    "test_pass_manager__testRunPipeline_multi_threaded",  # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
-    "test_dialects_transform_interpreter__include_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
-    "test_dialects_transform_interpreter__transform_options_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
-    "test_dialects_transform_interpreter__print_self_multi_threaded",  # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
-    "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded",  # mlirEmitError calls thread-unsafe llvm::raw_ostream
-    "test_ir_module__testParseSuccess_multi_threaded",  # mlirOperationDump calls thread-unsafe llvm::raw_ostream
-    # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
-    # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
-    "test_execution_engine__testCapsule_multi_threaded",
-    "test_execution_engine__testDumpToObjectFile_multi_threaded",
-]
-
-TESTS_TO_XFAIL = [
-    # execution_engine tests:
-    # - ctypes related data-races: https://github.com/python/cpython/issues/127945
-    "test_execution_engine__testBF16Memref_multi_threaded",
-    "test_execution_engine__testBasicCallback_multi_threaded",
-    "test_execution_engine__testComplexMemrefAdd_multi_threaded",
-    "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
-    "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
-    "test_execution_engine__testF16MemrefAdd_multi_threaded",
-    "test_execution_engine__testF8E5M2Memref_multi_threaded",
-    "test_execution_engine__testInvokeFloatAdd_multi_threaded",
-    "test_execution_engine__testInvokeVoid_multi_threaded",  # a ctypes race
-    "test_execution_engine__testMemrefAdd_multi_threaded",
-    "test_execution_engine__testRankedMemRefCallback_multi_threaded",
-    "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
-    "test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
-    "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
-    # dialects tests
-    "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded",  # Related to ctypes data races
-    "test_dialects_transform_interpreter__print_other_multi_threaded",  # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
-    "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded",  # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
-    "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
-    "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
-    "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
-    # integration tests
-    "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded",  # Related to ctypes data races
-    "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded",  # Related to ctypes data races
-    "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded",  # ctypes
-    "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded",  # ctypes
-]
-
-
-def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
-    def decorator(test_cls):
-        this_folder = Path(__file__).parent.absolute()
-        test_cls.output_folder = tempfile.TemporaryDirectory()
-        output_folder = Path(test_cls.output_folder.name)
-
-        for test_mod_info in test_modules:
-            # test_mod_info is a tuple of size 2 or 3:
-            # (test_module_str, run_test_function) or (test_module_str, run_test_function, test_name_patterns_list)
-            # For example:
-            # - ("ir/value", run) or
-            # - ("dialects/transform_loop_ext", run_with_insertion_point, ["loopOutline"])
-            assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
-            if len(test_mod_info) == 2:
-                test_module_name, exec_fn = test_mod_info
-                test_pattern = None
-            else:
-                test_module_name, exec_fn, test_pattern = test_mod_info
-
-            src_filepath = this_folder / f"{test_module_name}.py"
-            dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
-            if not dst_filepath.parent.exists():
-                dst_filepath.parent.mkdir(parents=True)
-            copy_and_update(src_filepath, dst_filepath)
-            test_mod = import_from_path(test_module_name, dst_filepath)
-            for attr_name in dir(test_mod):
-                is_test_fn = test_pattern is None and attr_name.startswith("test")
-                is_test_fn |= test_pattern is not None and any(
-                    [p in attr_name for p in test_pattern]
-                )
-                if is_test_fn:
-                    obj = getattr(test_mod, attr_name)
-                    if callable(obj):
-                        test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
-
-                        def wrapped_test_fn(
-                            self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
-                        ):
-                            __exec_fn__(__test_fn__)
-
-                        setattr(test_cls, test_name, wrapped_test_fn)
-        return test_cls
-
-    return decorator
-
-
- at contextmanager
-def _capture_output(fp):
-    # Inspired from jax test_utils.py capture_stderr method
-    # ``None`` means nothing has not been captured yet.
-    captured = None
-
-    def get_output() -> str:
-        if captured is None:
-            raise ValueError("get_output() called while the context is active.")
-        return captured
-
-    with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
-        original_fd = os.dup(fp.fileno())
-        os.dup2(f.fileno(), fp.fileno())
-        try:
-            yield get_output
-        finally:
-            # Python also has its own buffers, make sure everything is flushed.
-            fp.flush()
-            os.fsync(fp.fileno())
-            f.seek(0)
-            captured = f.read()
-            os.dup2(original_fd, fp.fileno())
-
-
-capture_stdout = partial(_capture_output, sys.stdout)
-capture_stderr = partial(_capture_output, sys.stderr)
-
-
-def multi_threaded(
-    num_workers: int,
-    num_runs: int = 5,
-    skip_tests: Optional[list[str]] = None,
-    xfail_tests: Optional[list[str]] = None,
-    test_prefix: str = "_original_test",
-    multithreaded_test_postfix: str = "_multi_threaded",
-):
-    """Decorator that runs a test in a multi-threaded environment."""
-
-    def decorator(test_cls):
-        for name, test_fn in test_cls.__dict__.copy().items():
-            if not (name.startswith(test_prefix) and callable(test_fn)):
-                continue
-
-            name = f"test{name[len(test_prefix):]}"
-            if skip_tests is not None:
-                if any(
-                    test_name.replace(multithreaded_test_postfix, "") in name
-                    for test_name in skip_tests
-                ):
-                    continue
-
-            def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
-                with capture_stdout(), capture_stderr() as get_output:
-                    barrier = threading.Barrier(num_workers)
-
-                    def closure():
-                        barrier.wait()
-                        for _ in range(num_runs):
-                            __test_fn__(self, *args, **kwargs)
-
-                    with concurrent.futures.ThreadPoolExecutor(
-                        max_workers=num_workers
-                    ) as executor:
-                        futures = []
-                        for _ in range(num_workers):
-                            futures.append(executor.submit(closure))
-                        # We should call future.result() to re-raise an exception if test has
-                        # failed
-                        assert len(list(f.result() for f in futures)) == num_workers
-
-                    gc.collect()
-                    assert Context._get_live_count() == 0
-
-                captured = get_output()
-                if len(captured) > 0 and "ThreadSanitizer" in captured:
-                    raise RuntimeError(
-                        f"ThreadSanitizer reported warnings:\n{captured}"
-                    )
-
-            test_new_name = f"{name}{multithreaded_test_postfix}"
-            if xfail_tests is not None and test_new_name in xfail_tests:
-                multi_threaded_test_fn = unittest.expectedFailure(
-                    multi_threaded_test_fn
-                )
-
-            setattr(test_cls, test_new_name, multi_threaded_test_fn)
-
-        return test_cls
-
-    return decorator
-
-
- at multi_threaded(
-    num_workers=10,
-    num_runs=20,
-    skip_tests=TESTS_TO_SKIP,
-    xfail_tests=TESTS_TO_XFAIL,
-)
- at add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
-class TestAllMultiThreaded(unittest.TestCase):
-    @classmethod
-    def tearDownClass(cls):
-        if hasattr(cls, "output_folder"):
-            cls.output_folder.cleanup()
-
-    def _original_test_create_context(self):
-        with Context() as ctx:
-            print(ctx._get_live_count())
-            print(ctx._get_live_module_count())
-            print(ctx._get_live_operation_count())
-            print(ctx._get_live_operation_objects())
-            print(ctx._get_context_again() is ctx)
-            print(ctx._clear_live_operations())
-
-    def _original_test_create_module_with_consts(self):
-        py_values = [123, 234, 345]
-        with Context() as ctx:
-            module = Module.create(loc=Location.file("foo.txt", 0, 0))
-
-            dtype = IntegerType.get_signless(64)
-            with InsertionPoint(module.body), Location.name("a"):
-                arith.constant(dtype, py_values[0])
-
-            with InsertionPoint(module.body), Location.name("b"):
-                arith.constant(dtype, py_values[1])
-
-            with InsertionPoint(module.body), Location.name("c"):
-                arith.constant(dtype, py_values[2])
-
-
-if __name__ == "__main__":
-    # Do not run the tests on CPython with GIL
-    if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
-        unittest.main()


        


More information about the Mlir-commits mailing list