[Mlir-commits] [mlir] Use liveOperationsMutex in ~PyOperation and lock first liveOperationsMutex and then opMutex (PR #130612)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 10 07:22:48 PDT 2025


https://github.com/vfdev-5 created https://github.com/llvm/llvm-project/pull/130612

Description:

TODO: write comprehensive description why we needed this change.

- Use liveOperationsMutex in ~PyOperation and lock first liveOperationsMutex and then opMutex

Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da


>From b31cf08684c74756eb9896fc66d2e2a357e57dec Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Mon, 24 Feb 2025 14:11:04 +0000
Subject: [PATCH] Use liveOperationsMutex in ~PyOperation and lock first
 liveOperationsMutex and then opMutex

---
 mlir/lib/Bindings/Python/IRCore.cpp     | 122 ++++++++++++++++++++----
 mlir/lib/Bindings/Python/IRModule.h     |  16 +++-
 mlir/test/python/multithreaded_tests.py |  39 ++++++++
 3 files changed, 153 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b13a429d4a3c0..517f13df0f978 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -713,20 +713,21 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
-void PyMlirContext::clearOperation(MlirOperation op) {
-  PyOperation *py_op;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    auto it = liveOperations.find(op.ptr);
-    if (it == liveOperations.end()) {
-      return;
-    }
-    py_op = it->second.second;
-    liveOperations.erase(it);
+void PyMlirContext::_clearOperationLocked(MlirOperation op) {
+  auto it = liveOperations.find(op.ptr);
+  if (it == liveOperations.end()) {
+    return;
   }
+  PyOperation *py_op = it->second.second;
+  liveOperations.erase(it);
   py_op->setInvalid();
 }
 
+void PyMlirContext::clearOperation(MlirOperation op) {
+  nb::ft_lock_guard lock(liveOperationsMutex);
+  _clearOperationLocked(op);
+}
+
 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
   typedef struct {
     PyOperation &rootOp;
@@ -752,6 +753,22 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
   clearOperationsInside(opRef->getOperation());
 }
 
+void _clearOperationAndInsideHelper(
+  PyOperation &op, MlirOperationWalkCallback invalidatingCallback
+) {
+  mlirOperationWalk(op, invalidatingCallback, &op.getContext(), MlirWalkPreOrder);
+}
+
+void PyMlirContext::_clearOperationAndInsideLocked(PyOperationBase &op) {
+  MlirOperationWalkCallback invalidatingCallbackLocked = [](MlirOperation op,
+                                                      void *userData) {
+    PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
+    contextRef->_clearOperationLocked(op);
+    return MlirWalkResult::MlirWalkResultAdvance;
+  };
+  _clearOperationAndInsideHelper(op.getOperation(), invalidatingCallbackLocked);
+}
+
 void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
   MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
                                                       void *userData) {
@@ -759,8 +776,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
     contextRef->clearOperation(op);
     return MlirWalkResult::MlirWalkResultAdvance;
   };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    &op.getOperation().getContext(), MlirWalkPreOrder);
+  _clearOperationAndInsideHelper(op.getOperation(), invalidatingCallback);
 }
 
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
@@ -1189,19 +1205,25 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
     : BaseContextObject(std::move(contextRef)), operation(operation) {}
 
 PyOperation::~PyOperation() {
-  // If the operation has already been invalidated there is nothing to do.
-  if (!valid)
-    return;
+  // This lock helps to serialize the access to ~PyOperation and PyOperation::forOperation
+  // when we should invalidate existing PyOperation
+  nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
+  {
+    nb::ft_lock_guard lock2(opMutex);
+    if (!valid)
+      return;
+  }
 
   // Otherwise, invalidate the operation and remove it from live map when it is
   // attached.
   if (isAttached()) {
-    getContext()->clearOperation(*this);
+    getContext()->_clearOperationLocked(operation);
   } else {
     // And destroy it when it is detached, i.e. owned by Python, in which case
     // all nested operations must be invalidated at removed from the live map as
     // well.
-    erase();
+    getContext()->_clearOperationAndInsideLocked(*this);
+    mlirOperationDestroy(operation);
   }
 }
 
@@ -1234,6 +1256,41 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
   return unownedOperation;
 }
 
+
+bool _nb_try_inc_ref(PyObject *obj) {
+    // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
+    uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
+    local += 1;
+    if (local == 0) {
+        // immortal
+        return true;
+    }
+    if (_Py_IsOwnedByCurrentThread(obj)) {
+        _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
+#ifdef Py_REF_DEBUG
+        _Py_INCREF_IncRefTotal();
+#endif
+        return true;
+    }
+    Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
+    for (;;) {
+        // If the shared refcount is zero and the object is either merged
+        // or may not have weak references, then we cannot incref it.
+        if (shared == 0 || shared == _Py_REF_MERGED) {
+            return false;
+        }
+
+        if (_Py_atomic_compare_exchange_ssize(
+                &obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
+#ifdef Py_REF_DEBUG
+            _Py_INCREF_IncRefTotal();
+#endif
+            return true;
+        }
+    }
+}
+
+
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
@@ -1250,8 +1307,31 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
   }
   // Use existing.
   PyOperation *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyOperationRef(existing, std::move(pyRef));
+  nb::object pyRef = nb::steal(it->second.first);
+
+  // Check whether pyRef is ongoing to be destroyed such that refcount increment
+  // wont keep it from deletion.
+  // If after incrementing the reference count its value is 1,
+  // it means that python object is under removal and ~PyOperation should be called.
+  // Thus, we should create new PyOperationRef.
+  if (_nb_try_inc_ref(pyRef.ptr())) {
+    return PyOperationRef(existing, std::move(pyRef));
+  }
+
+  // We should lock first liveOperationsMutex and then opMutex.
+  // We need to use existing->opMutex to serialize the
+  // access to ~PyOperation and the code below
+  nb::ft_lock_guard lock2(existing->opMutex);
+
+  // Invalidate existing
+  existing->valid = false;
+
+  // Create.
+  PyOperationRef result = createInstance(std::move(contextRef), operation,
+                                          std::move(parentKeepAlive));
+  liveOperations[operation.ptr] =
+      std::make_pair(result.getObject(), result.get());
+  return result;
 }
 
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1282,7 +1362,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
   return PyOperation::createDetached(std::move(contextRef), op);
 }
 
-void PyOperation::checkValid() const {
+void PyOperation::checkValid() {
+  nb::ft_lock_guard lock(opMutex);
   if (!valid) {
     throw std::runtime_error("the operation has been invalidated");
   }
@@ -2305,6 +2386,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
   // The operation is also erased, so we must invalidate it. There may be Python
   // references to this operation so we don't want to delete it from the list of
   // live operations here.
+  nb::ft_lock_guard lock(symbol.getOperation().opMutex);
   symbol.getOperation().valid = false;
 }
 
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dd6e7ef912374..d8d3afdd25ab3 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -253,6 +253,9 @@ class PyMlirContext {
   struct ErrorCapture;
 
 private:
+  void _clearOperationLocked(MlirOperation op);
+  void _clearOperationAndInsideLocked(PyOperationBase &op);
+
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
   // preserving the relationship that an MlirContext maps to a single
   // PyMlirContext wrapper. This could be replaced in the future with an
@@ -646,8 +649,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   }
 
   /// Gets the backing operation.
-  operator MlirOperation() const { return get(); }
-  MlirOperation get() const {
+  operator MlirOperation() { return get(); }
+  MlirOperation get() {
     checkValid();
     return operation;
   }
@@ -665,7 +668,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
     assert(attached && "operation already detached");
     attached = false;
   }
-  void checkValid() const;
+  void checkValid();
 
   /// Gets the owning block or raises an exception if the operation has no
   /// owning block.
@@ -700,7 +703,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   void erase();
 
   /// Invalidate the operation.
-  void setInvalid() { valid = false; }
+  void setInvalid() {
+    nanobind::ft_lock_guard lock(opMutex);
+    valid = false;
+  }
 
   /// Clones this operation.
   nanobind::object clone(const nanobind::object &ip);
@@ -724,6 +730,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   bool attached = true;
   bool valid = true;
 
+  nanobind::ft_mutex opMutex;
+
   friend class PyOperationBase;
   friend class PySymbolTable;
 };
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a668346872..75ce3756bbe8d 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -511,6 +511,45 @@ def _original_test_create_module_with_consts(self):
             with InsertionPoint(module.body), Location.name("c"):
                 arith.constant(dtype, py_values[2])
 
+    def test_check_pyoperation_race(self):
+        num_workers = 40
+        num_runs = 20
+
+        barrier = threading.Barrier(num_workers)
+
+        def check_op(op):
+            op_name = op.operation.name
+
+        def walk_operations(op):
+            check_op(op)
+            for region in op.operation.regions:
+                for block in region:
+                    for op in block:
+                        walk_operations(op)
+
+        with Context():
+            mlir_module = Module.parse(
+                """
+    module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+    func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
+        return %arg0 : tensor<f32>
+    }
+    }
+                """
+            )
+
+        def closure():
+            barrier.wait()
+
+            for _ in range(num_runs):
+                walk_operations(mlir_module)
+
+        with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+            futures = []
+            for i in range(num_workers):
+                futures.append(executor.submit(closure))
+            assert len(list(f.result() for f in futures)) == num_workers
+
 
 if __name__ == "__main__":
     # Do not run the tests on CPython with GIL



More information about the Mlir-commits mailing list