[Mlir-commits] [mlir] Use liveOperationsMutex in ~PyOperation and lock first liveOperationsMutex and then opMutex (PR #130612)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 10 07:22:48 PDT 2025
https://github.com/vfdev-5 created https://github.com/llvm/llvm-project/pull/130612
Description:
TODO: write comprehensive description why we needed this change.
- Use liveOperationsMutex in ~PyOperation and lock first liveOperationsMutex and then opMutex
Data race report: https://gist.github.com/vfdev-5/02bb822a0475d782da60815604ef30da
>From b31cf08684c74756eb9896fc66d2e2a357e57dec Mon Sep 17 00:00:00 2001
From: vfdev-5 <vfdev.5 at gmail.com>
Date: Mon, 24 Feb 2025 14:11:04 +0000
Subject: [PATCH] Use liveOperationsMutex in ~PyOperation and lock first
liveOperationsMutex and then opMutex
---
mlir/lib/Bindings/Python/IRCore.cpp | 122 ++++++++++++++++++++----
mlir/lib/Bindings/Python/IRModule.h | 16 +++-
mlir/test/python/multithreaded_tests.py | 39 ++++++++
3 files changed, 153 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b13a429d4a3c0..517f13df0f978 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -713,20 +713,21 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}
-void PyMlirContext::clearOperation(MlirOperation op) {
- PyOperation *py_op;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- auto it = liveOperations.find(op.ptr);
- if (it == liveOperations.end()) {
- return;
- }
- py_op = it->second.second;
- liveOperations.erase(it);
+void PyMlirContext::_clearOperationLocked(MlirOperation op) {
+ auto it = liveOperations.find(op.ptr);
+ if (it == liveOperations.end()) {
+ return;
}
+ PyOperation *py_op = it->second.second;
+ liveOperations.erase(it);
py_op->setInvalid();
}
+void PyMlirContext::clearOperation(MlirOperation op) {
+ nb::ft_lock_guard lock(liveOperationsMutex);
+ _clearOperationLocked(op);
+}
+
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
typedef struct {
PyOperation &rootOp;
@@ -752,6 +753,22 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
clearOperationsInside(opRef->getOperation());
}
+void _clearOperationAndInsideHelper(
+ PyOperation &op, MlirOperationWalkCallback invalidatingCallback
+) {
+ mlirOperationWalk(op, invalidatingCallback, &op.getContext(), MlirWalkPreOrder);
+}
+
+void PyMlirContext::_clearOperationAndInsideLocked(PyOperationBase &op) {
+ MlirOperationWalkCallback invalidatingCallbackLocked = [](MlirOperation op,
+ void *userData) {
+ PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
+ contextRef->_clearOperationLocked(op);
+ return MlirWalkResult::MlirWalkResultAdvance;
+ };
+ _clearOperationAndInsideHelper(op.getOperation(), invalidatingCallbackLocked);
+}
+
void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
@@ -759,8 +776,7 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
contextRef->clearOperation(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- &op.getOperation().getContext(), MlirWalkPreOrder);
+ _clearOperationAndInsideHelper(op.getOperation(), invalidatingCallback);
}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
@@ -1189,19 +1205,25 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
: BaseContextObject(std::move(contextRef)), operation(operation) {}
PyOperation::~PyOperation() {
- // If the operation has already been invalidated there is nothing to do.
- if (!valid)
- return;
+ // This lock helps to serialize the access to ~PyOperation and PyOperation::forOperation
+ // when we should invalidate existing PyOperation
+ nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
+ {
+ nb::ft_lock_guard lock2(opMutex);
+ if (!valid)
+ return;
+ }
// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
- getContext()->clearOperation(*this);
+ getContext()->_clearOperationLocked(operation);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
- erase();
+ getContext()->_clearOperationAndInsideLocked(*this);
+ mlirOperationDestroy(operation);
}
}
@@ -1234,6 +1256,41 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
return unownedOperation;
}
+
+bool _nb_try_inc_ref(PyObject *obj) {
+ // See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
+ uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
+ local += 1;
+ if (local == 0) {
+ // immortal
+ return true;
+ }
+ if (_Py_IsOwnedByCurrentThread(obj)) {
+ _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
+#ifdef Py_REF_DEBUG
+ _Py_INCREF_IncRefTotal();
+#endif
+ return true;
+ }
+ Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
+ for (;;) {
+ // If the shared refcount is zero and the object is either merged
+ // or may not have weak references, then we cannot incref it.
+ if (shared == 0 || shared == _Py_REF_MERGED) {
+ return false;
+ }
+
+ if (_Py_atomic_compare_exchange_ssize(
+ &obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
+#ifdef Py_REF_DEBUG
+ _Py_INCREF_IncRefTotal();
+#endif
+ return true;
+ }
+ }
+}
+
+
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
@@ -1250,8 +1307,31 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
}
// Use existing.
PyOperation *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+ nb::object pyRef = nb::steal(it->second.first);
+
+ // Check whether pyRef is ongoing to be destroyed such that refcount increment
+ // wont keep it from deletion.
+ // If after incrementing the reference count its value is 1,
+ // it means that python object is under removal and ~PyOperation should be called.
+ // Thus, we should create new PyOperationRef.
+ if (_nb_try_inc_ref(pyRef.ptr())) {
+ return PyOperationRef(existing, std::move(pyRef));
+ }
+
+ // We should lock first liveOperationsMutex and then opMutex.
+ // We need to use existing->opMutex to serialize the
+ // access to ~PyOperation and the code below
+ nb::ft_lock_guard lock2(existing->opMutex);
+
+ // Invalidate existing
+ existing->valid = false;
+
+ // Create.
+ PyOperationRef result = createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
+ liveOperations[operation.ptr] =
+ std::make_pair(result.getObject(), result.get());
+ return result;
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1282,7 +1362,8 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
return PyOperation::createDetached(std::move(contextRef), op);
}
-void PyOperation::checkValid() const {
+void PyOperation::checkValid() {
+ nb::ft_lock_guard lock(opMutex);
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
}
@@ -2305,6 +2386,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
// The operation is also erased, so we must invalidate it. There may be Python
// references to this operation so we don't want to delete it from the list of
// live operations here.
+ nb::ft_lock_guard lock(symbol.getOperation().opMutex);
symbol.getOperation().valid = false;
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index dd6e7ef912374..d8d3afdd25ab3 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -253,6 +253,9 @@ class PyMlirContext {
struct ErrorCapture;
private:
+ void _clearOperationLocked(MlirOperation op);
+ void _clearOperationAndInsideLocked(PyOperationBase &op);
+
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -646,8 +649,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
}
/// Gets the backing operation.
- operator MlirOperation() const { return get(); }
- MlirOperation get() const {
+ operator MlirOperation() { return get(); }
+ MlirOperation get() {
checkValid();
return operation;
}
@@ -665,7 +668,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
assert(attached && "operation already detached");
attached = false;
}
- void checkValid() const;
+ void checkValid();
/// Gets the owning block or raises an exception if the operation has no
/// owning block.
@@ -700,7 +703,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
void erase();
/// Invalidate the operation.
- void setInvalid() { valid = false; }
+ void setInvalid() {
+ nanobind::ft_lock_guard lock(opMutex);
+ valid = false;
+ }
/// Clones this operation.
nanobind::object clone(const nanobind::object &ip);
@@ -724,6 +730,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
bool attached = true;
bool valid = true;
+ nanobind::ft_mutex opMutex;
+
friend class PyOperationBase;
friend class PySymbolTable;
};
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a668346872..75ce3756bbe8d 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -511,6 +511,45 @@ def _original_test_create_module_with_consts(self):
with InsertionPoint(module.body), Location.name("c"):
arith.constant(dtype, py_values[2])
+ def test_check_pyoperation_race(self):
+ num_workers = 40
+ num_runs = 20
+
+ barrier = threading.Barrier(num_workers)
+
+ def check_op(op):
+ op_name = op.operation.name
+
+ def walk_operations(op):
+ check_op(op)
+ for region in op.operation.regions:
+ for block in region:
+ for op in block:
+ walk_operations(op)
+
+ with Context():
+ mlir_module = Module.parse(
+ """
+ module @jit_sin attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
+ func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
+ return %arg0 : tensor<f32>
+ }
+ }
+ """
+ )
+
+ def closure():
+ barrier.wait()
+
+ for _ in range(num_runs):
+ walk_operations(mlir_module)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for i in range(num_workers):
+ futures.append(executor.submit(closure))
+ assert len(list(f.result() for f in futures)) == num_workers
+
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
More information about the Mlir-commits
mailing list