[Mlir-commits] [mlir] [mlir][py] invalidate nested operations when parent is deleted (PR #93339)

Oleksandr Alex Zinenko llvmlistbot at llvm.org
Fri May 24 12:56:31 PDT 2024


https://github.com/ftynse created https://github.com/llvm/llvm-project/pull/93339

When an operation is erased in Python, its children may still be in the "live" list inside Python bindings. After this, if some of the newly allocated operations happen to reuse the same pointer address, this will trigger an assertion in the bindings. This assertion would be incorrect because the operations aren't actually live. Make sure we remove the children operations form the "live" list when erasing the parent.

This also concentrates responsibility over the removal from the "live" list and invalidation in a single place.

Note that this requires the IR to be sufficiently structurally valid so a walk through it can succeeed. If this invariant was broken by, e.g, C++ pass called from Python, there isn't much we can do.

>From 8c3727a3ccf21afce0ca1fb8cba3baa01a31f68d Mon Sep 17 00:00:00 2001
From: Alex Zinenko <ftynse at gmail.com>
Date: Fri, 24 May 2024 21:51:20 +0200
Subject: [PATCH] [mlir][py] invalidate nested operations when parent is
 deleted

When an operation is erased in Python, its children may still be in the "live"
list inside Python bindings. After this, if some of the newly allocated
operations happen to reuse the same pointer address, this will trigger an
assertion in the bindings. This assertion would be incorrect because the
operations aren't actually live. Make sure we remove the children operations
form the "live" list when erasing the parent.

This also concentrates responsibility over the removal from the "live" list and
invalidation in a single place.

Note that this requires the IR to be sufficiently structurally valid so a walk
through it can succeeed. If this invariant was broken by, e.g, C++ pass called
from Python, there isn't much we can do.
---
 mlir/lib/Bindings/Python/IRCore.cpp | 35 ++++++++++++++--------
 mlir/lib/Bindings/Python/IRModule.h |  7 +++++
 mlir/test/python/live_operations.py | 46 +++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 13 deletions(-)
 create mode 100644 mlir/test/python/live_operations.py

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 01678a9719f90..f03c540d618cd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -684,6 +684,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) {
   clearOperationsInside(opRef->getOperation());
 }
 
+void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
+  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
+                                                      void *userData) {
+    PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
+    contextRef->clearOperation(op);
+    return MlirWalkResult::MlirWalkResultAdvance;
+  };
+  mlirOperationWalk(op.getOperation(), invalidatingCallback,
+                    &op.getOperation().getContext(), MlirWalkPreOrder);
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
@@ -1112,12 +1123,16 @@ PyOperation::~PyOperation() {
   // If the operation has already been invalidated there is nothing to do.
   if (!valid)
     return;
-  auto &liveOperations = getContext()->liveOperations;
-  assert(liveOperations.count(operation.ptr) == 1 &&
-         "destroying operation not in live map");
-  liveOperations.erase(operation.ptr);
-  if (!isAttached()) {
-    mlirOperationDestroy(operation);
+
+  // Otherwise, invalidate the operation and remove it from live map when it is
+  // attached.
+  if (isAttached()) {
+    getContext()->clearOperation(*this);
+  } else {
+    // And destroy it when it is detached, i.e. owned by Python, in which case
+    // all nested operations must be invalidated at removed from the live map as
+    // well.
+    erase();
   }
 }
 
@@ -1527,14 +1542,8 @@ py::object PyOperation::createOpView() {
 
 void PyOperation::erase() {
   checkValid();
-  // TODO: Fix memory hazards when erasing a tree of operations for which a deep
-  // Python reference to a child operation is live. All children should also
-  // have their `valid` bit set to false.
-  auto &liveOperations = getContext()->liveOperations;
-  if (liveOperations.count(operation.ptr))
-    liveOperations.erase(operation.ptr);
+  getContext()->clearOperationAndInside(*this);
   mlirOperationDestroy(operation);
-  valid = false;
 }
 
 //------------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b038a0c54d29b..8c34c11f70950 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,8 @@ class PyMlirContext {
   /// This is useful for when some non-bindings code destroys the operation and
   /// the bindings need to made aware. For example, in the case when pass
   /// manager is run.
+  ///
+  /// Note that this does *NOT* clear the nested operations.
   void clearOperation(MlirOperation op);
 
   /// Clears all operations nested inside the given op using
@@ -225,6 +227,10 @@ class PyMlirContext {
   void clearOperationsInside(PyOperationBase &op);
   void clearOperationsInside(MlirOperation op);
 
+  /// Clears the operaiton _and_ all operations inside using
+  /// `clearOperation(MlirOperation)`.
+  void clearOperationAndInside(PyOperationBase &op);
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
@@ -246,6 +252,7 @@ class PyMlirContext {
 
 private:
   PyMlirContext(MlirContext context);
+
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
   // preserving the relationship that an MlirContext maps to a single
   // PyMlirContext wrapper. This could be replaced in the future with an
diff --git a/mlir/test/python/live_operations.py b/mlir/test/python/live_operations.py
new file mode 100644
index 0000000000000..892ed1715f6c7
--- /dev/null
+++ b/mlir/test/python/live_operations.py
@@ -0,0 +1,46 @@
+# RUN: %PYTHON %s
+# It is sufficient that this doesn't assert.
+
+from mlir.ir import *
+
+
+def createDetachedModule():
+    module = Module.create()
+    with InsertionPoint(module.body):
+        # TODO: Python bindings are currently unaware that modules are also
+        # operations, so having a module erased won't trigger the cascading
+        # removal of live operations (#93337). Use a non-module operation
+        # instead.
+        nested = Operation.create("test.some_operation", regions=1)
+
+        # When the operation is detached from parent, it is considered to be
+        # owned by Python. It will therefore be erased when the Python object
+        # is destroyed.
+        nested.detach_from_parent()
+
+        # However, we create and maintain references to operations within
+        # `nested`. These references keep the corresponding operations in the
+        # "live" list even if they have been erased in C++, making them
+        # "zombie". If the C++ allocator reuses one of the address previously
+        # used for a now-"zombie" operation, this used to result in an
+        # assertion "cannot create detached operation that already exists" from
+        # the bindings code. Erasing the detached operation should result in
+        # removing all nested operations from the live list.
+        #
+        # Note that the assertion is not guaranteed since it depends on the
+        # behavior of the allocator on the C++ side, so this test mail fail
+        # intermittently.
+        with InsertionPoint(nested.regions[0].blocks.append()):
+            a = [Operation.create("test.some_other_operation") for i in range(100)]
+    return a
+
+
+def createManyDetachedModules():
+    with Context() as ctx, Location.unknown():
+        ctx.allow_unregistered_dialects = True
+        for j in range(100):
+            a = createDetachedModule()
+
+
+if __name__ == "__main__":
+    createManyDetachedModules()



More information about the Mlir-commits mailing list