[Mlir-commits] [mlir] [MLIR:Python] Fix race on PyOperations. (PR #139721)
Peter Hawkins
llvmlistbot at llvm.org
Wed May 14 08:36:16 PDT 2025
https://github.com/hawkinsp updated https://github.com/llvm/llvm-project/pull/139721
>From 0edf176e417bd7af95fb6227290fdfe3ee1496d8 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins at google.com>
Date: Tue, 13 May 2025 01:52:45 +0000
Subject: [PATCH] [MLIR:Python] Fix race on PyOperations.
Joint work with @vfdev-5
We found the following TSAN race report in JAX's CI:
https://github.com/jax-ml/jax/issues/28551
```
WARNING: ThreadSanitizer: data race (pid=35893)
Read of size 1 at 0x7fffca320cb9 by thread T57 (mutexes: read M0):
#0 mlir::python::PyOperation::checkValid() const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1300:8 (libjax_common.so+0x41e8b1d) (BuildId: 55242ad732cdae54)
#1 mlir::python::populateIRCore(nanobind::module_&)::$_57::operator()(mlir::python::PyOperationBase&) const /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:3221:40 (libjax_common.so+0x41e8b1d)
#2 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::operator()(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) const /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:275:24 (libjax_common.so+0x41e8b1d)
#3 _object* nanobind::detail::func_create<true, true, mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef, mlir::python::PyOperationBase&, 0ul, nanobind::is_method, nanobind::is_getter, nanobind::rv_policy>(mlir::python::populateIRCore(nanobind::module_&)::$_57&, MlirStringRef (*)(mlir::python::PyOperationBase&), std::integer_sequence<unsigned long, 0ul>, nanobind::is_method const&, nanobind::is_getter const&, nanobind::rv_policy const&)::'lambda'(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*)::__invoke(void*, _object**, unsigned char*, nanobind::rv_policy, nanobind::detail::cleanup_list*) /proc/self/cwd/external/nanobind/include/nanobind/nb_func.h:219:14 (libjax_common.so+0x41e8b1d)
...
Previous write of size 1 at 0x7fffca320cb9 by thread T56 (mutexes: read M0):
#0 mlir::python::PyOperation::setInvalid() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRModule.h:729:29 (libjax_common.so+0x419f012) (BuildId: 55242ad732cdae54)
#1 mlir::python::PyMlirContext::clearOperation(MlirOperation) /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:741:10 (libjax_common.so+0x419f012)
#2 mlir::python::PyOperation::~PyOperation() /proc/self/cwd/external/llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:1213:19 (libjax_common.so+0x41a414b) (BuildId: 55242ad732cdae54)
#3 void nanobind::detail::wrap_destruct<mlir::python::PyOperation>(void*) /proc/self/cwd/external/nanobind/include/nanobind/nb_class.h:245:21 (libjax_common.so+0x41ecf21) (BuildId: 55242ad732cdae54)
#4 nanobind::detail::inst_dealloc(_object*) /proc/self/cwd/external/nanobind/src/nb_type.cpp:255:13 (libjax_common.so+0x3284136) (BuildId: 55242ad732cdae54)
#5 _Py_Dealloc /project/cpython/Objects/object.c:3025:5 (python3.14+0x2a2422) (BuildId: 6051e096a967bdf49efb15da94a67d8eff710a9b)
#6 _Py_MergeZeroLocalRefcount /project/cpython/Objects/object.c (python3.14+0x2a2422)
#7 Py_DECREF(_object*) /proc/self/cwd/external/python_x86_64-unknown-linux-gnu-freethreaded/include/python3.14t/refcount.h:387:13 (libjax_common.so+0x41aaadc) (BuildId: 55242ad732cdae54)
...
```
At the simplest level, the `valid` field of a PyOperation must be
protected by a lock, because it may be concurrently accessed from
multiple threads. Much more interesting, however is how we get into the
situation described by the two stack traces above in the first place.
The scenario that triggers this is the following:
* thread T56 holds the last Python reference on a PyOperation, and
decides to release it.
* After T56 starts to release its reference, but before T56 removes the
PyOperation from the liveOperations map a second thread T57 comes
along and looks up the same MlirOperation in the liveOperations map.
* Finding the operation to be present, thread T57 increments the
reference count of that PyOperation and returns it to the caller.
This is illegal! Python is in the process of calling the destructor of
that object, and once an object is in that state it cannot be safely
revived.
To fix this, whenever we increment the reference count of a PyOperation
that we found via the liveOperations map and to which we only hold a
non-owning reference, we must use the Python 3.14+ API
`PyUnstable_TryIncRef`, which exists precisely for this
scenario (https://github.com/python/cpython/issues/128844). That API
does not exist under Python 3.13, so we need a backport of it in that
case, for which we the backport that both nanobind and pybind11 also
use.
Fixes https://github.com/jax-ml/jax/issues/28551
---
mlir/lib/Bindings/Python/IRCore.cpp | 165 +++++++++++++++++++-----
mlir/lib/Bindings/Python/IRModule.h | 35 +++--
mlir/test/python/multithreaded_tests.py | 46 +++++++
3 files changed, 203 insertions(+), 43 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b5720b7ad8b21..a5cc7deb021a7 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,6 +635,75 @@ class PyOpOperandIterator {
MlirOpOperand opOperand;
};
+#if !defined(Py_GIL_DISABLED)
+inline void enableTryIncRef(nb::handle obj) noexcept {}
+inline bool tryIncRef(nb::handle obj) noexcept {
+ if (Py_REFCNT(obj.ptr()) > 0) {
+ Py_INCREF(obj.ptr());
+ return true;
+ }
+ return false;
+}
+
+#elif PY_VERSION_HEX >= 0x030E00A5
+
+// CPython 3.14 provides an unstable API for these.
+inline void enableTryIncRef(nb::handle obj) noexcept {
+ PyUnstable_EnableTryIncRef(obj.ptr());
+}
+inline bool tryIncRef(nb::handle obj) noexcept {
+ return PyUnstable_TryIncRef(obj.ptr());
+}
+
+#else
+
+// For CPython 3.13 there is no API for this, and so we must implement our own.
+// This code originates from https://github.com/wjakob/nanobind/pull/865/files.
+void enableTryIncRef(nb::handle h) noexcept {
+ // Since this is called during object construction, we know that we have
+ // the only reference to the object and can use a non-atomic write.
+ PyObject *obj = h.ptr();
+ assert(h->ob_ref_shared == 0);
+ h->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
+}
+
+bool tryIncRef(nb::handle h) noexcept {
+ PyObject *obj = h.ptr();
+ // See
+ // https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
+ uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
+ local += 1;
+ if (local == 0) {
+ // immortal
+ return true;
+ }
+ if (_Py_IsOwnedByCurrentThread(obj)) {
+ _Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
+#ifdef Py_REF_DEBUG
+ _Py_INCREF_IncRefTotal();
+#endif
+ return true;
+ }
+ Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
+ for (;;) {
+ // If the shared refcount is zero and the object is either merged
+ // or may not have weak references, then we cannot incref it.
+ if (shared == 0 || shared == _Py_REF_MERGED) {
+ return false;
+ }
+
+ if (_Py_atomic_compare_exchange_ssize(&obj->ob_ref_shared, &shared,
+ shared +
+ (1 << _Py_REF_SHARED_SHIFT))) {
+#ifdef Py_REF_DEBUG
+ _Py_INCREF_IncRefTotal();
+#endif
+ return true;
+ }
+ }
+}
+#endif
+
} // namespace
//------------------------------------------------------------------------------
@@ -706,11 +775,17 @@ size_t PyMlirContext::getLiveOperationCount() {
return liveOperations.size();
}
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
- std::vector<PyOperation *> liveObjects;
+std::vector<nb::object> PyMlirContext::getLiveOperationObjects() {
+ std::vector<nb::object> liveObjects;
nb::ft_lock_guard lock(liveOperationsMutex);
- for (auto &entry : liveOperations)
- liveObjects.push_back(entry.second.second);
+ for (auto &entry : liveOperations) {
+ // It is not safe to unconditionally increment the reference count here
+ // because an operation that is in the process of being deleted by another
+ // thread may still be present in the map.
+ if (tryIncRef(entry.second.first)) {
+ liveObjects.push_back(nb::steal(entry.second.first));
+ }
+ }
return liveObjects;
}
@@ -720,25 +795,26 @@ size_t PyMlirContext::clearLiveOperations() {
{
nb::ft_lock_guard lock(liveOperationsMutex);
std::swap(operations, liveOperations);
+ for (auto &op : operations)
+ op.second.second->setInvalidLocked();
}
- for (auto &op : operations)
- op.second.second->setInvalid();
size_t numInvalidated = operations.size();
return numInvalidated;
}
-void PyMlirContext::clearOperation(MlirOperation op) {
- PyOperation *py_op;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- auto it = liveOperations.find(op.ptr);
- if (it == liveOperations.end()) {
- return;
- }
- py_op = it->second.second;
- liveOperations.erase(it);
+void PyMlirContext::clearOperationLocked(MlirOperation op) {
+ auto it = liveOperations.find(op.ptr);
+ if (it == liveOperations.end()) {
+ return;
}
- py_op->setInvalid();
+ PyOperation *py_op = it->second.second;
+ py_op->setInvalidLocked();
+ liveOperations.erase(it);
+}
+
+void PyMlirContext::clearOperation(MlirOperation op) {
+ nb::ft_lock_guard lock(liveOperationsMutex);
+ clearOperationLocked(op);
}
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -766,11 +842,11 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
clearOperationsInside(opRef->getOperation());
}
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
+void PyMlirContext::clearOperationAndInsideLocked(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
- contextRef->clearOperation(op);
+ contextRef->clearOperationLocked(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
@@ -1203,6 +1279,8 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
: BaseContextObject(std::move(contextRef)), operation(operation) {}
PyOperation::~PyOperation() {
+ PyMlirContextRef context = getContext();
+ nb::ft_lock_guard lock(context->liveOperationsMutex);
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
@@ -1210,12 +1288,14 @@ PyOperation::~PyOperation() {
// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
- getContext()->clearOperation(*this);
+ // Since the operation was valid, we know that it is this object present
+ // in the map, not some other object.
+ context->liveOperations.erase(operation.ptr);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
- erase();
+ eraseLocked();
}
}
@@ -1241,6 +1321,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
// Create.
PyOperationRef unownedOperation =
makeObjectRef<PyOperation>(std::move(contextRef), operation);
+ enableTryIncRef(unownedOperation.getObject());
unownedOperation->handle = unownedOperation.getObject();
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
@@ -1254,18 +1335,26 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
- if (it == liveOperations.end()) {
- // Create.
- PyOperationRef result = createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(result.getObject(), result.get());
- return result;
+ if (it != liveOperations.end()) {
+ PyOperation *existing = it->second.second;
+ nb::handle pyRef = it->second.first;
+
+ // Try to increment the reference count of the existing entry. This can fail
+ // if the object is in the process of being destroyed by another thread.
+ if (tryIncRef(pyRef)) {
+ return PyOperationRef(existing, nb::steal<nb::object>(pyRef));
+ }
+
+ // Mark the existing entry as invalid, since we are about to replace it.
+ existing->setInvalidLocked();
}
- // Use existing.
- PyOperation *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+
+ // Create a new wrapper object.
+ PyOperationRef result = createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
+ liveOperations[operation.ptr] =
+ std::make_pair(result.getObject(), result.get());
+ return result;
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
@@ -1297,6 +1386,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef,
}
void PyOperation::checkValid() const {
+ nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
if (!valid) {
throw std::runtime_error("the operation has been invalidated");
}
@@ -1638,12 +1728,17 @@ nb::object PyOperation::createOpView() {
return nb::cast(PyOpView(getRef().getObject()));
}
-void PyOperation::erase() {
+void PyOperation::eraseLocked() {
checkValid();
- getContext()->clearOperationAndInside(*this);
+ getContext()->clearOperationAndInsideLocked(*this);
mlirOperationDestroy(operation);
}
+void PyOperation::erase() {
+ nb::ft_lock_guard lock(getContext()->liveOperationsMutex);
+ eraseLocked();
+}
+
namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2324,7 +2419,7 @@ void PySymbolTable::erase(PyOperationBase &symbol) {
// The operation is also erased, so we must invalidate it. There may be Python
// references to this operation so we don't want to delete it from the list of
// live operations here.
- symbol.getOperation().valid = false;
+ symbol.getOperation().setInvalid();
}
void PySymbolTable::dunderDel(const std::string &name) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9befcce725bb7..36f97001bfd43 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -83,7 +83,7 @@ class PyObjectRef {
}
T *get() { return referrent; }
- T *operator->() {
+ T *operator->() const {
assert(referrent && object);
return referrent;
}
@@ -229,7 +229,7 @@ class PyMlirContext {
static size_t getLiveCount();
/// Get a list of Python objects which are still in the live context map.
- std::vector<PyOperation *> getLiveOperationObjects();
+ std::vector<nanobind::object> getLiveOperationObjects();
/// Gets the count of live operations associated with this context.
/// Used for testing.
@@ -254,9 +254,10 @@ class PyMlirContext {
void clearOperationsInside(PyOperationBase &op);
void clearOperationsInside(MlirOperation op);
- /// Clears the operaiton _and_ all operations inside using
- /// `clearOperation(MlirOperation)`.
- void clearOperationAndInside(PyOperationBase &op);
+ /// Clears the operation _and_ all operations inside using
+ /// `clearOperation(MlirOperation)`. Requires that liveOperations mutex is
+ /// held.
+ void clearOperationAndInsideLocked(PyOperationBase &op);
/// Gets the count of live modules associated with this context.
/// Used for testing.
@@ -278,6 +279,9 @@ class PyMlirContext {
struct ErrorCapture;
private:
+ // Similar to clearOperation, but requires the liveOperations mutex to be held
+ void clearOperationLocked(MlirOperation op);
+
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
@@ -302,6 +306,9 @@ class PyMlirContext {
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
+
+ // liveOperationsMutex guards both liveOperations and the valid field of
+ // PyOperation objects in free-threading mode.
nanobind::ft_mutex liveOperationsMutex;
// Guarded by liveOperationsMutex in free-threading mode.
@@ -336,6 +343,7 @@ class BaseContextObject {
}
/// Accesses the context reference.
+ const PyMlirContextRef &getContext() const { return contextRef; }
PyMlirContextRef &getContext() { return contextRef; }
private:
@@ -725,12 +733,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// parent context's live operations map, and sets the valid bit false.
void erase();
- /// Invalidate the operation.
- void setInvalid() { valid = false; }
-
/// Clones this operation.
nanobind::object clone(const nanobind::object &ip);
+ /// Invalidate the operation.
+ void setInvalid() {
+ nanobind::ft_lock_guard lock(getContext()->liveOperationsMutex);
+ setInvalidLocked();
+ }
+ /// Like setInvalid(), but requires the liveOperations mutex to be held.
+ void setInvalidLocked() { valid = false; }
+
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
private:
@@ -738,6 +751,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
MlirOperation operation,
nanobind::object parentKeepAlive);
+ // Like erase(), but requires the caller to hold the liveOperationsMutex.
+ void eraseLocked();
+
MlirOperation operation;
nanobind::handle handle;
// Keeps the parent alive, regardless of whether it is an Operation or
@@ -748,6 +764,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
// ir_operation.py regarding testing corresponding lifetime guarantees.
nanobind::object parentKeepAlive;
bool attached = true;
+
+ // Guarded by 'context->liveOperationsMutex'. Valid objects must be present
+ // in context->liveOperations.
bool valid = true;
friend class PyOperationBase;
diff --git a/mlir/test/python/multithreaded_tests.py b/mlir/test/python/multithreaded_tests.py
index 6e1a668346872..0c74e6c5d74f4 100644
--- a/mlir/test/python/multithreaded_tests.py
+++ b/mlir/test/python/multithreaded_tests.py
@@ -40,6 +40,7 @@
import importlib.util
import os
import sys
+import textwrap
import threading
import tempfile
import unittest
@@ -512,6 +513,51 @@ def _original_test_create_module_with_consts(self):
arith.constant(dtype, py_values[2])
+ def test_check_pyoperation_race(self):
+ # Regression test for a race where:
+ # * one thread is in the process of destroying a PyOperation,
+ # * while simultaneously another thread looks up the PyOperation is
+ # the liveOperations map and attempts to increase its reference count.
+ # It is illegal to attempt to revive an object that is in the process of
+ # being deleted, and this was producing races and heap use-after-frees.
+ num_workers = 40
+ num_runs = 20
+
+ barrier = threading.Barrier(num_workers)
+
+ def walk_operations(op):
+ _ = op.operation.name
+ for region in op.operation.regions:
+ for block in region:
+ for op in block:
+ walk_operations(op)
+
+ with Context():
+ mlir_module = Module.parse(
+ textwrap.dedent(
+ """
+ module @m {
+ func.func public @main(%arg0: tensor<f32>) -> (tensor<f32>) {
+ return %arg0 : tensor<f32>
+ }
+ }
+ """
+ )
+ )
+
+ def closure():
+ barrier.wait()
+
+ for _ in range(num_runs):
+ walk_operations(mlir_module)
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
+ futures = []
+ for i in range(num_workers):
+ futures.append(executor.submit(closure))
+ assert len(list(f.result() for f in futures)) == num_workers
+
+
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
More information about the Mlir-commits
mailing list