[Mlir-commits] [mlir] 6a4f664 - [MLIR][Python] restore `liveModuleMap` (#158506)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Sep 14 21:45:34 PDT 2025
Author: Maksim Levental
Date: 2025-09-15T06:45:30+02:00
New Revision: 6a4f66476ff59a32898891345bc07547e71028ec
URL: https://github.com/llvm/llvm-project/commit/6a4f66476ff59a32898891345bc07547e71028ec
DIFF: https://github.com/llvm/llvm-project/commit/6a4f66476ff59a32898891345bc07547e71028ec.diff
LOG: [MLIR][Python] restore `liveModuleMap` (#158506)
There are cases where the same module can have multiple references (via
`PyModule::forModule` via `PyModule::createFromCapsule`) and thus when
`PyModule`s get gc'd `mlirModuleDestroy` can get called multiple times
for the same actual underlying `mlir::Module` (i.e., double free). So we
do actually need a "liveness map" for modules.
Note, if `type_caster<MlirModule>::from_cpp` weren't a thing we could guarantree
this never happened except explicitly when users called `PyModule::createFromCapsule`.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/python/ir/module.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 8273a9346e5dd..10360e448858c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1079,23 +1079,38 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() { mlirModuleDestroy(module); }
+PyModule::~PyModule() {
+ nb::gil_scoped_acquire acquire;
+ auto &liveModules = getContext()->liveModules;
+ assert(liveModules.count(module.ptr) == 1 &&
+ "destroying module not in live map");
+ liveModules.erase(module.ptr);
+ mlirModuleDestroy(module);
+}
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is `automatic_reference`,
- // which means "does not take ownership, does not call delete/dtor".
- // We use `take_ownership`, which means "Python will call the C++ destructor
- // and delete operator when the Python wrapper is garbage collected", because
- // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
- // etc).
- nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
- unownedModule->handle = pyRef;
- return PyModuleRef(unownedModule, std::move(pyRef));
+ nb::gil_scoped_acquire acquire;
+ auto &liveModules = contextRef->liveModules;
+ auto it = liveModules.find(module.ptr);
+ if (it == liveModules.end()) {
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ liveModules[module.ptr] =
+ std::make_pair(unownedModule->handle, unownedModule);
+ return PyModuleRef(unownedModule, std::move(pyRef));
+ }
+ // Use existing.
+ PyModule *existing = it->second.second;
+ nb::object pyRef = nb::borrow<nb::object>(it->second.first);
+ return PyModuleRef(existing, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -2084,6 +2099,8 @@ PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) {
return PyInsertionPoint{block, std::move(nextOpRef)};
}
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) {
return PyThreadContextEntry::pushInsertionPoint(insertPoint);
}
@@ -2923,6 +2940,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 1d1ff29533f98..28b885f136fe0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,6 +218,10 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
+ /// Gets the count of live modules associated with this context.
+ /// Used for testing.
+ size_t getLiveModuleCount();
+
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -244,6 +248,14 @@ class PyMlirContext {
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
+ // Interns all live modules associated with this context. Modules tracked
+ // in this map are valid. When a module is invalidated, it is removed
+ // from this map, and while it still exists as an instance, any
+ // attempt to access it will raise an error.
+ using LiveModuleMap =
+ llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
+ LiveModuleMap liveModules;
+
bool emitErrorDiagnostics = false;
MlirContext context;
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ad4c9340a6c82..33959bea9ffb6 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,6 +121,7 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
op1 = module.operation
# CHECK: module @successfulParse
print(op1)
@@ -145,6 +146,7 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
+ assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@@ -152,17 +154,17 @@ def testModuleOperation():
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
+ assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
- assert module is not module_dup
+ assert module is module_dup
assert module == module_dup
- module._clear_mlir_module()
- assert module != module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
+ assert ctx._get_live_module_count() == 0
More information about the Mlir-commits
mailing list