[Mlir-commits] [mlir] [MLIR][Python] remove `liveOperations` (PR #155114)

Maksim Levental llvmlistbot at llvm.org
Sat Aug 23 14:45:47 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/155114

>From c8027c4a50023ebac8606bb222d230b6ae0e86f9 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 17 May 2024 22:36:41 -0500
Subject: [PATCH 1/5] [mlir][python] wip remove liveOpeartions

---
 mlir/include/mlir-c/IR.h                      |   2 +
 mlir/lib/Bindings/Python/IRCore.cpp           | 173 +++---------------
 mlir/lib/Bindings/Python/IRModule.h           |  53 ------
 mlir/lib/Bindings/Python/Pass.cpp             |   8 +-
 .../Bindings/Python/TransformInterpreter.cpp  |   1 -
 mlir/lib/CAPI/IR/IR.cpp                       |   4 +
 mlir/test/python/ir/module.py                 |  22 +--
 mlir/test/python/ir/operation.py              |   3 +-
 mlir/test/python/ir/symbol_table.py           |   8 -
 mlir/test/python/pass_manager.py              |  27 +--
 10 files changed, 41 insertions(+), 260 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 71c7d4378677f..d05f91d7e3b12 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
 /// The returned module is null when the input operation was not a ModuleOp.
 MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
 
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
+
 //===----------------------------------------------------------------------===//
 // Operation state.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b3a06cbce854..99d0efb276311 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
   return getLiveContexts().size();
 }
 
-size_t PyMlirContext::getLiveOperationCount() {
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
-  std::vector<PyOperation *> liveObjects;
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  for (auto &entry : liveOperations)
-    liveObjects.push_back(entry.second.second);
-  return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
-  LiveOperationMap operations;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    std::swap(operations, liveOperations);
-  }
-  for (auto &op : operations)
-    op.second.second->setInvalid();
-  size_t numInvalidated = operations.size();
-  return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
-  PyOperation *py_op;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    auto it = liveOperations.find(op.ptr);
-    if (it == liveOperations.end()) {
-      return;
-    }
-    py_op = it->second.second;
-    liveOperations.erase(it);
-  }
-  py_op->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
-  typedef struct {
-    PyOperation &rootOp;
-    bool rootSeen;
-  } callBackData;
-  callBackData data{op.getOperation(), false};
-  // Mark all ops below the op that the passmanager will be rooted
-  // at (but not op itself - note the preorder) as invalid.
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    callBackData *data = static_cast<callBackData *>(userData);
-    if (LLVM_LIKELY(data->rootSeen))
-      data->rootOp.getOperation().getContext()->clearOperation(op);
-    else
-      data->rootSeen = true;
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
-  PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
-  clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
-    contextRef->clearOperation(op);
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
 nb::object PyMlirContext::contextEnter(nb::object context) {
   return PyThreadContextEntry::pushContext(context);
 }
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() {
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = getContext()->liveModules;
-  assert(liveModules.count(module.ptr) == 1 &&
-         "destroying module not in live map");
-  liveModules.erase(module.ptr);
-  mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = contextRef->liveModules;
-  auto it = liveModules.find(module.ptr);
-  if (it == liveModules.end()) {
-    // Create.
-    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-    // Note that the default return value policy on cast is automatic_reference,
-    // which does not take ownership (delete will not be called).
-    // Just be explicit.
-    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-    unownedModule->handle = pyRef;
-    liveModules[module.ptr] =
-        std::make_pair(unownedModule->handle, unownedModule);
-    return PyModuleRef(unownedModule, std::move(pyRef));
-  }
-  // Use existing.
-  PyModule *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyModuleRef(existing, std::move(pyRef));
+  // Create.
+  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+  // Note that the default return value policy on cast is automatic_reference,
+  // which does not take ownership (delete will not be called).
+  // Just be explicit.
+  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+  unownedModule->handle = pyRef;
+  return PyModuleRef(unownedModule, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,15 +1111,11 @@ PyOperation::~PyOperation() {
   // If the operation has already been invalidated there is nothing to do.
   if (!valid)
     return;
-
-  // Otherwise, invalidate the operation and remove it from live map when it is
-  // attached.
-  if (isAttached()) {
-    getContext()->clearOperation(*this);
-  } else {
-    // And destroy it when it is detached, i.e. owned by Python, in which case
-    // all nested operations must be invalidated at removed from the live map as
-    // well.
+  // Otherwise, invalidate the operation when it is attached.
+  if (isAttached())
+    setInvalid();
+  else {
+    // And destroy it when it is detached, i.e. owned by Python.
     erase();
   }
 }
@@ -1252,35 +1152,16 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  auto it = liveOperations.find(operation.ptr);
-  if (it == liveOperations.end()) {
-    // Create.
-    PyOperationRef result = createInstance(std::move(contextRef), operation,
-                                           std::move(parentKeepAlive));
-    liveOperations[operation.ptr] =
-        std::make_pair(result.getObject(), result.get());
-    return result;
-  }
-  // Use existing.
-  PyOperation *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyOperationRef(existing, std::move(pyRef));
+  // Create.
+  return createInstance(std::move(contextRef), operation,
+                        std::move(parentKeepAlive));
 }
 
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  assert(liveOperations.count(operation.ptr) == 0 &&
-         "cannot create detached operation that already exists");
-  (void)liveOperations;
   PyOperationRef created = createInstance(std::move(contextRef), operation,
                                           std::move(parentKeepAlive));
-  liveOperations[operation.ptr] =
-      std::make_pair(created.getObject(), created.get());
   created->attached = false;
   return created;
 }
@@ -1652,7 +1533,6 @@ nb::object PyOperation::createOpView() {
 
 void PyOperation::erase() {
   checkValid();
-  getContext()->clearOperationAndInside(*this);
   mlirOperationDestroy(operation);
 }
 
@@ -3023,14 +2903,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
-      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
-      .def("_get_live_operation_objects",
-           &PyMlirContext::getLiveOperationObjects)
-      .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
-      .def("_clear_live_operations_inside",
-           nb::overload_cast<MlirOperation>(
-               &PyMlirContext::clearOperationsInside))
-      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
@@ -3428,7 +3300,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             // Defer to the operation's __str__.
             return self.attr("operation").attr("__str__")();
           },
-          kOperationStrDunderDocstring);
+          kOperationStrDunderDocstring)
+      .def(
+          "__eq__",
+          [](PyModule &self, PyModule &other) {
+            return mlirModuleEqual(self.get(), other.get());
+          },
+          "other"_a);
 
   //----------------------------------------------------------------------------
   // Mapping of Operation.
@@ -3440,7 +3318,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                    })
       .def("__eq__",
            [](PyOperationBase &self, PyOperationBase &other) {
-             return &self.getOperation() == &other.getOperation();
+             return mlirOperationEqual(self.getOperation().get(),
+                                       other.getOperation().get());
            })
       .def("__eq__",
            [](PyOperationBase &self, nb::object other) { return false; })
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fa16ae3ce3294..027bfef5cba07 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,40 +218,6 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
-  /// Get a list of Python objects which are still in the live context map.
-  std::vector<PyOperation *> getLiveOperationObjects();
-
-  /// Gets the count of live operations associated with this context.
-  /// Used for testing.
-  size_t getLiveOperationCount();
-
-  /// Clears the live operations map, returning the number of entries which were
-  /// invalidated. To be used as a safety mechanism so that API end-users can't
-  /// corrupt by holding references they shouldn't have accessed in the first
-  /// place.
-  size_t clearLiveOperations();
-
-  /// Removes an operation from the live operations map and sets it invalid.
-  /// This is useful for when some non-bindings code destroys the operation and
-  /// the bindings need to made aware. For example, in the case when pass
-  /// manager is run.
-  ///
-  /// Note that this does *NOT* clear the nested operations.
-  void clearOperation(MlirOperation op);
-
-  /// Clears all operations nested inside the given op using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationsInside(PyOperationBase &op);
-  void clearOperationsInside(MlirOperation op);
-
-  /// Clears the operaiton _and_ all operations inside using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationAndInside(PyOperationBase &op);
-
-  /// Gets the count of live modules associated with this context.
-  /// Used for testing.
-  size_t getLiveModuleCount();
-
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
-  // Interns all live modules associated with this context. Modules tracked
-  // in this map are valid. When a module is invalidated, it is removed
-  // from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveModuleMap =
-      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
-  LiveModuleMap liveModules;
-
-  // Interns all live operations associated with this context. Operations
-  // tracked in this map are valid. When an operation is invalidated, it is
-  // removed from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveOperationMap =
-      llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
-  nanobind::ft_mutex liveOperationsMutex;
-
-  // Guarded by liveOperationsMutex in free-threading mode.
-  LiveOperationMap liveOperations;
-
   bool emitErrorDiagnostics = false;
 
   MlirContext context;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 20017e25b69bb..817479ee2421b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op,
-             bool invalidateOps) {
-            if (invalidateOps) {
-              op.getOperation().getContext()->clearOperationsInside(op);
-            }
+          [](PyPassManager &passManager, PyOperationBase &op) {
             // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          "operation"_a, "invalidate_ops"_a = true,
+          "operation"_a,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed62778f..920bca886f617 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
         // root. This is awkward, but we don't have access to PyMlirContext
         // object here otherwise.
         nb::object obj = nb::cast(payloadRoot);
-        obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
 
         MlirLogicalResult result = mlirTransformApplyNamedSequence(
             payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553dab76f..c7069f0017b5d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
 }
 
+bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) {
+  return unwrap(lhs) == unwrap(rhs);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation state API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 6065e59fd6ed9..449e25d4edde2 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,27 +121,17 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    live_ops = ctx._get_live_operation_objects()
-    assert len(live_ops) == 1
-    assert live_ops[0] is op1
-    live_ops = None
     # CHECK: module @successfulParse
     print(op1)
 
     # Ensure that operations are the same on multiple calls.
     op2 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    assert op1 is op2
+    assert not op1 is op2
+    assert op1 == op2
 
     # Test live operation clearing.
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    num_invalidated = ctx._clear_live_operations()
-    assert num_invalidated == 1
-    assert ctx._get_live_operation_count() == 0
     op1 = None
     gc.collect()
     op1 = module.operation
@@ -155,9 +145,6 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
-    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
-    assert ctx._get_live_operation_count() == 0
-    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -165,16 +152,15 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is module_dup
+    assert not module is module_dup
+    assert module == module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
-    assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bf16e3f75d60d..bb74b6bc5e5ed 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -907,7 +907,8 @@ def testCapsuleConversions():
         m_capsule = m._CAPIPtr
         assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
         m2 = Operation._CAPICreate(m_capsule)
-        assert m2 is m
+        assert not m2 is m
+        assert m2 == m
 
 
 # CHECK-LABEL: TEST: testOperationErase
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 8b6d7ea5a197d..7afd539271d21 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,14 +56,6 @@ def testSymbolTableInsert():
         print(m1)
         assert "bar" not in symbol_table
 
-        try:
-            print(bar)
-        except RuntimeError as e:
-            if "the operation has been invalidated" not in str(e):
-                raise
-        else:
-            assert False, "expected RuntimeError due to invalidated operation"
-
         qux = m2.body.operations[0]
         m1.body.append(qux)
         symbol_table.insert(qux)
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e26d42bb32913..aea8803a57bc5 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
 @run
 def testPostPassOpInvalidation():
     with Context() as ctx:
-        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
-        # CHECK: invalidate_ops=False
-        log("invalidate_ops=False")
-
-        # CHECK: live ops: 0
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
         """
         )
 
-        # CHECK: live ops: 1
-        log_op_count()
-
         outer_const_op = module.body.operations[0]
         # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
         log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
         # CHECK: %[[VAL1]] = arith.constant 10 : i64
         log(inner_const_op)
 
-        # CHECK: live ops: 4
-        log_op_count()
-
-        PassManager.parse("builtin.module(canonicalize)").run(
-            module, invalidate_ops=False
-        )
+        PassManager.parse("builtin.module(canonicalize)").run(module)
         # CHECK: func.func @foo() {
         # CHECK:   return
         # CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
         # CHECK: invalidate_ops=True
         log("invalidate_ops=True")
 
-        # CHECK: live ops: 4
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -251,14 +232,8 @@ def testPostPassOpInvalidation():
         func_op = module.body.operations[1]
         inner_const_op = func_op.body.blocks[0].operations[0]
 
-        # CHECK: live ops: 4
-        log_op_count()
-
         PassManager.parse("builtin.module(canonicalize)").run(module)
 
-        # CHECK: live ops: 1
-        log_op_count()
-
         try:
             log(func_op)
         except RuntimeError as e:

>From 356aabb813f7afab2f66242563aa5c11bd713b48 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 14:30:14 -0400
Subject: [PATCH 2/5] "fix" testPostPassOpInvalidation

---
 mlir/test/python/pass_manager.py | 21 ---------------------
 1 file changed, 21 deletions(-)

diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index aea8803a57bc5..0896cd9784641 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -228,30 +228,9 @@ def testPostPassOpInvalidation():
           }
         """
         )
-        outer_const_op = module.body.operations[0]
-        func_op = module.body.operations[1]
-        inner_const_op = func_op.body.blocks[0].operations[0]
 
         PassManager.parse("builtin.module(canonicalize)").run(module)
 
-        try:
-            log(func_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
-        try:
-            log(outer_const_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
-        try:
-            log(inner_const_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
         # CHECK: func.func @foo() {
         # CHECK:   return
         # CHECK: }

>From 3c5352e6415d857737709638405dfd26b7bb4022 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 15:06:07 -0400
Subject: [PATCH 3/5] try to fix testModuleCapsule

---
 mlir/lib/Bindings/Python/IRCore.cpp | 1 +
 mlir/lib/Bindings/Python/IRModule.h | 2 ++
 mlir/test/python/ir/module.py       | 2 ++
 3 files changed, 5 insertions(+)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 99d0efb276311..c3fb6c614b68e 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3221,6 +3221,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
           [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 027bfef5cba07..0df4ccbdd23aa 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -522,6 +522,8 @@ class PyModule : public BaseContextObject {
   /// is taken by calling this function.
   static nanobind::object createFromCapsule(nanobind::object capsule);
 
+  void clearMlirModule() { module = {nullptr}; }
+
 private:
   PyModule(PyMlirContextRef contextRef, MlirModule module);
   MlirModule module;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 449e25d4edde2..a552eaa662af4 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -158,6 +158,8 @@ def testModuleCapsule():
     module_dup = Module._CAPICreate(module_capsule)
     assert not module is module_dup
     assert module == module_dup
+    module._clear_mlir_module()
+    assert not module == module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None

>From 41a7dcf5962ba6cb49a1fa5198765e2ed38e000d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 16:43:49 -0400
Subject: [PATCH 4/5] add check for Operation._CAPICreate

---
 mlir/test/python/ir/operation.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bb74b6bc5e5ed..94f39c0fbd077 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -909,6 +909,11 @@ def testCapsuleConversions():
         m2 = Operation._CAPICreate(m_capsule)
         assert not m2 is m
         assert m2 == m
+        # Gc and verify destructed.
+        m = None
+        m_capsule = None
+        m2 = None
+        gc.collect()
 
 
 # CHECK-LABEL: TEST: testOperationErase

>From 436609a9034886b260453e8dc9424c4f62ae4478 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 17:17:36 -0400
Subject: [PATCH 5/5] update docs

---
 mlir/lib/Bindings/Python/IRModule.h | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 0df4ccbdd23aa..c1fdfd64ee1e7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -495,8 +495,8 @@ class PyModule;
 using PyModuleRef = PyObjectRef<PyModule>;
 class PyModule : public BaseContextObject {
 public:
-  /// Returns a PyModule reference for the given MlirModule. This may return
-  /// a pre-existing or new object.
+  /// Returns a PyModule reference for the given MlirModule. This always returns
+  /// a new object.
   static PyModuleRef forModule(MlirModule module);
   PyModule(PyModule &) = delete;
   PyModule(PyMlirContext &&) = delete;
@@ -517,9 +517,8 @@ class PyModule : public BaseContextObject {
   nanobind::object getCapsule();
 
   /// Creates a PyModule from the MlirModule wrapped by a capsule.
-  /// Note that PyModule instances are uniqued, so the returned object
-  /// may be a pre-existing object. Ownership of the underlying MlirModule
-  /// is taken by calling this function.
+  /// Note this returns a new object BUT clearMlirModule() must be called to
+  /// prevent double-frees (of the underlying mlir::Module).
   static nanobind::object createFromCapsule(nanobind::object capsule);
 
   void clearMlirModule() { module = {nullptr}; }



More information about the Mlir-commits mailing list