[Mlir-commits] [mlir] e2c49a4 - [mlir python] Add locking around PyMlirContext::liveOperations. (#122720)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 13 07:49:29 PST 2025
Author: Peter Hawkins
Date: 2025-01-13T17:49:25+02:00
New Revision: e2c49a45da31522d91e2e7b12bbc0901b0519384
URL: https://github.com/llvm/llvm-project/commit/e2c49a45da31522d91e2e7b12bbc0901b0519384
DIFF: https://github.com/llvm/llvm-project/commit/e2c49a45da31522d91e2e7b12bbc0901b0519384.diff
LOG: [mlir python] Add locking around PyMlirContext::liveOperations. (#122720)
In JAX, I observed a race between two PyOperation destructors from
different threads updating the same `liveOperations` map, despite not
intentionally sharing the context between different threads. Since I
don't think we can be completely sure when GC happens and on which
thread, it seems safest simply to add locking here.
We may also want to explicitly support sharing a context between threads
in the future, which would require this change or something similar.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 463ebdebb3f3f6..53806ca9f04a49 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() {
return getLiveContexts().size();
}
-size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+size_t PyMlirContext::getLiveOperationCount() {
+ nb::ft_lock_guard lock(liveOperationsMutex);
+ return liveOperations.size();
+}
std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
std::vector<PyOperation *> liveObjects;
+ nb::ft_lock_guard lock(liveOperationsMutex);
for (auto &entry : liveOperations)
liveObjects.push_back(entry.second.second);
return liveObjects;
}
size_t PyMlirContext::clearLiveOperations() {
- for (auto &op : liveOperations)
+
+ LiveOperationMap operations;
+ {
+ nb::ft_lock_guard lock(liveOperationsMutex);
+ std::swap(operations, liveOperations);
+ }
+ for (auto &op : operations)
op.second.second->setInvalid();
- size_t numInvalidated = liveOperations.size();
- liveOperations.clear();
+ size_t numInvalidated = operations.size();
return numInvalidated;
}
void PyMlirContext::clearOperation(MlirOperation op) {
- auto it = liveOperations.find(op.ptr);
- if (it != liveOperations.end()) {
- it->second.second->setInvalid();
+ PyOperation *py_op;
+ {
+ nb::ft_lock_guard lock(liveOperationsMutex);
+ auto it = liveOperations.find(op.ptr);
+ if (it == liveOperations.end()) {
+ return;
+ }
+ py_op = it->second.second;
liveOperations.erase(it);
}
+ py_op->setInvalid();
}
void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() {
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- auto &liveOperations = contextRef->liveOperations;
// Create.
PyOperation *unownedOperation =
new PyOperation(std::move(contextRef), operation);
@@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
- liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
return PyOperationRef(unownedOperation, std::move(pyRef));
}
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
+ nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
// Create.
- return createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
+ PyOperationRef result = createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
+ liveOperations[operation.ptr] =
+ std::make_pair(result.getObject(), result.get());
+ return result;
}
// Use existing.
PyOperation *existing = it->second.second;
@@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
+ nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
assert(liveOperations.count(operation.ptr) == 0 &&
"cannot create detached operation that already exists");
(void)liveOperations;
-
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
+ liveOperations[operation.ptr] =
+ std::make_pair(created.getObject(), created.get());
created->attached = false;
return created;
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f5fbb6c61b57e2..d1fb4308dbb77c 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -277,6 +277,9 @@ class PyMlirContext {
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
+ nanobind::ft_mutex liveOperationsMutex;
+
+ // Guarded by liveOperationsMutex in free-threading mode.
LiveOperationMap liveOperations;
bool emitErrorDiagnostics = false;
More information about the Mlir-commits
mailing list