[Mlir-commits] [mlir] [mlir python] Add locking around PyMlirContext::liveOperations. (PR #122720)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 13 06:59:31 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

<details>
<summary>Changes</summary>

In JAX, I observed a race between two PyOperation destructors from different threads updating the same `liveOperations` map, despite not intentionally sharing the context between different threads. Since I don't think we can be completely sure when GC happens and on which thread, it seems safest simply to add locking here.

We may also want to explicitly support sharing a context between threads in the future, which would require this change or something similar.

---
Full diff: https://github.com/llvm/llvm-project/pull/122720.diff


2 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+31-12) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+3) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 463ebdebb3f3f6..53806ca9f04a49 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() {
   return getLiveContexts().size();
 }
 
-size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+size_t PyMlirContext::getLiveOperationCount() {
+  nb::ft_lock_guard lock(liveOperationsMutex);
+  return liveOperations.size();
+}
 
 std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
   std::vector<PyOperation *> liveObjects;
+  nb::ft_lock_guard lock(liveOperationsMutex);
   for (auto &entry : liveOperations)
     liveObjects.push_back(entry.second.second);
   return liveObjects;
 }
 
 size_t PyMlirContext::clearLiveOperations() {
-  for (auto &op : liveOperations)
+
+  LiveOperationMap operations;
+  {
+    nb::ft_lock_guard lock(liveOperationsMutex);
+    std::swap(operations, liveOperations);
+  }
+  for (auto &op : operations)
     op.second.second->setInvalid();
-  size_t numInvalidated = liveOperations.size();
-  liveOperations.clear();
+  size_t numInvalidated = operations.size();
   return numInvalidated;
 }
 
 void PyMlirContext::clearOperation(MlirOperation op) {
-  auto it = liveOperations.find(op.ptr);
-  if (it != liveOperations.end()) {
-    it->second.second->setInvalid();
+  PyOperation *py_op;
+  {
+    nb::ft_lock_guard lock(liveOperationsMutex);
+    auto it = liveOperations.find(op.ptr);
+    if (it == liveOperations.end()) {
+      return;
+    }
+    py_op = it->second.second;
     liveOperations.erase(it);
   }
+  py_op->setInvalid();
 }
 
 void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
@@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() {
 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
-  auto &liveOperations = contextRef->liveOperations;
   // Create.
   PyOperation *unownedOperation =
       new PyOperation(std::move(contextRef), operation);
@@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
   if (parentKeepAlive) {
     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
   }
-  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
   return PyOperationRef(unownedOperation, std::move(pyRef));
 }
 
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
+  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
   auto &liveOperations = contextRef->liveOperations;
   auto it = liveOperations.find(operation.ptr);
   if (it == liveOperations.end()) {
     // Create.
-    return createInstance(std::move(contextRef), operation,
-                          std::move(parentKeepAlive));
+    PyOperationRef result = createInstance(std::move(contextRef), operation,
+                                           std::move(parentKeepAlive));
+    liveOperations[operation.ptr] =
+        std::make_pair(result.getObject(), result.get());
+    return result;
   }
   // Use existing.
   PyOperation *existing = it->second.second;
@@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
+  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
   auto &liveOperations = contextRef->liveOperations;
   assert(liveOperations.count(operation.ptr) == 0 &&
          "cannot create detached operation that already exists");
   (void)liveOperations;
-
   PyOperationRef created = createInstance(std::move(contextRef), operation,
                                           std::move(parentKeepAlive));
+  liveOperations[operation.ptr] =
+      std::make_pair(created.getObject(), created.get());
   created->attached = false;
   return created;
 }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f5fbb6c61b57e2..d1fb4308dbb77c 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -277,6 +277,9 @@ class PyMlirContext {
   // attempt to access it will raise an error.
   using LiveOperationMap =
       llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
+  nanobind::ft_mutex liveOperationsMutex;
+
+  // Guarded by liveOperationsMutex in free-threading mode.
   LiveOperationMap liveOperations;
 
   bool emitErrorDiagnostics = false;

``````````

</details>


https://github.com/llvm/llvm-project/pull/122720


More information about the Mlir-commits mailing list