[Mlir-commits] [mlir] fa19ef7 - [mlir][python] Clear PyOperations instead of invalidating them. (#70044)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 24 22:18:00 PDT 2023
Author: Ingo Müller
Date: 2023-10-25T07:17:56+02:00
New Revision: fa19ef7a10869bf0f8325681be111f7d97b2544e
URL: https://github.com/llvm/llvm-project/commit/fa19ef7a10869bf0f8325681be111f7d97b2544e
DIFF: https://github.com/llvm/llvm-project/commit/fa19ef7a10869bf0f8325681be111f7d97b2544e.diff
LOG: [mlir][python] Clear PyOperations instead of invalidating them. (#70044)
`PyOperations` are Python-level handles to `Operation *` instances. When
the latter are modified by C++, the former need to be invalidated.
#69746 implements such invalidation mechanism by setting all
`PyReferences` to `invalid`. However, that is not enough: they also need
to be removed from the `liveOperations` map since other parts of the
code (such as `PyOperation::createDetached`) assume that that map only
contains valid refs.
This is required to actually solve the issue in #69730.
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/test/python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 7b121d4df328641..5659230a03d8ce3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,12 +705,12 @@ typedef enum MlirWalkOrder {
} MlirWalkOrder;
/// Operation walker type. The handler is passed an (opaque) reference to an
-/// operation a pointer to a `userData`.
+/// operation and a pointer to a `userData`.
typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
/// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
/// `*userData` is passed to the callback as well and can be used to tunnel some
-/// some context or other data into the callback.
+/// context or other data into the callback.
MLIR_CAPI_EXPORTED
void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
void *userData, MlirWalkOrder walkOrder);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a8ea1a381edb96e..7cfea31dbb2e80c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,9 +635,32 @@ size_t PyMlirContext::clearLiveOperations() {
return numInvalidated;
}
-void PyMlirContext::setOperationInvalid(MlirOperation op) {
- if (liveOperations.contains(op.ptr))
- liveOperations[op.ptr].second->setInvalid();
+void PyMlirContext::clearOperation(MlirOperation op) {
+ auto it = liveOperations.find(op.ptr);
+ if (it != liveOperations.end()) {
+ it->second.second->setInvalid();
+ liveOperations.erase(it);
+ }
+}
+
+void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
+ typedef struct {
+ PyOperation &rootOp;
+ bool rootSeen;
+ } callBackData;
+ callBackData data{op.getOperation(), false};
+ // Mark all ops below the op that the passmanager will be rooted
+ // at (but not op itself - note the preorder) as invalid.
+ MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
+ void *userData) {
+ callBackData *data = static_cast<callBackData *>(userData);
+ if (LLVM_LIKELY(data->rootSeen))
+ data->rootOp.getOperation().getContext()->clearOperation(op);
+ else
+ data->rootSeen = true;
+ };
+ mlirOperationWalk(op.getOperation(), invalidatingCallback,
+ static_cast<void *>(&data), MlirWalkPreOrder);
}
size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 26292885711a4e4..01ee4975d0e9a91 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -37,6 +37,7 @@ class PyMlirContext;
class DefaultingPyMlirContext;
class PyModule;
class PyOperation;
+class PyOperationBase;
class PyType;
class PySymbolTable;
class PyValue;
@@ -209,10 +210,15 @@ class PyMlirContext {
/// place.
size_t clearLiveOperations();
- /// Sets an operation invalid. This is useful for when some non-bindings
- /// code destroys the operation and the bindings need to made aware. For
- /// example, in the case when pass manager is run.
- void setOperationInvalid(MlirOperation op);
+ /// Removes an operation from the live operations map and sets it invalid.
+ /// This is useful for when some non-bindings code destroys the operation and
+ /// the bindings need to made aware. For example, in the case when pass
+ /// manager is run.
+ void clearOperation(MlirOperation op);
+
+ /// Clears all operations nested inside the given op using
+ /// `clearOperation(MlirOperation)`.
+ void clearOperationsInside(PyOperationBase &op);
/// Gets the count of live modules associated with this context.
/// Used for testing.
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 2175cea79960ca6..588a8e25414c657 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -119,25 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
[](PyPassManager &passManager, PyOperationBase &op,
bool invalidateOps) {
if (invalidateOps) {
- typedef struct {
- PyOperation &rootOp;
- bool rootSeen;
- } callBackData;
- callBackData data{op.getOperation(), false};
- // Mark all ops below the op that the passmanager will be rooted
- // at (but not op itself - note the preorder) as invalid.
- MlirOperationWalkCallback invalidatingCallback =
- [](MlirOperation op, void *userData) {
- callBackData *data = static_cast<callBackData *>(userData);
- if (LLVM_LIKELY(data->rootSeen))
- data->rootOp.getOperation()
- .getContext()
- ->setOperationInvalid(op);
- else
- data->rootSeen = true;
- };
- mlirOperationWalk(op.getOperation(), invalidatingCallback,
- static_cast<void *>(&data), MlirWalkPreOrder);
+ op.getOperation().getContext()->clearOperationsInside(op);
}
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e7f79ddc75113e0..0face028b73ff1d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,6 +176,14 @@ def testRunPipelineError():
@run
def testPostPassOpInvalidation():
with Context() as ctx:
+ log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
+
+ # CHECK: invalidate_ops=False
+ log("invalidate_ops=False")
+
+ # CHECK: live ops: 0
+ log_op_count()
+
module = ModuleOp.parse(
"""
module {
@@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
"""
)
- # CHECK: invalidate_ops=False
- log("invalidate_ops=False")
+ # CHECK: live ops: 1
+ log_op_count()
outer_const_op = module.body.operations[0]
# CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
@@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
# CHECK: %[[VAL1]] = arith.constant 10 : i64
log(inner_const_op)
+ # CHECK: live ops: 4
+ log_op_count()
+
PassManager.parse("builtin.module(canonicalize)").run(
module, invalidate_ops=False
)
@@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
# CHECK: invalidate_ops=True
log("invalidate_ops=True")
+ # CHECK: live ops: 4
+ log_op_count()
+
module = ModuleOp.parse(
"""
module {
@@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
func_op = module.body.operations[1]
inner_const_op = func_op.body.blocks[0].operations[0]
+ # CHECK: live ops: 4
+ log_op_count()
+
PassManager.parse("builtin.module(canonicalize)").run(module)
+
+ # CHECK: live ops: 1
+ log_op_count()
+
try:
log(func_op)
except RuntimeError as e:
More information about the Mlir-commits
mailing list