[Mlir-commits] [mlir] [mlir][python] remove liveOpeartions (PR #155114)
Maksim Levental
llvmlistbot at llvm.org
Sat Aug 23 12:36:11 PDT 2025
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/155114
>From 2539ddf588230974b94c937430e9e19e7d313c74 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 17 May 2024 22:36:41 -0500
Subject: [PATCH 1/4] [mlir][python] wip remove liveOpeartions
---
mlir/include/mlir-c/IR.h | 2 +
mlir/lib/Bindings/Python/IRCore.cpp | 174 +++---------------
mlir/lib/Bindings/Python/IRModule.h | 53 ------
mlir/lib/Bindings/Python/Pass.cpp | 8 +-
.../Bindings/Python/TransformInterpreter.cpp | 1 -
mlir/lib/CAPI/IR/IR.cpp | 4 +
mlir/python/mlir/extras/passes.py | 140 ++++++++++++++
mlir/test/python/ir/module.py | 20 +-
mlir/test/python/ir/operation.py | 2 +-
mlir/test/python/ir/symbol_table.py | 8 -
mlir/test/python/pass_manager.py | 27 +--
11 files changed, 176 insertions(+), 263 deletions(-)
create mode 100644 mlir/python/mlir/extras/passes.py
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 71c7d4378677f..d20b46651383f 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -415,6 +415,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
/// The returned module is null when the input operation was not a ModuleOp.
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule mod, MlirModule other);
+
//===----------------------------------------------------------------------===//
// Operation state.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b3a06cbce854..2c643db10f349 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -702,84 +702,6 @@ size_t PyMlirContext::getLiveCount() {
return getLiveContexts().size();
}
-size_t PyMlirContext::getLiveOperationCount() {
- nb::ft_lock_guard lock(liveOperationsMutex);
- return liveOperations.size();
-}
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
- std::vector<PyOperation *> liveObjects;
- nb::ft_lock_guard lock(liveOperationsMutex);
- for (auto &entry : liveOperations)
- liveObjects.push_back(entry.second.second);
- return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
-
- LiveOperationMap operations;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- std::swap(operations, liveOperations);
- }
- for (auto &op : operations)
- op.second.second->setInvalid();
- size_t numInvalidated = operations.size();
- return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
- PyOperation *py_op;
- {
- nb::ft_lock_guard lock(liveOperationsMutex);
- auto it = liveOperations.find(op.ptr);
- if (it == liveOperations.end()) {
- return;
- }
- py_op = it->second.second;
- liveOperations.erase(it);
- }
- py_op->setInvalid();
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
- typedef struct {
- PyOperation &rootOp;
- bool rootSeen;
- } callBackData;
- callBackData data{op.getOperation(), false};
- // Mark all ops below the op that the passmanager will be rooted
- // at (but not op itself - note the preorder) as invalid.
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- callBackData *data = static_cast<callBackData *>(userData);
- if (LLVM_LIKELY(data->rootSeen))
- data->rootOp.getOperation().getContext()->clearOperation(op);
- else
- data->rootSeen = true;
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
- PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
- clearOperationsInside(opRef->getOperation());
-}
-
-void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
- contextRef->clearOperation(op);
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- &op.getOperation().getContext(), MlirWalkPreOrder);
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
nb::object PyMlirContext::contextEnter(nb::object context) {
return PyThreadContextEntry::pushContext(context);
}
@@ -1151,38 +1073,20 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() {
- nb::gil_scoped_acquire acquire;
- auto &liveModules = getContext()->liveModules;
- assert(liveModules.count(module.ptr) == 1 &&
- "destroying module not in live map");
- liveModules.erase(module.ptr);
- mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- nb::gil_scoped_acquire acquire;
- auto &liveModules = contextRef->liveModules;
- auto it = liveModules.find(module.ptr);
- if (it == liveModules.end()) {
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
- unownedModule->handle = pyRef;
- liveModules[module.ptr] =
- std::make_pair(unownedModule->handle, unownedModule);
- return PyModuleRef(unownedModule, std::move(pyRef));
- }
- // Use existing.
- PyModule *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyModuleRef(existing, std::move(pyRef));
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ return PyModuleRef(unownedModule, std::move(pyRef));
}
nb::object PyModule::createFromCapsule(nb::object capsule) {
@@ -1207,16 +1111,8 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
-
- // Otherwise, invalidate the operation and remove it from live map when it is
- // attached.
- if (isAttached()) {
- getContext()->clearOperation(*this);
- } else {
- // And destroy it when it is detached, i.e. owned by Python, in which case
- // all nested operations must be invalidated at removed from the live map as
- // well.
- erase();
+ if (!isAttached()) {
+ mlirOperationDestroy(operation);
}
}
@@ -1246,41 +1142,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
- return unownedOperation;
+ return PyOperationRef(unownedOperation, std::move(pyRef));
}
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- auto it = liveOperations.find(operation.ptr);
- if (it == liveOperations.end()) {
- // Create.
- PyOperationRef result = createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(result.getObject(), result.get());
- return result;
- }
- // Use existing.
- PyOperation *existing = it->second.second;
- nb::object pyRef = nb::borrow<nb::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+ // Create.
+ return createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
- nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
- auto &liveOperations = contextRef->liveOperations;
- assert(liveOperations.count(operation.ptr) == 0 &&
- "cannot create detached operation that already exists");
- (void)liveOperations;
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
- liveOperations[operation.ptr] =
- std::make_pair(created.getObject(), created.get());
created->attached = false;
return created;
}
@@ -1652,7 +1529,6 @@ nb::object PyOperation::createOpView() {
void PyOperation::erase() {
checkValid();
- getContext()->clearOperationAndInside(*this);
mlirOperationDestroy(operation);
}
@@ -2494,7 +2370,6 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
@@ -3023,14 +2898,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
- .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
- .def("_get_live_operation_objects",
- &PyMlirContext::getLiveOperationObjects)
- .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
- .def("_clear_live_operations_inside",
- nb::overload_cast<MlirOperation>(
- &PyMlirContext::clearOperationsInside))
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
@@ -3428,7 +3295,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
- kOperationStrDunderDocstring);
+ kOperationStrDunderDocstring)
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a);
//----------------------------------------------------------------------------
// Mapping of Operation.
@@ -3440,7 +3313,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
- return &self.getOperation() == &other.getOperation();
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, nb::object other) { return false; })
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fa16ae3ce3294..027bfef5cba07 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -218,40 +218,6 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
- /// Get a list of Python objects which are still in the live context map.
- std::vector<PyOperation *> getLiveOperationObjects();
-
- /// Gets the count of live operations associated with this context.
- /// Used for testing.
- size_t getLiveOperationCount();
-
- /// Clears the live operations map, returning the number of entries which were
- /// invalidated. To be used as a safety mechanism so that API end-users can't
- /// corrupt by holding references they shouldn't have accessed in the first
- /// place.
- size_t clearLiveOperations();
-
- /// Removes an operation from the live operations map and sets it invalid.
- /// This is useful for when some non-bindings code destroys the operation and
- /// the bindings need to made aware. For example, in the case when pass
- /// manager is run.
- ///
- /// Note that this does *NOT* clear the nested operations.
- void clearOperation(MlirOperation op);
-
- /// Clears all operations nested inside the given op using
- /// `clearOperation(MlirOperation)`.
- void clearOperationsInside(PyOperationBase &op);
- void clearOperationsInside(MlirOperation op);
-
- /// Clears the operaiton _and_ all operations inside using
- /// `clearOperation(MlirOperation)`.
- void clearOperationAndInside(PyOperationBase &op);
-
- /// Gets the count of live modules associated with this context.
- /// Used for testing.
- size_t getLiveModuleCount();
-
/// Enter and exit the context manager.
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
@@ -278,25 +244,6 @@ class PyMlirContext {
static nanobind::ft_mutex live_contexts_mutex;
static LiveContextMap &getLiveContexts();
- // Interns all live modules associated with this context. Modules tracked
- // in this map are valid. When a module is invalidated, it is removed
- // from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveModuleMap =
- llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
- LiveModuleMap liveModules;
-
- // Interns all live operations associated with this context. Operations
- // tracked in this map are valid. When an operation is invalidated, it is
- // removed from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveOperationMap =
- llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
- nanobind::ft_mutex liveOperationsMutex;
-
- // Guarded by liveOperationsMutex in free-threading mode.
- LiveOperationMap liveOperations;
-
bool emitErrorDiagnostics = false;
MlirContext context;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 20017e25b69bb..817479ee2421b 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op,
- bool invalidateOps) {
- if (invalidateOps) {
- op.getOperation().getContext()->clearOperationsInside(op);
- }
+ [](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- "operation"_a, "invalidate_ops"_a = true,
+ "operation"_a,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f9b0fed62778f..920bca886f617 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) {
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
nb::object obj = nb::cast(payloadRoot);
- obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8491553dab76f..7ea799c836889 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}
+bool mlirModuleEqual(MlirModule mod, MlirModule other) {
+ return unwrap(mod) == unwrap(other);
+}
+
//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/extras/passes.py b/mlir/python/mlir/extras/passes.py
new file mode 100644
index 0000000000000..be26e42a99144
--- /dev/null
+++ b/mlir/python/mlir/extras/passes.py
@@ -0,0 +1,140 @@
+import contextlib
+import logging
+import os
+import sys
+import tempfile
+from contextlib import ExitStack
+from io import StringIO
+from typing import Optional, List, Union
+
+from ..ir import StringAttr, Module
+from ..passmanager import PassManager
+
+
+ at contextlib.contextmanager
+def disable_multithreading(context=None):
+ from ..ir import Context
+
+ if context is None:
+ context = Context.current
+
+ context.enable_multithreading(False)
+ yield
+ context.enable_multithreading(True)
+
+
+logger = logging.getLogger(__name__)
+
+
+def get_module_name_for_debug_dump(module):
+ if "debug_module_name" not in module.operation.attributes:
+ return "UnnammedMLIRModule"
+ return StringAttr(module.operation.attributes["debug_module_name"]).value
+
+
+def run_pipeline(
+ module,
+ pipeline: Union[str, "Pipeline"],
+ description: Optional[str] = None,
+ enable_ir_printing=False,
+ print_pipeline=False,
+ verify=True,
+):
+ module = Module.parse(str(module))
+ if isinstance(pipeline, Pipeline):
+ pipeline = str(pipeline)
+ module_name = get_module_name_for_debug_dump(module)
+ try:
+ original_stderr = sys.stderr
+ sys.stderr = StringIO()
+ with ExitStack() as stack:
+ stack.enter_context(module.context)
+ asm_for_error_report = module.operation.get_asm(
+ large_elements_limit=10, enable_debug_info=True
+ )
+ pm = PassManager.parse(pipeline)
+ pm.enable_verifier(verify)
+ if print_pipeline:
+ print(pm)
+ if enable_ir_printing:
+ stack.enter_context(disable_multithreading())
+ pm.enable_ir_printing()
+ pm.run(module.operation)
+ except Exception as e:
+ print(e, file=sys.stderr)
+ filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
+ with open(filename, "w") as f:
+ f.write(asm_for_error_report)
+ debug_options = "-mlir-print-ir-after-all -mlir-disable-threading"
+ description = description or f"{module_name} compile"
+
+ message = f"""\
+ {description} failed with the following diagnostics:
+
+ {'*' * 80}
+ {sys.stderr.getvalue().strip()}
+ {'*' * 80}
+
+ For developers, the error can be reproduced with:
+ $ mlir-opt {debug_options} -pass-pipeline='{pipeline}' {filename}
+ """
+ trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
+ raise RuntimeError(trimmed_message)
+ finally:
+ sys.stderr = original_stderr
+
+ return module
+
+
+class Pipeline:
+ _pipeline: List[str] = []
+
+ def __init__(self, pipeline=None, wrapper=None):
+ if pipeline is None:
+ pipeline = []
+ self._pipeline = pipeline
+
+ def Nested(self, context, p: "Pipeline"):
+ self._pipeline.append(f"{context}({p.materialize(module=False)})")
+ return self
+
+ def Func(self, p: "Pipeline"):
+ return self.Nested("func.func", p)
+
+ def Spirv(self, p: "Pipeline"):
+ return self.Nested("spirv.module", p)
+
+ def Gpu(self, p: "Pipeline"):
+ assert isinstance(p, Pipeline)
+ return self.Nested("gpu.module", p)
+
+ def materialize(self, module=True):
+ pipeline_str = ",".join(self._pipeline)
+ if module:
+ pipeline_str = f"builtin.module({pipeline_str})"
+ logger.debug(f"{pipeline_str}")
+ return pipeline_str
+
+ def __str__(self):
+ return self.materialize()
+
+ def __iadd__(self, other: "Pipeline"):
+ self._pipeline.extend(other._pipeline)
+ return self
+
+ def __add__(self, other: "Pipeline"):
+ return Pipeline(self._pipeline + other._pipeline)
+
+ def add_pass(self, pass_name, **kwargs):
+ kwargs = {
+ k.replace("_", "-"): int(v) if isinstance(v, bool) else v
+ for k, v in kwargs.items()
+ if v is not None
+ }
+ if kwargs:
+ args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())
+ pass_str = f"{pass_name}{{ {args_str} }}"
+ else:
+ pass_str = f"{pass_name}"
+ self._pipeline.append(pass_str)
+ return self
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 6065e59fd6ed9..59b992d9226ee 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -121,27 +121,16 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- live_ops = ctx._get_live_operation_objects()
- assert len(live_ops) == 1
- assert live_ops[0] is op1
- live_ops = None
# CHECK: module @successfulParse
print(op1)
# Ensure that operations are the same on multiple calls.
op2 = module.operation
- assert ctx._get_live_operation_count() == 1
- assert op1 is op2
+ assert op1 == op2
# Test live operation clearing.
op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- num_invalidated = ctx._clear_live_operations()
- assert num_invalidated == 1
- assert ctx._get_live_operation_count() == 0
op1 = None
gc.collect()
op1 = module.operation
@@ -155,9 +144,6 @@ def testModuleOperation():
op1 = None
op2 = None
gc.collect()
- print("LIVE OPERATIONS:", ctx._get_live_operation_count())
- assert ctx._get_live_operation_count() == 0
- assert ctx._get_live_module_count() == 0
# CHECK-LABEL: TEST: testModuleCapsule
@@ -165,16 +151,14 @@ def testModuleOperation():
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
module_dup = Module._CAPICreate(module_capsule)
- assert module is module_dup
+ assert module == module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
module_capsule = None
module_dup = None
gc.collect()
- assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index bf16e3f75d60d..c95a97b9731a2 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -907,7 +907,7 @@ def testCapsuleConversions():
m_capsule = m._CAPIPtr
assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
m2 = Operation._CAPICreate(m_capsule)
- assert m2 is m
+ assert m2 == m
# CHECK-LABEL: TEST: testOperationErase
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 8b6d7ea5a197d..7afd539271d21 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,14 +56,6 @@ def testSymbolTableInsert():
print(m1)
assert "bar" not in symbol_table
- try:
- print(bar)
- except RuntimeError as e:
- if "the operation has been invalidated" not in str(e):
- raise
- else:
- assert False, "expected RuntimeError due to invalidated operation"
-
qux = m2.body.operations[0]
m1.body.append(qux)
symbol_table.insert(qux)
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e26d42bb32913..aea8803a57bc5 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
- log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
- # CHECK: invalidate_ops=False
- log("invalidate_ops=False")
-
- # CHECK: live ops: 0
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
"""
)
- # CHECK: live ops: 1
- log_op_count()
-
outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)
- # CHECK: live ops: 4
- log_op_count()
-
- PassManager.parse("builtin.module(canonicalize)").run(
- module, invalidate_ops=False
- )
+ PassManager.parse("builtin.module(canonicalize)").run(module)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")
- # CHECK: live ops: 4
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -251,14 +232,8 @@ def testPostPassOpInvalidation():
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]
- # CHECK: live ops: 4
- log_op_count()
-
PassManager.parse("builtin.module(canonicalize)").run(module)
- # CHECK: live ops: 1
- log_op_count()
-
try:
log(func_op)
except RuntimeError as e:
>From c84d11741b2130b88c5460f882eb6f2bd5c58e2d Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 14:30:14 -0400
Subject: [PATCH 2/4] "fix" testPostPassOpInvalidation
---
mlir/test/python/pass_manager.py | 21 ---------------------
1 file changed, 21 deletions(-)
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index aea8803a57bc5..0896cd9784641 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -228,30 +228,9 @@ def testPostPassOpInvalidation():
}
"""
)
- outer_const_op = module.body.operations[0]
- func_op = module.body.operations[1]
- inner_const_op = func_op.body.blocks[0].operations[0]
PassManager.parse("builtin.module(canonicalize)").run(module)
- try:
- log(func_op)
- except RuntimeError as e:
- # CHECK: the operation has been invalidated
- log(e)
-
- try:
- log(outer_const_op)
- except RuntimeError as e:
- # CHECK: the operation has been invalidated
- log(e)
-
- try:
- log(inner_const_op)
- except RuntimeError as e:
- # CHECK: the operation has been invalidated
- log(e)
-
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
>From 6a1d89c78e796e0ad5d8990432902100dfb2e4eb Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 15:06:07 -0400
Subject: [PATCH 3/4] try to fix testModuleCapsule
---
mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++++----
mlir/lib/Bindings/Python/IRModule.h | 6 +++++-
mlir/test/python/ir/module.py | 4 ++--
3 files changed, 19 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2c643db10f349..e164d6fbd4db2 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1089,10 +1089,12 @@ PyModuleRef PyModule::forModule(MlirModule module) {
return PyModuleRef(unownedModule, std::move(pyRef));
}
-nb::object PyModule::createFromCapsule(nb::object capsule) {
+nb::object PyModule::createFromCapsule(PyModule &self) {
+ nb::object capsule = self.getCapsule();
MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
if (mlirModuleIsNull(rawModule))
throw nb::python_error();
+ self.clearReferrent();
return forModule(rawModule).releaseObject();
}
@@ -1135,13 +1137,19 @@ PyObjectRef<T> makeObjectRef(Args &&...args) {
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
+
// Create.
- PyOperationRef unownedOperation =
- makeObjectRef<PyOperation>(std::move(contextRef), operation);
- unownedOperation->handle = unownedOperation.getObject();
+ PyOperation *unownedOperation =
+ new PyOperation(std::move(contextRef), operation);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership);
+ unownedOperation->handle = pyRef;
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
+
return PyOperationRef(unownedOperation, std::move(pyRef));
}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 027bfef5cba07..fcfe8b3298c9b 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -72,6 +72,8 @@ class PyObjectRef {
return Py_REFCNT(object.ptr());
}
+ void clearReferrent() { referrent = nullptr; }
+
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
@@ -520,7 +522,9 @@ class PyModule : public BaseContextObject {
/// Note that PyModule instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirModule
/// is taken by calling this function.
- static nanobind::object createFromCapsule(nanobind::object capsule);
+ static nanobind::object createFromCapsule(PyModule &self);
+
+ void clearReferrent() { module = {nullptr}; }
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 59b992d9226ee..a2c60181aa0ec 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -154,8 +154,8 @@ def testModuleCapsule():
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
- module_dup = Module._CAPICreate(module_capsule)
- assert module == module_dup
+ module_dup = Module._CAPICreate(module)
+ # assert module == module_dup
assert module_dup.context is ctx
# Gc and verify destructed.
module = None
>From b0e6bd6f77f82d1fde915b27617abf91248421e8 Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 23 Aug 2025 15:36:00 -0400
Subject: [PATCH 4/4] remove stray file
---
mlir/python/mlir/extras/passes.py | 140 ------------------------------
1 file changed, 140 deletions(-)
delete mode 100644 mlir/python/mlir/extras/passes.py
diff --git a/mlir/python/mlir/extras/passes.py b/mlir/python/mlir/extras/passes.py
deleted file mode 100644
index be26e42a99144..0000000000000
--- a/mlir/python/mlir/extras/passes.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import contextlib
-import logging
-import os
-import sys
-import tempfile
-from contextlib import ExitStack
-from io import StringIO
-from typing import Optional, List, Union
-
-from ..ir import StringAttr, Module
-from ..passmanager import PassManager
-
-
- at contextlib.contextmanager
-def disable_multithreading(context=None):
- from ..ir import Context
-
- if context is None:
- context = Context.current
-
- context.enable_multithreading(False)
- yield
- context.enable_multithreading(True)
-
-
-logger = logging.getLogger(__name__)
-
-
-def get_module_name_for_debug_dump(module):
- if "debug_module_name" not in module.operation.attributes:
- return "UnnammedMLIRModule"
- return StringAttr(module.operation.attributes["debug_module_name"]).value
-
-
-def run_pipeline(
- module,
- pipeline: Union[str, "Pipeline"],
- description: Optional[str] = None,
- enable_ir_printing=False,
- print_pipeline=False,
- verify=True,
-):
- module = Module.parse(str(module))
- if isinstance(pipeline, Pipeline):
- pipeline = str(pipeline)
- module_name = get_module_name_for_debug_dump(module)
- try:
- original_stderr = sys.stderr
- sys.stderr = StringIO()
- with ExitStack() as stack:
- stack.enter_context(module.context)
- asm_for_error_report = module.operation.get_asm(
- large_elements_limit=10, enable_debug_info=True
- )
- pm = PassManager.parse(pipeline)
- pm.enable_verifier(verify)
- if print_pipeline:
- print(pm)
- if enable_ir_printing:
- stack.enter_context(disable_multithreading())
- pm.enable_ir_printing()
- pm.run(module.operation)
- except Exception as e:
- print(e, file=sys.stderr)
- filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
- with open(filename, "w") as f:
- f.write(asm_for_error_report)
- debug_options = "-mlir-print-ir-after-all -mlir-disable-threading"
- description = description or f"{module_name} compile"
-
- message = f"""\
- {description} failed with the following diagnostics:
-
- {'*' * 80}
- {sys.stderr.getvalue().strip()}
- {'*' * 80}
-
- For developers, the error can be reproduced with:
- $ mlir-opt {debug_options} -pass-pipeline='{pipeline}' {filename}
- """
- trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
- raise RuntimeError(trimmed_message)
- finally:
- sys.stderr = original_stderr
-
- return module
-
-
-class Pipeline:
- _pipeline: List[str] = []
-
- def __init__(self, pipeline=None, wrapper=None):
- if pipeline is None:
- pipeline = []
- self._pipeline = pipeline
-
- def Nested(self, context, p: "Pipeline"):
- self._pipeline.append(f"{context}({p.materialize(module=False)})")
- return self
-
- def Func(self, p: "Pipeline"):
- return self.Nested("func.func", p)
-
- def Spirv(self, p: "Pipeline"):
- return self.Nested("spirv.module", p)
-
- def Gpu(self, p: "Pipeline"):
- assert isinstance(p, Pipeline)
- return self.Nested("gpu.module", p)
-
- def materialize(self, module=True):
- pipeline_str = ",".join(self._pipeline)
- if module:
- pipeline_str = f"builtin.module({pipeline_str})"
- logger.debug(f"{pipeline_str}")
- return pipeline_str
-
- def __str__(self):
- return self.materialize()
-
- def __iadd__(self, other: "Pipeline"):
- self._pipeline.extend(other._pipeline)
- return self
-
- def __add__(self, other: "Pipeline"):
- return Pipeline(self._pipeline + other._pipeline)
-
- def add_pass(self, pass_name, **kwargs):
- kwargs = {
- k.replace("_", "-"): int(v) if isinstance(v, bool) else v
- for k, v in kwargs.items()
- if v is not None
- }
- if kwargs:
- args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())
- pass_str = f"{pass_name}{{ {args_str} }}"
- else:
- pass_str = f"{pass_name}"
- self._pipeline.append(pass_str)
- return self
More information about the Mlir-commits
mailing list