[Mlir-commits] [mlir] Added FT locks around PyOperation valid flag to prevent a data race (PR #126709)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 11 02:22:13 PST 2025
https://github.com/vfdev-5 created https://github.com/llvm/llvm-project/pull/126709
Description:
- Added FT locks around PyOperation valid flag to prevent a data race
- Added a test
Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da
>From 11af4ceab4cb77cd6f5d18c9e6cbc7cef2717614 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Tue, 11 Feb 2025 10:20:06 +0000
Subject: [PATCH] Added FT locks around PyOperation valid flag to prevent a
data race Description: - Added FT locks around PyOperation valid flag to
prevent a data race - Added a test
Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da
---
mlir/lib/Bindings/Python/IRCore.cpp | 3 +-
mlir/lib/Bindings/Python/IRModule.h | 14 ++++++---
mlir/test/python/multithreaded_tests.py | 42 ++++++++++++++++++++++++-
3 files changed, 53 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 47a85c2a486fd46..f080b79d7a56949 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1279,7 +1279,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
-void PyOperation::checkValid() const {
+void PyOperation::checkValid() {
+ nb::ft_lock_guard lock(validFlagMutex);
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dd6e7ef9123746b..fd2b3a066273c12 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -646,8 +646,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
}
/// Gets the backing operation.
- operator MlirOperation() const { return get(); }
- MlirOperation get() const {
+ operator MlirOperation() { return get(); }
+ MlirOperation get() {
checkValid();
return operation;
}
@@ -665,7 +665,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
assert(attached && "operation already detached");
attached = false;
}
- void checkValid() const;
+ void checkValid();
/// Gets the owning block or raises an exception if the operation has no
/// owning block.
@@ -700,7 +700,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void erase();
/// Invalidate the operation.
- void setInvalid() { valid = false; }
+ void setInvalid() {
+ nanobind::ft_lock_guard lock(validFlagMutex);
+ valid = false;
+ }
/// Clones this operation.
nanobind::object clone(const nanobind::object &ip);
@@ -726,6 +729,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
friend class PyOperationBase;
friend class PySymbolTable;
+
+ // FT mutex to protect checkValid() method
+ nanobind::ft_mutex validFlagMutex;
};
/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a66834687291..bc99144ef790157 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -475,7 +475,7 @@ def closure():
@multi_threaded(
- num_workers=10,
+ num_workers=32,
num_runs=20,
skip_tests=TESTS_TO_SKIP,
xfail_tests=TESTS_TO_XFAIL,
@@ -511,6 +511,46 @@ def _original_test_create_module_with_consts(self):
with InsertionPoint(module.body), Location.name("c"):
arith.constant(dtype, py_values[2])
+ def test_check_pyoperation_race(self):
+ num_workers = 40
+ num_runs = 20
+
+ barrier = threading.Barrier(num_workers)
+
+ def check_op(op):
+ op_name = op.operation.name
+
+ def walk_operations(op):
+ check_op(op)
+ for region in op.operation.regions:
+ for block in region:
+ for op in block:
+ walk_operations(op)
+
+ with Context():
+ mlir_module = Module.parse(
+ """
+ module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
+ return %arg0 : tensor<f32>
+ }
+ }
+ """
+ )
+
+ def closure():
+ barrier.wait()
+
+ for _ in range(num_runs):
+ walk_operations(mlir_module)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for i in range(num_workers):
+ futures.append(executor.submit(closure))
+ assert len(list(f.result() for f in futures)) == num_workers
+
+
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
More information about the Mlir-commits
mailing list