[Mlir-commits] [mlir] 6a4f664 - [MLIR][Python] restore `liveModuleMap` (#158506)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 14 21:45:34 PDT 2025


Author: Maksim Levental
Date: 2025-09-15T06:45:30+02:00
New Revision: 6a4f66476ff59a32898891345bc07547e71028ec

URL: https://github.com/llvm/llvm-project/commit/6a4f66476ff59a32898891345bc07547e71028ec
DIFF: https://github.com/llvm/llvm-project/commit/6a4f66476ff59a32898891345bc07547e71028ec.diff

LOG: [MLIR][Python] restore `liveModuleMap` (#158506)

There are cases where the same module can have multiple references (via
`PyModule::forModule` via `PyModule::createFromCapsule`) and thus when
`PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times
for the same actual underlying `mlir::Module` (i.e., double free). So we
do actually need a "liveness map" for modules.

Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree
this never happened except explicitly when users called `PyModule::createFromCapsule`.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/python/ir/module.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8273a9346e5dd..10360e448858c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
     : BaseContextObject(std::move(contextRef)), module(module) {}
 
-PyModule::~PyModule() { mlirModuleDestroy(module); }
+PyModule::~PyModule() {
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = getContext()->liveModules;
+  assert(liveModules.count(module.ptr) == 1 &&
+         "destroying module not in live map");
+  liveModules.erase(module.ptr);
+  mlirModuleDestroy(module);
+}
 
 PyModuleRef PyModule::forModule(MlirModule module) {
   MlirContext context = mlirModuleGetContext(module);
   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
 
-  // Create.
-  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
-  // Note that the default return value policy on cast is `automatic_reference`,
-  // which means "does not take ownership, does not call delete/dtor".
-  // We use `take_ownership`, which means "Python will call the C++ destructor
-  // and delete operator when the Python wrapper is garbage collected", because
-  // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
-  // etc).
-  nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
-  unownedModule->handle = pyRef;
-  return PyModuleRef(unownedModule, std::move(pyRef));
+  nb::gil_scoped_acquire acquire;
+  auto &liveModules = contextRef->liveModules;
+  auto it = liveModules.find(module.ptr);
+  if (it == liveModules.end()) {
+    // Create.
+    PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+    // Note that the default return value policy on cast is automatic_reference,
+    // which does not take ownership (delete will not be called).
+    // Just be explicit.
+    nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+    unownedModule->handle = pyRef;
+    liveModules[module.ptr] =
+        std::make_pair(unownedModule->handle, unownedModule);
+    return PyModuleRef(unownedModule, std::move(pyRef));
+  }
+  // Use existing.
+  PyModule *existing = it->second.second;
+  nb::object pyRef = nb::borrow<nb::object>(it->second.first);
+  return PyModuleRef(existing, std::move(pyRef));
 }
 
 nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
   return PyInsertionPoint{block, std::move(nextOpRef)};
 }
 
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
 nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
   return PyThreadContextEntry::pushInsertionPoint(insertPoint);
 }
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
              return ref.releaseObject();
            })
+      .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
       .def("__enter__", &PyMlirContext::contextEnter)

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1d1ff29533f98..28b885f136fe0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,10 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Gets the count of live modules associated with this context.
+  /// Used for testing.
+  size_t getLiveModuleCount();
+
   /// Enter and exit the context manager.
   static nanobind::object contextEnter(nanobind::object context);
   void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
   static nanobind::ft_mutex live_contexts_mutex;
   static LiveContextMap &getLiveContexts();
 
+  // Interns all live modules associated with this context. Modules tracked
+  // in this map are valid. When a module is invalidated, it is removed
+  // from this map, and while it still exists as an instance, any
+  // attempt to access it will raise an error.
+  using LiveModuleMap =
+      llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
+  LiveModuleMap liveModules;
+
   bool emitErrorDiagnostics = false;
 
   MlirContext context;

diff  --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ad4c9340a6c82..33959bea9ffb6 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,6 +121,7 @@ def testRoundtripBinary():
 def testModuleOperation():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     op1 = module.operation
     # CHECK: module @successfulParse
     print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
     op1 = None
     op2 = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0
 
 
 # CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,17 @@ def testModuleOperation():
 def testModuleCapsule():
     ctx = Context()
     module = Module.parse(r"""module @successfulParse {}""", ctx)
+    assert ctx._get_live_module_count() == 1
     # CHECK: "mlir.ir.Module._CAPIPtr"
     module_capsule = module._CAPIPtr
     print(module_capsule)
     module_dup = Module._CAPICreate(module_capsule)
-    assert module is not module_dup
+    assert module is module_dup
     assert module == module_dup
-    module._clear_mlir_module()
-    assert module != module_dup
     assert module_dup.context is ctx
     # Gc and verify destructed.
     module = None
     module_capsule = None
     module_dup = None
     gc.collect()
+    assert ctx._get_live_module_count() == 0


        


More information about the Mlir-commits mailing list