[Mlir-commits] [mlir] Added FT locks around PyOperation valid flag to prevent a data race (PR #126709)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 11 02:22:13 PST 2025


https://github.com/vfdev-5 created https://github.com/llvm/llvm-project/pull/126709

Description:
- Added FT locks around PyOperation valid flag to prevent a data race
- Added a test

Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da

>From 11af4ceab4cb77cd6f5d18c9e6cbc7cef2717614 Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Tue, 11 Feb 2025 10:20:06 +0000
Subject: [PATCH] Added FT locks around PyOperation valid flag to prevent a
 data race Description: - Added FT locks around PyOperation valid flag to
 prevent a data race - Added a test

Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da
---
 mlir/lib/Bindings/Python/IRCore.cpp     |  3 +-
 mlir/lib/Bindings/Python/IRModule.h     | 14 ++++++---
 mlir/test/python/multithreaded_tests.py | 42 ++++++++++++++++++++++++-
 3 files changed, 53 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 47a85c2a486fd46..f080b79d7a56949 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1279,7 +1279,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
   return PyOperation::createDetached(std::move(contextRef), op);
 }
 
-void PyOperation::checkValid() const {
+void PyOperation::checkValid() {
+  nb::ft_lock_guard lock(validFlagMutex);
   if (!valid) {
     throw std::runtime_error("the operation has been invalidated");
   }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dd6e7ef9123746b..fd2b3a066273c12 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -646,8 +646,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   }
 
   /// Gets the backing operation.
-  operator MlirOperation() const { return get(); }
-  MlirOperation get() const {
+  operator MlirOperation() { return get(); }
+  MlirOperation get() {
     checkValid();
     return operation;
   }
@@ -665,7 +665,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
     assert(attached && "operation already detached");
     attached = false;
   }
-  void checkValid() const;
+  void checkValid();
 
   /// Gets the owning block or raises an exception if the operation has no
   /// owning block.
@@ -700,7 +700,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   void erase();
 
   /// Invalidate the operation.
-  void setInvalid() { valid = false; }
+  void setInvalid() {
+    nanobind::ft_lock_guard lock(validFlagMutex);
+    valid = false;
+  }
 
   /// Clones this operation.
   nanobind::object clone(const nanobind::object &ip);
@@ -726,6 +729,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
 
   friend class PyOperationBase;
   friend class PySymbolTable;
+
+  // FT mutex to protect checkValid() method
+  nanobind::ft_mutex validFlagMutex;
 };
 
 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a66834687291..bc99144ef790157 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -475,7 +475,7 @@ def closure():
 
 
 @multi_threaded(
-    num_workers=10,
+    num_workers=32,
     num_runs=20,
     skip_tests=TESTS_TO_SKIP,
     xfail_tests=TESTS_TO_XFAIL,
@@ -511,6 +511,46 @@ def _original_test_create_module_with_consts(self):
             with InsertionPoint(module.body), Location.name("c"):
                 arith.constant(dtype, py_values[2])
 
+    def test_check_pyoperation_race(self):
+        num_workers = 40
+        num_runs = 20
+
+        barrier = threading.Barrier(num_workers)
+
+        def check_op(op):
+            op_name = op.operation.name
+
+        def walk_operations(op):
+            check_op(op)
+            for region in op.operation.regions:
+                for block in region:
+                    for op in block:
+                        walk_operations(op)
+
+        with Context():
+            mlir_module = Module.parse(
+                """
+    module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+    func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
+        return %arg0 : tensor<f32>
+    }
+    }
+                """
+            )
+
+        def closure():
+            barrier.wait()
+
+            for _ in range(num_runs):
+                walk_operations(mlir_module)
+
+        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+            futures = []
+            for i in range(num_workers):
+                futures.append(executor.submit(closure))
+            assert len(list(f.result() for f in futures)) == num_workers
+
+
 
 if __name__ == "__main__":
     # Do not run the tests on CPython with GIL



More information about the Mlir-commits mailing list