[Mlir-commits] [mlir] [mlir][python] Clear PyOperations instead of invalidating them. (PR #70044)

Ingo Müller llvmlistbot at llvm.org
Tue Oct 24 07:23:12 PDT 2023


https://github.com/ingomueller-net updated https://github.com/llvm/llvm-project/pull/70044

>From 129659c4c3edfe6ddc080a2aa6e70c0e84c63478 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 24 Oct 2023 13:56:02 +0000
Subject: [PATCH 1/3] [mlir][python] Clear PyOperations instead of invalidating
 them.

`PyOperations` are Python-level handles to `Operation *` instances. When
the latter are modified by C++, the former need to be invalidated.
 #69746 implements such invalidation mechanism by setting all
`PyReferences` to `invalid`. However, that is not enough: they also need
to be removed from the `liveOperations` map since other parts of the
code (such as `PyOperation::createDetached`) assume that that map only
contains valid refs.

This is required to actually solve the issue in #69730.
---
 mlir/lib/Bindings/Python/IRCore.cpp |  9 ++++++---
 mlir/lib/Bindings/Python/IRModule.h |  9 +++++----
 mlir/lib/Bindings/Python/Pass.cpp   |  5 ++---
 mlir/test/python/pass_manager.py    | 25 +++++++++++++++++++++++--
 4 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a8ea1a381edb96e..5d936c2a5f0ed53 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -635,9 +635,12 @@ size_t PyMlirContext::clearLiveOperations() {
   return numInvalidated;
 }
 
-void PyMlirContext::setOperationInvalid(MlirOperation op) {
-  if (liveOperations.contains(op.ptr))
-    liveOperations[op.ptr].second->setInvalid();
+void PyMlirContext::clearOperation(MlirOperation op) {
+  auto it = liveOperations.find(op.ptr);
+  if (it != liveOperations.end()) {
+    it->second.second->setInvalid();
+    liveOperations.erase(it);
+  }
 }
 
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 26292885711a4e4..f62a64bceee5b30 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -209,10 +209,11 @@ class PyMlirContext {
   /// place.
   size_t clearLiveOperations();
 
-  /// Sets an operation invalid. This is useful for when some non-bindings
-  /// code destroys the operation and the bindings need to made aware. For
-  /// example, in the case when pass manager is run.
-  void setOperationInvalid(MlirOperation op);
+  /// Removes an operation from the live operations map and sets it invalid.
+  /// This is useful for when some non-bindings code destroys the operation and
+  /// the bindings need to made aware. For example, in the case when pass
+  /// manager is run.
+  void clearOperation(MlirOperation op);
 
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index 2175cea79960ca6..b71661712e27c3e 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -130,9 +130,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
                   [](MlirOperation op, void *userData) {
                     callBackData *data = static_cast<callBackData *>(userData);
                     if (LLVM_LIKELY(data->rootSeen))
-                      data->rootOp.getOperation()
-                          .getContext()
-                          ->setOperationInvalid(op);
+                      data->rootOp.getOperation().getContext()->clearOperation(
+                          op);
                     else
                       data->rootSeen = true;
                   };
diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py
index e7f79ddc75113e0..0face028b73ff1d 100644
--- a/mlir/test/python/pass_manager.py
+++ b/mlir/test/python/pass_manager.py
@@ -176,6 +176,14 @@ def testRunPipelineError():
 @run
 def testPostPassOpInvalidation():
     with Context() as ctx:
+        log_op_count = lambda: log("live ops:", ctx._get_live_operation_count())
+
+        # CHECK: invalidate_ops=False
+        log("invalidate_ops=False")
+
+        # CHECK: live ops: 0
+        log_op_count()
+
         module = ModuleOp.parse(
             """
           module {
@@ -188,8 +196,8 @@ def testPostPassOpInvalidation():
         """
         )
 
-        # CHECK: invalidate_ops=False
-        log("invalidate_ops=False")
+        # CHECK: live ops: 1
+        log_op_count()
 
         outer_const_op = module.body.operations[0]
         # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64
@@ -206,6 +214,9 @@ def testPostPassOpInvalidation():
         # CHECK: %[[VAL1]] = arith.constant 10 : i64
         log(inner_const_op)
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         PassManager.parse("builtin.module(canonicalize)").run(
             module, invalidate_ops=False
         )
@@ -222,6 +233,9 @@ def testPostPassOpInvalidation():
         # CHECK: invalidate_ops=True
         log("invalidate_ops=True")
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         module = ModuleOp.parse(
             """
           module {
@@ -237,7 +251,14 @@ def testPostPassOpInvalidation():
         func_op = module.body.operations[1]
         inner_const_op = func_op.body.blocks[0].operations[0]
 
+        # CHECK: live ops: 4
+        log_op_count()
+
         PassManager.parse("builtin.module(canonicalize)").run(module)
+
+        # CHECK: live ops: 1
+        log_op_count()
+
         try:
             log(func_op)
         except RuntimeError as e:

>From 537e4878b685f6aed5169b53e38b8a9f75037bb3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 24 Oct 2023 14:03:59 +0000
Subject: [PATCH 2/3] Fix some typos from #69746. (NFC)

---
 mlir/include/mlir-c/IR.h | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 7b121d4df328641..5659230a03d8ce3 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -705,12 +705,12 @@ typedef enum MlirWalkOrder {
 } MlirWalkOrder;
 
 /// Operation walker type. The handler is passed an (opaque) reference to an
-/// operation a pointer to a `userData`.
+/// operation and a pointer to a `userData`.
 typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData);
 
 /// Walks operation `op` in `walkOrder` and calls `callback` on that operation.
 /// `*userData` is passed to the callback as well and can be used to tunnel some
-/// some context or other data into the callback.
+/// context or other data into the callback.
 MLIR_CAPI_EXPORTED
 void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback,
                        void *userData, MlirWalkOrder walkOrder);

>From 2207fb135627368a0ff69c39bd1c86850edc683a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ingo=20M=C3=BCller?= <ingomueller at google.com>
Date: Tue, 24 Oct 2023 14:19:47 +0000
Subject: [PATCH 3/3] Move clearing of nested op to `PyMlirContext`.

This allows to use the same implementation by other binding functions
that also invalidate `PyOperation`s.
---
 mlir/lib/Bindings/Python/IRCore.cpp | 20 ++++++++++++++++++++
 mlir/lib/Bindings/Python/IRModule.h |  5 +++++
 mlir/lib/Bindings/Python/Pass.cpp   | 19 +------------------
 3 files changed, 26 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5d936c2a5f0ed53..7cfea31dbb2e80c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -643,6 +643,26 @@ void PyMlirContext::clearOperation(MlirOperation op) {
   }
 }
 
+void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
+  typedef struct {
+    PyOperation &rootOp;
+    bool rootSeen;
+  } callBackData;
+  callBackData data{op.getOperation(), false};
+  // Mark all ops below the op that the passmanager will be rooted
+  // at (but not op itself - note the preorder) as invalid.
+  MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
+                                                      void *userData) {
+    callBackData *data = static_cast<callBackData *>(userData);
+    if (LLVM_LIKELY(data->rootSeen))
+      data->rootOp.getOperation().getContext()->clearOperation(op);
+    else
+      data->rootSeen = true;
+  };
+  mlirOperationWalk(op.getOperation(), invalidatingCallback,
+                    static_cast<void *>(&data), MlirWalkPreOrder);
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f62a64bceee5b30..01ee4975d0e9a91 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -37,6 +37,7 @@ class PyMlirContext;
 class DefaultingPyMlirContext;
 class PyModule;
 class PyOperation;
+class PyOperationBase;
 class PyType;
 class PySymbolTable;
 class PyValue;
@@ -215,6 +216,10 @@ class PyMlirContext {
   /// manager is run.
   void clearOperation(MlirOperation op);
 
+  /// Clears all operations nested inside the given op using
+  /// `clearOperation(MlirOperation)`.
+  void clearOperationsInside(PyOperationBase &op);
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index b71661712e27c3e..588a8e25414c657 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -119,24 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
           [](PyPassManager &passManager, PyOperationBase &op,
              bool invalidateOps) {
             if (invalidateOps) {
-              typedef struct {
-                PyOperation &rootOp;
-                bool rootSeen;
-              } callBackData;
-              callBackData data{op.getOperation(), false};
-              // Mark all ops below the op that the passmanager will be rooted
-              // at (but not op itself - note the preorder) as invalid.
-              MlirOperationWalkCallback invalidatingCallback =
-                  [](MlirOperation op, void *userData) {
-                    callBackData *data = static_cast<callBackData *>(userData);
-                    if (LLVM_LIKELY(data->rootSeen))
-                      data->rootOp.getOperation().getContext()->clearOperation(
-                          op);
-                    else
-                      data->rootSeen = true;
-                  };
-              mlirOperationWalk(op.getOperation(), invalidatingCallback,
-                                static_cast<void *>(&data), MlirWalkPreOrder);
+              op.getOperation().getContext()->clearOperationsInside(op);
             }
             // Actually run the pass manager.
             PyMlirContext::ErrorCapture errors(op.getOperation().getContext());



More information about the Mlir-commits mailing list