[Mlir-commits] [mlir] [MLIR][Python] remove `liveOperations` (PR #155114)

Maksim Levental llvmlistbot at llvm.org
Tue Aug 26 12:21:26 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/155114

>From 9af755cd4794434439074e2e80e5ac0a43aad78a Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 17 May 2024 22:36:41 -0500
Subject: [PATCH 1/7] [mlir][python] wip remove liveOpeartions

---
 mlir/include/mlir-c/IR.h                      |   2 +
 mlir/lib/Bindings/Python/IRCore.cpp           | 174 +++---------------
 mlir/lib/Bindings/Python/IRModule.h           |  53 ------
 mlir/lib/Bindings/Python/Pass.cpp             |   8 +-
 .../Bindings/Python/TransformInterpreter.cpp  |   1 -
 mlir/lib/CAPI/IR/IR.cpp                       |   4 +
 mlir/test/python/ir/module.py                 |  22 +--
 mlir/test/python/ir/operation.py              |   3 +-
 mlir/test/python/ir/symbol_table.py           |   8 -
 mlir/test/python/pass_manager.py              |  27 +--
 10 files changed, 42 insertions(+), 260 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 71c7d4378677f..d05f91d7e3b12 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
 /// The returned module is null when the input operation was not a ModuleOp.
 MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
 
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
+
 //===----------------------------------------------------------------------===//
 // Operation state.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 15889ddabd2c4..789891f495217 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
   return getLiveContexts().size();
 }
 
-size_t PyMlirContext::getLiveOperationCount() {
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
-  std::vector<PyOperation *> liveObjects;
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  for (auto &entry : liveOperations)
-    liveObjects.push_back(entry.second.second);
-  return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
-  LiveOperationMap operations;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    std::swap(operations, liveOperations);
-  }
-  for (auto &op : operations)
-    op.second.second->setInvalid();
-  size_t numInvalidated = operations.size();
-  return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
-  PyOperation *pyOp;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    auto it = liveOperations.find(op.ptr);
-    if (it == liveOperations.end()) {
-      return;
-    }
-    pyOp = it->second.second;
-    liveOperations.erase(it);
-  }
-  pyOp->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
-  using callBackData = struct {
-    PyOperation &rootOp;
-    bool rootSeen;
-  };
-  callBackData data{op.getOperation(), false};
-  // Mark all ops below the op that the passmanager will be rooted
-  // at (but not op itself - note the preorder) as invalid.
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    callBackData *data = static_cast<callBackData *>(userData);
-    if (LLVM_LIKELY(data->rootSeen))
-      data->rootOp.getOperation().getContext()->clearOperation(op);
-    else
-      data->rootSeen = true;
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
-  PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
-  clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
-    contextRef->clearOperation(op);
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
 nb::object PyMlirContext::contextEnter(nb::object context) {
   return PyThreadContextEntry::pushContext(context);
 }
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() {
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = getContext()->liveModules;
-  assert(liveModules.count(module.ptr) == 1 &&
-         "destroying module not in live map");
-  liveModules.erase(module.ptr);
-  mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = contextRef->liveModules;
-  auto it = liveModules.find(module.ptr);
-  if (it == liveModules.end()) {
-    // Create.
-    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-    // Note that the default return value policy on cast is automatic_reference,
-    // which does not take ownership (delete will not be called).
-    // Just be explicit.
-    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-    unownedModule->handle = pyRef;
-    liveModules[module.ptr] =
-        std::make_pair(unownedModule->handle, unownedModule);
-    return PyModuleRef(unownedModule, std::move(pyRef));
-  }
-  // Use existing.
-  PyModule *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyModuleRef(existing, std::move(pyRef));
+  // Create.
+  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+  // Note that the default return value policy on cast is automatic_reference,
+  // which does not take ownership (delete will not be called).
+  // Just be explicit.
+  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+  unownedModule->handle = pyRef;
+  return PyModuleRef(unownedModule, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,15 +1111,11 @@ PyOperation::~PyOperation() {
   // If the operation has already been invalidated there is nothing to do.
   if (!valid)
     return;
-
-  // Otherwise, invalidate the operation and remove it from live map when it is
-  // attached.
-  if (isAttached()) {
-    getContext()->clearOperation(*this);
-  } else {
-    // And destroy it when it is detached, i.e. owned by Python, in which case
-    // all nested operations must be invalidated at removed from the live map as
-    // well.
+  // Otherwise, invalidate the operation when it is attached.
+  if (isAttached())
+    setInvalid();
+  else {
+    // And destroy it when it is detached, i.e. owned by Python.
     erase();
   }
 }
@@ -1252,35 +1152,16 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  auto it = liveOperations.find(operation.ptr);
-  if (it == liveOperations.end()) {
-    // Create.
-    PyOperationRef result = createInstance(std::move(contextRef), operation,
-                                           std::move(parentKeepAlive));
-    liveOperations[operation.ptr] =
-        std::make_pair(result.getObject(), result.get());
-    return result;
-  }
-  // Use existing.
-  PyOperation *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyOperationRef(existing, std::move(pyRef));
+  // Create.
+  return createInstance(std::move(contextRef), operation,
+                        std::move(parentKeepAlive));
 }
 
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  assert(liveOperations.count(operation.ptr) == 0 &&
-         "cannot create detached operation that already exists");
-  (void)liveOperations;
   PyOperationRef created = createInstance(std::move(contextRef), operation,
                                           std::move(parentKeepAlive));
-  liveOperations[operation.ptr] =
-      std::make_pair(created.getObject(), created.get());
   created->attached = false;
   return created;
 }
@@ -1652,7 +1533,7 @@ nb::object PyOperation::createOpView() {
 
 void PyOperation::erase() {
   checkValid();
-  getContext()->clearOperationAndInside(*this);
+  setInvalid();
   mlirOperationDestroy(operation);
 }
 
@@ -3023,14 +2904,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
-      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
-      .def("_get_live_operation_objects",
-           &PyMlirContext::getLiveOperationObjects)
-      .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
-      .def("_clear_live_operations_inside",
-           nb::overload_cast<MlirOperation>(
-               &PyMlirContext::clearOperationsInside))
-      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
@@ -3428,7 +3301,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             // Defer to the operation's __str__.
             return self.attr("operation").attr("__str__")();
           },
-          kOperationStrDunderDocstring);
+          kOperationStrDunderDocstring)
+      .def(
+          "__eq__",
+          [](PyModule &self, PyModule &other) {
+            return mlirModuleEqual(self.get(), other.get());
+          },
+          "other"_a);
 
   //----------------------------------------------------------------------------
   // Mapping of Operation.
@@ -3440,7 +3319,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                    })
       .def("__eq__",
            [](PyOperationBase &self, PyOperationBase &other) {
-             return &self.getOperation() == &other.getOperation();
+             return mlirOperationEqual(self.getOperation().get(),
+                                       other.getOperation().get());
            })
       .def("__eq__",
            [](PyOperationBase &self, nb::object other) { return false; })
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6617b41cc916c..553da2ef52880 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,40 +218,6 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
-  /// Get a list of Python objects which are still in the live context map.
-  std::vector<PyOperation *> getLiveOperationObjects();
-
-  /// Gets the count of live operations associated with this context.
-  /// Used for testing.
-  size_t getLiveOperationCount();
-
-  /// Clears the live operations map, returning the number of entries which were
-  /// invalidated. To be used as a safety mechanism so that API end-users can't
-  /// corrupt by holding references they shouldn't have accessed in the first
-  /// place.
-  size_t clearLiveOperations();
-
-  /// Removes an operation from the live operations map and sets it invalid.
-  /// This is useful for when some non-bindings code destroys the operation and
-  /// the bindings need to made aware. For example, in the case when pass
-  /// manager is run.
-  ///
-  /// Note that this does *NOT* clear the nested operations.
-  void clearOperation(MlirOperation op);
-
-  /// Clears all operations nested inside the given op using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationsInside(PyOperationBase &op);
-  void clearOperationsInside(MlirOperation op);
-
-  /// Clears the operaiton _and_ all operations inside using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationAndInside(PyOperationBase &op);
-
-  /// Gets the count of live modules associated with this context.
-  /// Used for testing.
-  size_t getLiveModuleCount();
-
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
-  // Interns all live modules associated with this context. Modules tracked
-  // in this map are valid. When a module is invalidated, it is removed
-  // from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveModuleMap =
-      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
-  LiveModuleMap liveModules;
-
-  // Interns all live operations associated with this context. Operations
-  // tracked in this map are valid. When an operation is invalidated, it is
-  // removed from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveOperationMap =
-      llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
-  nanobind::ft_mutex liveOperationsMutex;
-
-  // Guarded by liveOperationsMutex in free-threading mode.
-  LiveOperationMap liveOperations;
-
   bool emitErrorDiagnostics = false;
 
   MlirContext context;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 20017e25b69bb..817479ee2421b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op,
-             bool invalidateOps) {
-            if (invalidateOps) {
-              op.getOperation().getContext()->clearOperationsInside(op);
-            }
+          [](PyPassManager &passManager, PyOperationBase &op) {
             // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          "operation"_a, "invalidate_ops"_a = true,
+          "operation"_a,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed62778f..920bca886f617 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
         // root. This is awkward, but we don't have access to PyMlirContext
         // object here otherwise.
         nb::object obj = nb::cast(payloadRoot);
-        obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
 
         MlirLogicalResult result = mlirTransformApplyNamedSequence(
             payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553dab76f..c7069f0017b5d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
 }
 
+bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) {
+  return unwrap(lhs) == unwrap(rhs);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation state API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 6065e59fd6ed9..449e25d4edde2 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,27 +121,17 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    live_ops = ctx._get_live_operation_objects()
-    assert len(live_ops) == 1
-    assert live_ops[0] is op1
-    live_ops = None
     # CHECK: module @successfulParse
     print(op1)
 
     # Ensure that operations are the same on multiple calls.
     op2 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    assert op1 is op2
+    assert not op1 is op2
+    assert op1 == op2
 
     # Test live operation clearing.
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    num_invalidated = ctx._clear_live_operations()
-    assert num_invalidated == 1
-    assert ctx._get_live_operation_count() == 0
     op1 = None
     gc.collect()
     op1 = module.operation
@@ -155,9 +145,6 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
-    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
-    assert ctx._get_live_operation_count() == 0
-    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -165,16 +152,15 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is module_dup
+    assert not module is module_dup
+    assert module == module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
-    assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bf16e3f75d60d..bb74b6bc5e5ed 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -907,7 +907,8 @@ def testCapsuleConversions():
         m_capsule = m._CAPIPtr
         assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
         m2 = Operation._CAPICreate(m_capsule)
-        assert m2 is m
+        assert not m2 is m
+        assert m2 == m
 
 
 # CHECK-LABEL: TEST: testOperationErase
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 8b6d7ea5a197d..7afd539271d21 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,14 +56,6 @@ def testSymbolTableInsert():
         print(m1)
         assert "bar" not in symbol_table
 
-        try:
-            print(bar)
-        except RuntimeError as e:
-            if "the operation has been invalidated" not in str(e):
-                raise
-        else:
-            assert False, "expected RuntimeError due to invalidated operation"
-
         qux = m2.body.operations[0]
         m1.body.append(qux)
         symbol_table.insert(qux)
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e26d42bb32913..aea8803a57bc5 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
 @run
 def testPostPassOpInvalidation():
     with Context() as ctx:
-        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
-        # CHECK: invalidate_ops=False
-        log("invalidate_ops=False")
-
-        # CHECK: live ops: 0
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
         """
         )
 
-        # CHECK: live ops: 1
-        log_op_count()
-
         outer_const_op = module.body.operations[0]
         # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
         log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
         # CHECK: %[[VAL1]] = arith.constant 10 : i64
         log(inner_const_op)
 
-        # CHECK: live ops: 4
-        log_op_count()
-
-        PassManager.parse("builtin.module(canonicalize)").run(
-            module, invalidate_ops=False
-        )
+        PassManager.parse("builtin.module(canonicalize)").run(module)
         # CHECK: func.func @foo() {
         # CHECK:   return
         # CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
         # CHECK: invalidate_ops=True
         log("invalidate_ops=True")
 
-        # CHECK: live ops: 4
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -251,14 +232,8 @@ def testPostPassOpInvalidation():
         func_op = module.body.operations[1]
         inner_const_op = func_op.body.blocks[0].operations[0]
 
-        # CHECK: live ops: 4
-        log_op_count()
-
         PassManager.parse("builtin.module(canonicalize)").run(module)
 
-        # CHECK: live ops: 1
-        log_op_count()
-
         try:
             log(func_op)
         except RuntimeError as e:

>From 809f0dc5c667ce39e72843e50e14ed3fc10dad8d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 14:30:14 -0400
Subject: [PATCH 2/7] "fix" testPostPassOpInvalidation

---
 mlir/test/python/pass_manager.py | 21 ---------------------
 1 file changed, 21 deletions(-)

diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index aea8803a57bc5..0896cd9784641 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -228,30 +228,9 @@ def testPostPassOpInvalidation():
           }
         """
         )
-        outer_const_op = module.body.operations[0]
-        func_op = module.body.operations[1]
-        inner_const_op = func_op.body.blocks[0].operations[0]
 
         PassManager.parse("builtin.module(canonicalize)").run(module)
 
-        try:
-            log(func_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
-        try:
-            log(outer_const_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
-        try:
-            log(inner_const_op)
-        except RuntimeError as e:
-            # CHECK: the operation has been invalidated
-            log(e)
-
         # CHECK: func.func @foo() {
         # CHECK:   return
         # CHECK: }

>From fd8a12aa687bc0cb021bd1de0a248e132744c8e3 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 15:06:07 -0400
Subject: [PATCH 3/7] try to fix testModuleCapsule

---
 mlir/lib/Bindings/Python/IRCore.cpp | 1 +
 mlir/lib/Bindings/Python/IRModule.h | 2 ++
 mlir/test/python/ir/module.py       | 2 ++
 3 files changed, 5 insertions(+)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 789891f495217..b0dc9f17a76b1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3222,6 +3222,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
           [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 553da2ef52880..932d46b5fd7ba 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -522,6 +522,8 @@ class PyModule : public BaseContextObject {
   /// is taken by calling this function.
   static nanobind::object createFromCapsule(nanobind::object capsule);
 
+  void clearMlirModule() { module = {nullptr}; }
+
 private:
   PyModule(PyMlirContextRef contextRef, MlirModule module);
   MlirModule module;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 449e25d4edde2..a552eaa662af4 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -158,6 +158,8 @@ def testModuleCapsule():
     module_dup = Module._CAPICreate(module_capsule)
     assert not module is module_dup
     assert module == module_dup
+    module._clear_mlir_module()
+    assert not module == module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None

>From 8eb4d75524f2dd9dc2123a2d9f0bde0295b6d23d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 16:43:49 -0400
Subject: [PATCH 4/7] add check for Operation._CAPICreate

---
 mlir/test/python/ir/operation.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bb74b6bc5e5ed..94f39c0fbd077 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -909,6 +909,11 @@ def testCapsuleConversions():
         m2 = Operation._CAPICreate(m_capsule)
         assert not m2 is m
         assert m2 == m
+        # Gc and verify destructed.
+        m = None
+        m_capsule = None
+        m2 = None
+        gc.collect()
 
 
 # CHECK-LABEL: TEST: testOperationErase

>From 8c722718b034cff71e7737afc22930a2ec5a1e6a Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 17:17:36 -0400
Subject: [PATCH 5/7] update docs

---
 mlir/lib/Bindings/Python/IRCore.cpp | 9 ++++++++-
 mlir/lib/Bindings/Python/IRModule.h | 9 ++++-----
 2 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index b0dc9f17a76b1..7f31ea1a7b1c8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
 See also: https://mlir.llvm.org/docs/LangRef/
 )";
 
+static const char kModuleCAPICreate[] =
+    R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
+Note this returns a new object BUT _clear_mlir_module(module) must be called to
+prevent double-frees (of the underlying mlir::Module).
+)";
+
 static const char kOperationCreateDocstring[] =
     R"(Creates a new operation.
 
@@ -3221,7 +3227,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+           kModuleCAPICreate)
       .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 932d46b5fd7ba..0cc0459ebc9a0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -495,8 +495,8 @@ class PyModule;
 using PyModuleRef = PyObjectRef<PyModule>;
 class PyModule : public BaseContextObject {
 public:
-  /// Returns a PyModule reference for the given MlirModule. This may return
-  /// a pre-existing or new object.
+  /// Returns a PyModule reference for the given MlirModule. This always returns
+  /// a new object.
   static PyModuleRef forModule(MlirModule module);
   PyModule(PyModule &) = delete;
   PyModule(PyMlirContext &&) = delete;
@@ -517,9 +517,8 @@ class PyModule : public BaseContextObject {
   nanobind::object getCapsule();
 
   /// Creates a PyModule from the MlirModule wrapped by a capsule.
-  /// Note that PyModule instances are uniqued, so the returned object
-  /// may be a pre-existing object. Ownership of the underlying MlirModule
-  /// is taken by calling this function.
+  /// Note this returns a new object BUT clearMlirModule() must be called to
+  /// prevent double-frees (of the underlying mlir::Module).
   static nanobind::object createFromCapsule(nanobind::object capsule);
 
   void clearMlirModule() { module = {nullptr}; }

>From 70d6b56d9f26928af5450c9efac62ff2694cf369 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 26 Aug 2025 14:58:22 -0400
Subject: [PATCH 6/7] comments

---
 mlir/include/mlir-c/IR.h                |  1 +
 mlir/lib/Bindings/Python/IRCore.cpp     | 10 ++++++----
 mlir/lib/Bindings/Python/MainModule.cpp |  4 ++++
 3 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d05f91d7e3b12..e97369778b377 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,7 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
 /// The returned module is null when the input operation was not a ModuleOp.
 MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
 
+/// Checks if two modules are equal.
 MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 7f31ea1a7b1c8..8ab8901cdc41f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1087,9 +1087,12 @@ PyModuleRef PyModule::forModule(MlirModule module) {
 
   // Create.
   PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-  // Note that the default return value policy on cast is automatic_reference,
-  // which does not take ownership (delete will not be called).
-  // Just be explicit.
+  // Note that the default return value policy on cast is `automatic_reference`,
+  // which means "does not take ownership, does not call delete/dtor".
+  // We use `take_ownership`, which means "Python will call the C++ destructor
+  // and delete operator when the Python wrapper is garbage collected", because
+  // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
+  // etc).
   nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
   unownedModule->handle = pyRef;
   return PyModuleRef(unownedModule, std::move(pyRef));
@@ -1158,7 +1161,6 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
-  // Create.
   return createInstance(std::move(contextRef), operation,
                         std::move(parentKeepAlive));
 }
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 278847e7ac7f5..d091d6a11ab11 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,4 +139,8 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
+
+  m.def("test_raise_exception", []() {
+    throw std::runtime_error("wtfbbq");
+  });
 }

>From 9343c1f1f2e7f7c5caef91c6f630870c7e139d39 Mon Sep 17 00:00:00 2001
From: makslevental <maksim.levental at gmail.com>
Date: Tue, 26 Aug 2025 15:15:40 -0400
Subject: [PATCH 7/7] update the docs

---
 mlir/docs/Bindings/Python.md            | 14 ++------------
 mlir/lib/Bindings/Python/MainModule.cpp |  4 ----
 2 files changed, 2 insertions(+), 16 deletions(-)

diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bef9e7f54948d..031d494746bd1 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -216,13 +216,8 @@ added to an attached operation, they need to be re-parented to the containing
 module).
 
 Due to the validity and parenting accounting needs, `PyOperation` is the owner
-for regions and blocks and needs to be a top-level type that we can count on not
-aliasing. This let's us do things like selectively invalidating instances when
-mutations occur without worrying that there is some alias to the same operation
-in the hierarchy. Operations are also the only entity that are allowed to be in
-a detached state, and they are interned at the context level so that there is
-never more than one Python `mlir.ir.Operation` object for a unique
-`MlirOperation`, regardless of how it is obtained.
+for regions and blocks. Operations are also the only entity that are allowed to be in
+a detached state.
 
 The C/C++ API allows for Region/Block to also be detached, but it simplifies the
 ownership model a lot to eliminate that possibility in this API, allowing the
@@ -238,11 +233,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on
 how hard it is to guarantee how mutations interact with their Python peer
 objects. We can cross that bridge easily when we get there.
 
-Module, when used purely from the Python API, can't alias anyway, so we can use
-it as a top-level ref type without a live-list for interning. If the API ever
-changes such that this cannot be guaranteed (i.e. by letting you marshal a
-native-defined Module in), then there would need to be a live table for it too.
-
 ## User-level API
 
 ### Context Management
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index d091d6a11ab11..278847e7ac7f5 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -139,8 +139,4 @@ NB_MODULE(_mlir, m) {
   auto passModule =
       m.def_submodule("passmanager", "MLIR Pass Management Bindings");
   populatePassManagerSubmodule(passModule);
-
-  m.def("test_raise_exception", []() {
-    throw std::runtime_error("wtfbbq");
-  });
 }



More information about the Mlir-commits mailing list