[Mlir-commits] [mlir] Added FT locks around PyOperation valid flag to prevent a data race (PR #126709)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 02:22:48 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: vfdev (vfdev-5)
<details>
<summary>Changes</summary>
Description:
- Added FT locks around PyOperation valid flag to prevent a data race
- Added a test
cc @<!-- -->hawkinsp
Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da
---
Full diff: https://github.com/llvm/llvm-project/pull/126709.diff
3 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+2-1)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+10-4)
- (modified) mlir/test/python/multithreaded_tests.py (+41-1)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 47a85c2a486fd46..f080b79d7a56949 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1279,7 +1279,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
-void PyOperation::checkValid() const {
+void PyOperation::checkValid() {
+ nb::ft_lock_guard lock(validFlagMutex);
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dd6e7ef9123746b..fd2b3a066273c12 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -646,8 +646,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
}
/// Gets the backing operation.
- operator MlirOperation() const { return get(); }
- MlirOperation get() const {
+ operator MlirOperation() { return get(); }
+ MlirOperation get() {
checkValid();
return operation;
}
@@ -665,7 +665,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
assert(attached && "operation already detached");
attached = false;
}
- void checkValid() const;
+ void checkValid();
/// Gets the owning block or raises an exception if the operation has no
/// owning block.
@@ -700,7 +700,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void erase();
/// Invalidate the operation.
- void setInvalid() { valid = false; }
+ void setInvalid() {
+ nanobind::ft_lock_guard lock(validFlagMutex);
+ valid = false;
+ }
/// Clones this operation.
nanobind::object clone(const nanobind::object &ip);
@@ -726,6 +729,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
friend class PyOperationBase;
friend class PySymbolTable;
+
+ // FT mutex to protect checkValid() method
+ nanobind::ft_mutex validFlagMutex;
};
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a66834687291..bc99144ef790157 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -475,7 +475,7 @@ def closure():
@multi_threaded(
- num_workers=10,
+ num_workers=32,
num_runs=20,
skip_tests=TESTS_TO_SKIP,
xfail_tests=TESTS_TO_XFAIL,
@@ -511,6 +511,46 @@ def _original_test_create_module_with_consts(self):
with InsertionPoint(module.body), Location.name("c"):
arith.constant(dtype, py_values[2])
+ def test_check_pyoperation_race(self):
+ num_workers = 40
+ num_runs = 20
+
+ barrier = threading.Barrier(num_workers)
+
+ def check_op(op):
+ op_name = op.operation.name
+
+ def walk_operations(op):
+ check_op(op)
+ for region in op.operation.regions:
+ for block in region:
+ for op in block:
+ walk_operations(op)
+
+ with Context():
+ mlir_module = Module.parse(
+ """
+ module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
+ return %arg0 : tensor<f32>
+ }
+ }
+ """
+ )
+
+ def closure():
+ barrier.wait()
+
+ for _ in range(num_runs):
+ walk_operations(mlir_module)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for i in range(num_workers):
+ futures.append(executor.submit(closure))
+ assert len(list(f.result() for f in futures)) == num_workers
+
+
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
``````````
</details>
https://github.com/llvm/llvm-project/pull/126709
More information about the Mlir-commits
mailing list