[Mlir-commits] [mlir] [mlir][python] Clear PyOperations instead of invalidating them. (PR #70044)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 24 07:03:24 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Ingo Müller (ingomueller-net)

<details>
<summary>Changes</summary>

`PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. #<!-- -->69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs.

This is required to actually solve the issue in #<!-- -->69730.

---
Full diff: https://github.com/llvm/llvm-project/pull/70044.diff


3 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+6-3) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+5-4) 
- (modified) mlir/test/python/pass_manager.py (+23-2) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a8ea1a381edb96e..5d936c2a5f0ed53 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,9 +635,12 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
-void PyMlirContext::setOperationInvalid(MlirOperation op) {
-  if (liveOperations.contains(op.ptr))
-    liveOperations[op.ptr].second->setInvalid();
+void PyMlirContext::clearOperation(MlirOperation op) {
+  auto it = liveOperations.find(op.ptr);
+  if (it != liveOperations.end()) {
+    it->second.second->setInvalid();
+    liveOperations.erase(it);
+  }
 }
 
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 26292885711a4e4..f62a64bceee5b30 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,10 +209,11 @@ class PyMlirContext {
   /// place.
   size_t clearLiveOperations();
 
-  /// Sets an operation invalid. This is useful for when some non-bindings
-  /// code destroys the operation and the bindings need to made aware. For
-  /// example, in the case when pass manager is run.
-  void setOperationInvalid(MlirOperation op);
+  /// Removes an operation from the live operations map and sets it invalid.
+  /// This is useful for when some non-bindings code destroys the operation and
+  /// the bindings need to made aware. For example, in the case when pass
+  /// manager is run.
+  void clearOperation(MlirOperation op);
 
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e7f79ddc75113e0..0face028b73ff1d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,6 +176,14 @@ def testRunPipelineError():
 @run
 def testPostPassOpInvalidation():
     with Context() as ctx:
+        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
+
+        # CHECK: invalidate_ops=False
+        log("invalidate_ops=False")
+
+        # CHECK: live ops: 0
+        log_op_count()
+
         module = ModuleOp.parse(
             """
           module {
@@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
         """
         )
 
-        # CHECK: invalidate_ops=False
-        log("invalidate_ops=False")
+        # CHECK: live ops: 1
+        log_op_count()
 
         outer_const_op = module.body.operations[0]
         # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
@@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
         # CHECK: %[[VAL1]] = arith.constant 10 : i64
         log(inner_const_op)
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         PassManager.parse("builtin.module(canonicalize)").run(
             module, invalidate_ops=False
         )
@@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
         # CHECK: invalidate_ops=True
         log("invalidate_ops=True")
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         module = ModuleOp.parse(
             """
           module {
@@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
         func_op = module.body.operations[1]
         inner_const_op = func_op.body.blocks[0].operations[0]
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         PassManager.parse("builtin.module(canonicalize)").run(module)
+
+        # CHECK: live ops: 1
+        log_op_count()
+
         try:
             log(func_op)
         except RuntimeError as e:

``````````

</details>


https://github.com/llvm/llvm-project/pull/70044


More information about the Mlir-commits mailing list