[Mlir-commits] [mlir] [mlir][python] wip remove liveOpeartions (PR #92631)
Maksim Levental
llvmlistbot at llvm.org
Sat May 18 11:16:54 PDT 2024
https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/92631
>From be31d6ac8d6f2f490dbd6fe4a988877420f11d62 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Fri, 17 May 2024 22:36:41 -0500
Subject: [PATCH] [mlir][python] wip remove liveOpeartions
---
mlir/include/mlir-c/IR.h | 2 +
mlir/lib/Bindings/Python/IRCore.cpp | 138 +++---------------
mlir/lib/Bindings/Python/IRModule.h | 44 ------
mlir/lib/Bindings/Python/Pass.cpp | 8 +-
.../Bindings/Python/TransformInterpreter.cpp | 1 -
mlir/lib/CAPI/IR/IR.cpp | 4 +
mlir/test/python/ir/module.py | 27 ----
mlir/test/python/ir/symbol_table.py | 8 -
mlir/test/python/pass_manager.py | 27 +---
9 files changed, 31 insertions(+), 228 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 32abacf353133..f837b13c88811 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -323,6 +323,8 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
/// The returned module is null when the input operation was not a ModuleOp.
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);
+MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule mod, MlirModule other);
+
//===----------------------------------------------------------------------===//
// Operation state.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 01678a9719f90..3c1d8456f690b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -634,58 +634,6 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
-size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
-
-std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
- std::vector<PyOperation *> liveObjects;
- for (auto &entry : liveOperations)
- liveObjects.push_back(entry.second.second);
- return liveObjects;
-}
-
-size_t PyMlirContext::clearLiveOperations() {
- for (auto &op : liveOperations)
- op.second.second->setInvalid();
- size_t numInvalidated = liveOperations.size();
- liveOperations.clear();
- return numInvalidated;
-}
-
-void PyMlirContext::clearOperation(MlirOperation op) {
- auto it = liveOperations.find(op.ptr);
- if (it != liveOperations.end()) {
- it->second.second->setInvalid();
- liveOperations.erase(it);
- }
-}
-
-void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
- typedef struct {
- PyOperation &rootOp;
- bool rootSeen;
- } callBackData;
- callBackData data{op.getOperation(), false};
- // Mark all ops below the op that the passmanager will be rooted
- // at (but not op itself - note the preorder) as invalid.
- MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
- void *userData) {
- callBackData *data = static_cast<callBackData *>(userData);
- if (LLVM_LIKELY(data->rootSeen))
- data->rootOp.getOperation().getContext()->clearOperation(op);
- else
- data->rootSeen = true;
- return MlirWalkResult::MlirWalkResultAdvance;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- static_cast<void *>(&data), MlirWalkPreOrder);
-}
-void PyMlirContext::clearOperationsInside(MlirOperation op) {
- PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
- clearOperationsInside(opRef->getOperation());
-}
-
-size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
-
pybind11::object PyMlirContext::contextEnter() {
return PyThreadContextEntry::pushContext(*this);
}
@@ -1055,39 +1003,21 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
-PyModule::~PyModule() {
- py::gil_scoped_acquire acquire;
- auto &liveModules = getContext()->liveModules;
- assert(liveModules.count(module.ptr) == 1 &&
- "destroying module not in live map");
- liveModules.erase(module.ptr);
- mlirModuleDestroy(module);
-}
+PyModule::~PyModule() { mlirModuleDestroy(module); }
PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);
- py::gil_scoped_acquire acquire;
- auto &liveModules = contextRef->liveModules;
- auto it = liveModules.find(module.ptr);
- if (it == liveModules.end()) {
- // Create.
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- py::object pyRef =
- py::cast(unownedModule, py::return_value_policy::take_ownership);
- unownedModule->handle = pyRef;
- liveModules[module.ptr] =
- std::make_pair(unownedModule->handle, unownedModule);
- return PyModuleRef(unownedModule, std::move(pyRef));
- }
- // Use existing.
- PyModule *existing = it->second.second;
- py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
- return PyModuleRef(existing, std::move(pyRef));
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedModule, py::return_value_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ return PyModuleRef(unownedModule, std::move(pyRef));
}
py::object PyModule::createFromCapsule(py::object capsule) {
@@ -1112,10 +1042,6 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;
- auto &liveOperations = getContext()->liveOperations;
- assert(liveOperations.count(operation.ptr) == 1 &&
- "destroying operation not in live map");
- liveOperations.erase(operation.ptr);
if (!isAttached()) {
mlirOperationDestroy(operation);
}
@@ -1124,7 +1050,6 @@ PyOperation::~PyOperation() {
PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
- auto &liveOperations = contextRef->liveOperations;
// Create.
PyOperation *unownedOperation =
new PyOperation(std::move(contextRef), operation);
@@ -1137,34 +1062,20 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
if (parentKeepAlive) {
unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
}
- liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
return PyOperationRef(unownedOperation, std::move(pyRef));
}
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
- auto &liveOperations = contextRef->liveOperations;
- auto it = liveOperations.find(operation.ptr);
- if (it == liveOperations.end()) {
- // Create.
- return createInstance(std::move(contextRef), operation,
- std::move(parentKeepAlive));
- }
- // Use existing.
- PyOperation *existing = it->second.second;
- py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
- return PyOperationRef(existing, std::move(pyRef));
+ // Create.
+ return createInstance(std::move(contextRef), operation,
+ std::move(parentKeepAlive));
}
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
- auto &liveOperations = contextRef->liveOperations;
- assert(liveOperations.count(operation.ptr) == 0 &&
- "cannot create detached operation that already exists");
- (void)liveOperations;
-
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
created->attached = false;
@@ -1530,9 +1441,6 @@ void PyOperation::erase() {
// TODO: Fix memory hazards when erasing a tree of operations for which a deep
// Python reference to a child operation is live. All children should also
// have their `valid` bit set to false.
- auto &liveOperations = getContext()->liveOperations;
- if (liveOperations.count(operation.ptr))
- liveOperations.erase(operation.ptr);
mlirOperationDestroy(operation);
valid = false;
}
@@ -2274,7 +2182,6 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
- using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;
PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
@@ -2598,14 +2505,6 @@ void mlir::python::populateIRCore(py::module &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
- .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
- .def("_get_live_operation_objects",
- &PyMlirContext::getLiveOperationObjects)
- .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
- .def("_clear_live_operations_inside",
- py::overload_cast<MlirOperation>(
- &PyMlirContext::clearOperationsInside))
- .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -2915,7 +2814,13 @@ void mlir::python::populateIRCore(py::module &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
- kOperationStrDunderDocstring);
+ kOperationStrDunderDocstring)
+ .def(
+ "__eq__",
+ [](PyModule &self, PyModule &other) {
+ return mlirModuleEqual(self.get(), other.get());
+ },
+ "other"_a);
//----------------------------------------------------------------------------
// Mapping of Operation.
@@ -2927,7 +2832,8 @@ void mlir::python::populateIRCore(py::module &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
- return &self.getOperation() == &other.getOperation();
+ return mlirOperationEqual(self.getOperation().get(),
+ other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, py::object other) { return false; })
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index b038a0c54d29b..f9d102c598c22 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -201,34 +201,6 @@ class PyMlirContext {
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
- /// Get a list of Python objects which are still in the live context map.
- std::vector<PyOperation *> getLiveOperationObjects();
-
- /// Gets the count of live operations associated with this context.
- /// Used for testing.
- size_t getLiveOperationCount();
-
- /// Clears the live operations map, returning the number of entries which were
- /// invalidated. To be used as a safety mechanism so that API end-users can't
- /// corrupt by holding references they shouldn't have accessed in the first
- /// place.
- size_t clearLiveOperations();
-
- /// Removes an operation from the live operations map and sets it invalid.
- /// This is useful for when some non-bindings code destroys the operation and
- /// the bindings need to made aware. For example, in the case when pass
- /// manager is run.
- void clearOperation(MlirOperation op);
-
- /// Clears all operations nested inside the given op using
- /// `clearOperation(MlirOperation)`.
- void clearOperationsInside(PyOperationBase &op);
- void clearOperationsInside(MlirOperation op);
-
- /// Gets the count of live modules associated with this context.
- /// Used for testing.
- size_t getLiveModuleCount();
-
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
@@ -255,22 +227,6 @@ class PyMlirContext {
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();
- // Interns all live modules associated with this context. Modules tracked
- // in this map are valid. When a module is invalidated, it is removed
- // from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveModuleMap =
- llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
- LiveModuleMap liveModules;
-
- // Interns all live operations associated with this context. Operations
- // tracked in this map are valid. When an operation is invalidated, it is
- // removed from this map, and while it still exists as an instance, any
- // attempt to access it will raise an error.
- using LiveOperationMap =
- llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
- LiveOperationMap liveOperations;
-
bool emitErrorDiagnostics = false;
MlirContext context;
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index a68421b61641f..0603f9299ecf4 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -117,11 +117,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"ValueError if the pipeline can't be parsed.")
.def(
"run",
- [](PyPassManager &passManager, PyOperationBase &op,
- bool invalidateOps) {
- if (invalidateOps) {
- op.getOperation().getContext()->clearOperationsInside(op);
- }
+ [](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
@@ -130,7 +126,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
- "operation"_a, "invalidate_ops"_a = true,
+ "operation"_a,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
index f6b4532b1b6be..cddaf3e70286a 100644
--- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp
+++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp
@@ -68,7 +68,6 @@ static void populateTransformInterpreterSubmodule(py::module &m) {
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
py::object obj = py::cast(payloadRoot);
- obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index a72cd247e73f6..9f60280af14fe 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -332,6 +332,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) {
return wrap(dyn_cast<ModuleOp>(unwrap(op)));
}
+bool mlirModuleEqual(MlirModule mod, MlirModule other) {
+ return unwrap(mod) == unwrap(other);
+}
+
//===----------------------------------------------------------------------===//
// Operation state API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index ecafcb46af217..16efd4fdb61b6 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -102,51 +102,25 @@ def testRoundtripBinary():
def testModuleOperation():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- live_ops = ctx._get_live_operation_objects()
- assert len(live_ops) == 1
- assert live_ops[0] is op1
- live_ops = None
# CHECK: module @successfulParse
print(op1)
# Ensure that operations are the same on multiple calls.
op2 = module.operation
- assert ctx._get_live_operation_count() == 1
assert op1 is op2
- # Test live operation clearing.
- op1 = module.operation
- assert ctx._get_live_operation_count() == 1
- num_invalidated = ctx._clear_live_operations()
- assert num_invalidated == 1
- assert ctx._get_live_operation_count() == 0
- op1 = None
- gc.collect()
- op1 = module.operation
-
# Ensure that if module is de-referenced, the operations are still valid.
module = None
gc.collect()
print(op1)
- # Collect and verify lifetime.
- op1 = None
- op2 = None
- gc.collect()
- print("LIVE OPERATIONS:", ctx._get_live_operation_count())
- assert ctx._get_live_operation_count() == 0
- assert ctx._get_live_module_count() == 0
-
# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
ctx = Context()
module = Module.parse(r"""module @successfulParse {}""", ctx)
- assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
module_capsule = module._CAPIPtr
print(module_capsule)
@@ -158,4 +132,3 @@ def testModuleCapsule():
module_capsule = None
module_dup = None
gc.collect()
- assert ctx._get_live_module_count() == 0
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 577721ab2111f..61c181c2efeb7 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -56,14 +56,6 @@ def testSymbolTableInsert():
print(m1)
assert "bar" not in symbol_table
- try:
- print(bar)
- except RuntimeError as e:
- if "the operation has been invalidated" not in str(e):
- raise
- else:
- assert False, "expected RuntimeError due to invalidated operation"
-
qux = m2.body.operations[0]
m1.body.append(qux)
symbol_table.insert(qux)
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index 43af80b53166c..9826f26f75890 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,14 +176,6 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
- log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
-
- # CHECK: invalidate_ops=False
- log("invalidate_ops=False")
-
- # CHECK: live ops: 0
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -196,9 +188,6 @@ def testPostPassOpInvalidation():
"""
)
- # CHECK: live ops: 1
- log_op_count()
-
outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
log(outer_const_op)
@@ -214,12 +203,7 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)
- # CHECK: live ops: 4
- log_op_count()
-
- PassManager.parse("builtin.module(canonicalize)").run(
- module, invalidate_ops=False
- )
+ PassManager.parse("builtin.module(canonicalize)").run(module)
# CHECK: func.func @foo() {
# CHECK: return
# CHECK: }
@@ -233,9 +217,6 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")
- # CHECK: live ops: 4
- log_op_count()
-
module = ModuleOp.parse(
"""
module {
@@ -251,14 +232,8 @@ def testPostPassOpInvalidation():
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]
- # CHECK: live ops: 4
- log_op_count()
-
PassManager.parse("builtin.module(canonicalize)").run(module)
- # CHECK: live ops: 1
- log_op_count()
-
try:
log(func_op)
except RuntimeError as e:
More information about the Mlir-commits
mailing list