[Mlir-commits] [mlir] b2a7369 - [MLIR][Python] remove `liveOperations` (#155114)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 1 21:53:37 PDT 2025


Author: Maksim Levental
Date: 2025-09-01T21:53:33-07:00
New Revision: b2a7369631d70df8bcd0c4ee304afa30a0657ba7

URL: https://github.com/llvm/llvm-project/commit/b2a7369631d70df8bcd0c4ee304afa30a0657ba7
DIFF: https://github.com/llvm/llvm-project/commit/b2a7369631d70df8bcd0c4ee304afa30a0657ba7.diff

LOG: [MLIR][Python] remove `liveOperations` (#155114)

Historical context: `PyMlirContext::liveOperations` was an optimization
meant to cut down on the number of Python object allocations and
(partially) a mechanism for updating validity of ops after
transformation. E.g. during walking/transforming the AST. See original
patch [here](https://reviews.llvm.org/D87958).

Inspired by a
[renewed](https://github.com/llvm/llvm-project/pull/139721#issuecomment-3217131918)
interest in https://github.com/llvm/llvm-project/pull/139721 (which has
become a little stale...)

<p align="center">
<img width="504" height="375" alt="image"
src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186"
/>
</p>

In the previous go-around
(https://github.com/llvm/llvm-project/pull/92631) there were two issues
which have been resolved

1. ops that were "fetched" under a root op which has been transformed
are no longer reported as invalid. We simply "[formally
forbid](https://github.com/llvm/llvm-project/pull/92631#issuecomment-2119397018)"
this;
2. `Module._CAPICreate(module_capsule)` must now be followed by a
`module._clear_mlir_module()` to prevent double-freeing of the actual
`ModuleOp` object (i.e. calling the dtor on the
`OwningOpRef<ModuleOp>`):

     ```python
    module = ...
    module_dup = Module._CAPICreate(module._CAPIPtr)
    module._clear_mlir_module()
    ```
- **the alternative choice** here is to remove the `Module._CAPICreate`
API altogether and replace it with something like `Module._move(module)`
which will do both `Module._CAPICreate` and `module._clear_mlir_module`.

Note, the other approach I explored last year was a [weakref
system](https://github.com/llvm/llvm-project/pull/97340) for
`mlir::Operation` which would effectively hoist this `liveOperations`
thing into MLIR core. Possibly doable but I now believe it's a bad idea.

The other potentially breaking change is `is`, which checks object
equality rather than value equality, will now report `False` because we
are always allocating `new` Python objects (ie that's the whole point of
this change). Users wanting to check equality for `Operation` and
`Module` should use `==`.

Added: 
    

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/Bindings/Python/TransformInterpreter.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/ir/module.py
    mlir/test/python/ir/operation.py
    mlir/test/python/ir/symbol_table.py
    mlir/test/python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index bef9e7f54948d..98ac635aa4ee2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -216,13 +216,28 @@ added to an attached operation, they need to be re-parented to the containing
 module).
 
 Due to the validity and parenting accounting needs, `PyOperation` is the owner
-for regions and blocks and needs to be a top-level type that we can count on not
-aliasing. This let's us do things like selectively invalidating instances when
-mutations occur without worrying that there is some alias to the same operation
-in the hierarchy. Operations are also the only entity that are allowed to be in
-a detached state, and they are interned at the context level so that there is
-never more than one Python `mlir.ir.Operation` object for a unique
-`MlirOperation`, regardless of how it is obtained.
+for regions and blocks. Operations are also the only entities which are allowed to be in
+a detached state.
+
+**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`. 
+This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op` 
+and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2`
+will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any 
+operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**. 
+For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is 
+transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2` 
+become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to 
+`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the 
+purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that 
+
+1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.;
+2. Second, transform the AST/erase operations/etc. via a single root object;
+3. Invalidate all queried nodes (e.g., using `op._set_invalid()`).
+
+Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no 
+risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on 
+nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post
+modifying a parent op.
 
 The C/C++ API allows for Region/Block to also be detached, but it simplifies the
 ownership model a lot to eliminate that possibility in this API, allowing the
@@ -238,11 +253,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on
 how hard it is to guarantee how mutations interact with their Python peer
 objects. We can cross that bridge easily when we get there.
 
-Module, when used purely from the Python API, can't alias anyway, so we can use
-it as a top-level ref type without a live-list for interning. If the API ever
-changes such that this cannot be guaranteed (i.e. by letting you marshal a
-native-defined Module in), then there would need to be a live table for it too.
-
 ## User-level API
 
 ### Context Management
@@ -1229,4 +1239,4 @@ The exceptions to the free-threading compatibility:
 - Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
 - Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
 - Usage of `mlir.dialects.transform.interpreter` is unsafe.
-- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
\ No newline at end of file
+- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 71c7d4378677f..e97369778b377 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
 /// The returned module is null when the input operation was not a ModuleOp.
 MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
 
+/// Checks if two modules are equal.
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);
+
 //===----------------------------------------------------------------------===//
 // Operation state.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 15889ddabd2c4..2df2a73fd88ff 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
 See also: https://mlir.llvm.org/docs/LangRef/
 )";
 
+static const char kModuleCAPICreate[] =
+    R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
+Note this returns a new object BUT _clear_mlir_module(module) must be called to
+prevent double-frees (of the underlying mlir::Module).
+)";
+
 static const char kOperationCreateDocstring[] =
     R"(Creates a new operation.
 
@@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
   return getLiveContexts().size();
 }
 
-size_t PyMlirContext::getLiveOperationCount() {
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
-  std::vector<PyOperation *> liveObjects;
-  nb::ft_lock_guard lock(liveOperationsMutex);
-  for (auto &entry : liveOperations)
-    liveObjects.push_back(entry.second.second);
-  return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
-  LiveOperationMap operations;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    std::swap(operations, liveOperations);
-  }
-  for (auto &op : operations)
-    op.second.second->setInvalid();
-  size_t numInvalidated = operations.size();
-  return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
-  PyOperation *pyOp;
-  {
-    nb::ft_lock_guard lock(liveOperationsMutex);
-    auto it = liveOperations.find(op.ptr);
-    if (it == liveOperations.end()) {
-      return;
-    }
-    pyOp = it->second.second;
-    liveOperations.erase(it);
-  }
-  pyOp->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
-  using callBackData = struct {
-    PyOperation &rootOp;
-    bool rootSeen;
-  };
-  callBackData data{op.getOperation(), false};
-  // Mark all ops below the op that the passmanager will be rooted
-  // at (but not op itself - note the preorder) as invalid.
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    callBackData *data = static_cast<callBackData *>(userData);
-    if (LLVM_LIKELY(data->rootSeen))
-      data->rootOp.getOperation().getContext()->clearOperation(op);
-    else
-      data->rootSeen = true;
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
-  PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
-  clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
-  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
-                                                      void *userData) {
-    PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
-    contextRef->clearOperation(op);
-    return MlirWalkResult::MlirWalkResultAdvance;
-  };
-  mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                    &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
 nb::object PyMlirContext::contextEnter(nb::object context) {
   return PyThreadContextEntry::pushContext(context);
 }
@@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() {
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = getContext()->liveModules;
-  assert(liveModules.count(module.ptr) == 1 &&
-         "destroying module not in live map");
-  liveModules.erase(module.ptr);
-  mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  nb::gil_scoped_acquire acquire;
-  auto &liveModules = contextRef->liveModules;
-  auto it = liveModules.find(module.ptr);
-  if (it == liveModules.end()) {
-    // Create.
-    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-    // Note that the default return value policy on cast is automatic_reference,
-    // which does not take ownership (delete will not be called).
-    // Just be explicit.
-    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-    unownedModule->handle = pyRef;
-    liveModules[module.ptr] =
-        std::make_pair(unownedModule->handle, unownedModule);
-    return PyModuleRef(unownedModule, std::move(pyRef));
-  }
-  // Use existing.
-  PyModule *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyModuleRef(existing, std::move(pyRef));
+  // Create.
+  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+  // Note that the default return value policy on cast is `automatic_reference`,
+  // which means "does not take ownership, does not call delete/dtor".
+  // We use `take_ownership`, which means "Python will call the C++ destructor
+  // and delete operator when the Python wrapper is garbage collected", because
+  // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
+  // etc).
+  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+  unownedModule->handle = pyRef;
+  return PyModuleRef(unownedModule, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
   // If the operation has already been invalidated there is nothing to do.
   if (!valid)
     return;
-
-  // Otherwise, invalidate the operation and remove it from live map when it is
-  // attached.
-  if (isAttached()) {
-    getContext()->clearOperation(*this);
-  } else {
-    // And destroy it when it is detached, i.e. owned by Python, in which case
-    // all nested operations must be invalidated at removed from the live map as
-    // well.
+  // Otherwise, invalidate the operation when it is attached.
+  if (isAttached())
+    setInvalid();
+  else {
+    // And destroy it when it is detached, i.e. owned by Python.
     erase();
   }
 }
@@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
                                          MlirOperation operation,
                                          nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  auto it = liveOperations.find(operation.ptr);
-  if (it == liveOperations.end()) {
-    // Create.
-    PyOperationRef result = createInstance(std::move(contextRef), operation,
-                                           std::move(parentKeepAlive));
-    liveOperations[operation.ptr] =
-        std::make_pair(result.getObject(), result.get());
-    return result;
-  }
-  // Use existing.
-  PyOperation *existing = it->second.second;
-  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
-  return PyOperationRef(existing, std::move(pyRef));
+  return createInstance(std::move(contextRef), operation,
+                        std::move(parentKeepAlive));
 }
 
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            nb::object parentKeepAlive) {
-  nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
-  auto &liveOperations = contextRef->liveOperations;
-  assert(liveOperations.count(operation.ptr) == 0 &&
-         "cannot create detached operation that already exists");
-  (void)liveOperations;
   PyOperationRef created = createInstance(std::move(contextRef), operation,
                                           std::move(parentKeepAlive));
-  liveOperations[operation.ptr] =
-      std::make_pair(created.getObject(), created.get());
   created->attached = false;
   return created;
 }
@@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {
 
 void PyOperation::erase() {
   checkValid();
-  getContext()->clearOperationAndInside(*this);
+  setInvalid();
   mlirOperationDestroy(operation);
 }
 
@@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
-      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
-      .def("_get_live_operation_objects",
-           &PyMlirContext::getLiveOperationObjects)
-      .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
-      .def("_clear_live_operations_inside",
-           nb::overload_cast<MlirOperation>(
-               &PyMlirContext::clearOperationsInside))
-      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)
@@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   //----------------------------------------------------------------------------
   nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
-      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
+           kModuleCAPICreate)
+      .def("_clear_mlir_module", &PyModule::clearMlirModule)
       .def_static(
           "parse",
           [](const std::string &moduleAsm, DefaultingPyMlirContext context) {
@@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             // Defer to the operation's __str__.
             return self.attr("operation").attr("__str__")();
           },
-          kOperationStrDunderDocstring);
+          kOperationStrDunderDocstring)
+      .def(
+          "__eq__",
+          [](PyModule &self, PyModule &other) {
+            return mlirModuleEqual(self.get(), other.get());
+          },
+          "other"_a);
 
   //----------------------------------------------------------------------------
   // Mapping of Operation.
@@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                    })
       .def("__eq__",
            [](PyOperationBase &self, PyOperationBase &other) {
-             return &self.getOperation() == &other.getOperation();
+             return mlirOperationEqual(self.getOperation().get(),
+                                       other.getOperation().get());
            })
       .def("__eq__",
            [](PyOperationBase &self, nb::object other) { return false; })
@@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           [](PyOperationBase &self) {
             return PyOpSuccessors(self.getOperation().getRef());
           },
-          "Returns the list of Operation successors.");
+          "Returns the list of Operation successors.")
+      .def("_set_invalid", &PyOperation::setInvalid,
+           "Invalidate the operation.");
 
   auto opViewClass =
       nb::class_<PyOpView, PyOperationBase>(m, "OpView")
@@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
               [](PyOperationBase &self) {
                 return PyOpSuccessors(self.getOperation().getRef());
               },
-              "Returns the list of Operation successors.");
+              "Returns the list of Operation successors.")
+          .def(
+              "_set_invalid",
+              [](PyOpView &self) { self.getOperation().setInvalid(); },
+              "Invalidate the operation.");
   opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
   opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 6617b41cc916c..0cc0459ebc9a0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,40 +218,6 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
-  /// Get a list of Python objects which are still in the live context map.
-  std::vector<PyOperation *> getLiveOperationObjects();
-
-  /// Gets the count of live operations associated with this context.
-  /// Used for testing.
-  size_t getLiveOperationCount();
-
-  /// Clears the live operations map, returning the number of entries which were
-  /// invalidated. To be used as a safety mechanism so that API end-users can't
-  /// corrupt by holding references they shouldn't have accessed in the first
-  /// place.
-  size_t clearLiveOperations();
-
-  /// Removes an operation from the live operations map and sets it invalid.
-  /// This is useful for when some non-bindings code destroys the operation and
-  /// the bindings need to made aware. For example, in the case when pass
-  /// manager is run.
-  ///
-  /// Note that this does *NOT* clear the nested operations.
-  void clearOperation(MlirOperation op);
-
-  /// Clears all operations nested inside the given op using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationsInside(PyOperationBase &op);
-  void clearOperationsInside(MlirOperation op);
-
-  /// Clears the operaiton _and_ all operations inside using
-  /// `clearOperation(MlirOperation)`.
-  void clearOperationAndInside(PyOperationBase &op);
-
-  /// Gets the count of live modules associated with this context.
-  /// Used for testing.
-  size_t getLiveModuleCount();
-
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
-  // Interns all live modules associated with this context. Modules tracked
-  // in this map are valid. When a module is invalidated, it is removed
-  // from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveModuleMap =
-      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
-  LiveModuleMap liveModules;
-
-  // Interns all live operations associated with this context. Operations
-  // tracked in this map are valid. When an operation is invalidated, it is
-  // removed from this map, and while it still exists as an instance, any
-  // attempt to access it will raise an error.
-  using LiveOperationMap =
-      llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
-  nanobind::ft_mutex liveOperationsMutex;
-
-  // Guarded by liveOperationsMutex in free-threading mode.
-  LiveOperationMap liveOperations;
-
   bool emitErrorDiagnostics = false;
 
   MlirContext context;
@@ -548,8 +495,8 @@ class PyModule;
 using PyModuleRef = PyObjectRef<PyModule>;
 class PyModule : public BaseContextObject {
 public:
-  /// Returns a PyModule reference for the given MlirModule. This may return
-  /// a pre-existing or new object.
+  /// Returns a PyModule reference for the given MlirModule. This always returns
+  /// a new object.
   static PyModuleRef forModule(MlirModule module);
   PyModule(PyModule &) = delete;
   PyModule(PyMlirContext &&) = delete;
@@ -570,11 +517,12 @@ class PyModule : public BaseContextObject {
   nanobind::object getCapsule();
 
   /// Creates a PyModule from the MlirModule wrapped by a capsule.
-  /// Note that PyModule instances are uniqued, so the returned object
-  /// may be a pre-existing object. Ownership of the underlying MlirModule
-  /// is taken by calling this function.
+  /// Note this returns a new object BUT clearMlirModule() must be called to
+  /// prevent double-frees (of the underlying mlir::Module).
   static nanobind::object createFromCapsule(nanobind::object capsule);
 
+  void clearMlirModule() { module = {nullptr}; }
+
 private:
   PyModule(PyMlirContextRef contextRef, MlirModule module);
   MlirModule module;

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 1030dea7f364c..88e28dca76bb9 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
           "ValueError if the pipeline can't be parsed.")
       .def(
           "run",
-          [](PyPassManager &passManager, PyOperationBase &op,
-             bool invalidateOps) {
-            if (invalidateOps) {
-              op.getOperation().getContext()->clearOperationsInside(op);
-            }
+          [](PyPassManager &passManager, PyOperationBase &op) {
             // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
             MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
               throw MLIRError("Failure while executing pass pipeline",
                               errors.take());
           },
-          "operation"_a, "invalidate_ops"_a = true,
+          "operation"_a,
           "Run the pass manager on the provided operation, raising an "
           "MLIRError on failure.")
       .def(

diff  --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed62778f..920bca886f617 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
         // root. This is awkward, but we don't have access to PyMlirContext
         // object here otherwise.
         nb::object obj = nb::cast(payloadRoot);
-        obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
 
         MlirLogicalResult result = mlirTransformApplyNamedSequence(
             payloadRoot, transformRoot, transformModule, options.options);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553dab76f..c7069f0017b5d 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
   return wrap(dyn_cast<ModuleOp>(unwrap(op)));
 }
 
+bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) {
+  return unwrap(lhs) == unwrap(rhs);
+}
+
 //===----------------------------------------------------------------------===//
 // Operation state API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 6065e59fd6ed9..ad4c9340a6c82 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,27 +121,17 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    live_ops = ctx._get_live_operation_objects()
-    assert len(live_ops) == 1
-    assert live_ops[0] is op1
-    live_ops = None
     # CHECK: module @successfulParse
     print(op1)
 
     # Ensure that operations are the same on multiple calls.
     op2 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    assert op1 is op2
+    assert op1 is not op2
+    assert op1 == op2
 
     # Test live operation clearing.
     op1 = module.operation
-    assert ctx._get_live_operation_count() == 1
-    num_invalidated = ctx._clear_live_operations()
-    assert num_invalidated == 1
-    assert ctx._get_live_operation_count() == 0
     op1 = None
     gc.collect()
     op1 = module.operation
@@ -155,9 +145,6 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
-    print("LIVE OPERATIONS:", ctx._get_live_operation_count())
-    assert ctx._get_live_operation_count() == 0
-    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -165,16 +152,17 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
-    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is module_dup
+    assert module is not module_dup
+    assert module == module_dup
+    module._clear_mlir_module()
+    assert module != module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
-    assert ctx._get_live_module_count() == 0

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bf16e3f75d60d..7759b1797e3c3 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -907,7 +907,13 @@ def testCapsuleConversions():
         m_capsule = m._CAPIPtr
         assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
         m2 = Operation._CAPICreate(m_capsule)
-        assert m2 is m
+        assert m2 is not m
+        assert m2 == m
+        # Gc and verify destructed.
+        m = None
+        m_capsule = None
+        m2 = None
+        gc.collect()
 
 
 # CHECK-LABEL: TEST: testOperationErase

diff  --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 8b6d7ea5a197d..99d5fadfea10a 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,6 +56,7 @@ def testSymbolTableInsert():
         print(m1)
         assert "bar" not in symbol_table
 
+        bar._set_invalid()
         try:
             print(bar)
         except RuntimeError as e:

diff  --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e26d42bb32913..5f92f5b52a09a 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
 @run
 def testPostPassOpInvalidation():
     with Context() as ctx:
-        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
-        # CHECK: invalidate_ops=False
-        log("invalidate_ops=False")
-
-        # CHECK: live ops: 0
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
         """
         )
 
-        # CHECK: live ops: 1
-        log_op_count()
-
         outer_const_op = module.body.operations[0]
         # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
         log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
         # CHECK: %[[VAL1]] = arith.constant 10 : i64
         log(inner_const_op)
 
-        # CHECK: live ops: 4
-        log_op_count()
-
-        PassManager.parse("builtin.module(canonicalize)").run(
-            module, invalidate_ops=False
-        )
+        PassManager.parse("builtin.module(canonicalize)").run(module)
         # CHECK: func.func @foo() {
         # CHECK:   return
         # CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
         # CHECK: invalidate_ops=True
         log("invalidate_ops=True")
 
-        # CHECK: live ops: 4
-        log_op_count()
-
         module = ModuleOp.parse(
             """
           module {
@@ -247,30 +228,24 @@ def testPostPassOpInvalidation():
           }
         """
         )
-        outer_const_op = module.body.operations[0]
-        func_op = module.body.operations[1]
-        inner_const_op = func_op.body.blocks[0].operations[0]
-
-        # CHECK: live ops: 4
-        log_op_count()
 
         PassManager.parse("builtin.module(canonicalize)").run(module)
 
-        # CHECK: live ops: 1
-        log_op_count()
-
+        func_op._set_invalid()
         try:
             log(func_op)
         except RuntimeError as e:
             # CHECK: the operation has been invalidated
             log(e)
 
+        outer_const_op._set_invalid()
         try:
             log(outer_const_op)
         except RuntimeError as e:
             # CHECK: the operation has been invalidated
             log(e)
 
+        inner_const_op._set_invalid()
         try:
             log(inner_const_op)
         except RuntimeError as e:


        


More information about the Mlir-commits mailing list