[Mlir-commits] [mlir] [MLIR][Python] Add method for getting the live operation objects (PR #78663)

John Demme llvmlistbot at llvm.org
Thu Feb 8 09:23:38 PST 2024


https://github.com/teqdruid updated https://github.com/llvm/llvm-project/pull/78663

>From 8becc6ac237972ba9e9f1ccebf97362d87034348 Mon Sep 17 00:00:00 2001
From: John Demme <john.demme at microsoft.com>
Date: Fri, 19 Jan 2024 03:04:00 +0000
Subject: [PATCH] [MLIR][Python] Add method for getting the live operation
 objects

Currently, a method exists to get the count of the operation objects
which are still alive. This helps for sanity checking, but isn't
terribly useful for debugging. This new method returns the actual
operation objects which are still alive.

This allows Python code like the following:

```
    gc.collect()
    live_ops = ir.Context.current._get_live_operation_objects()
    for op in live_ops:
      print(f"Warning: {op} is still live. Referrers:")
      for referrer in gc.get_referrers(op)[0]:
        print(f"  {referrer}")
```
---
 mlir/lib/Bindings/Python/IRCore.cpp      | 9 +++++++++
 mlir/lib/Bindings/Python/IRModule.h      | 3 +++
 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 1 +
 mlir/test/python/ir/module.py            | 4 ++++
 4 files changed, 17 insertions(+)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5412c3dec4b1b6..8a7951dc29fe5f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -636,6 +636,13 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
+std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
+  std::vector<PyOperation *> liveObjects;
+  for (auto &entry : liveOperations)
+    liveObjects.push_back(entry.second.second);
+  return liveObjects;
+}
+
 size_t PyMlirContext::clearLiveOperations() {
   for (auto &op : liveOperations)
     op.second.second->setInvalid();
@@ -2546,6 +2553,8 @@ void mlir::python::populateIRCore(py::module &m) {
              return ref.releaseObject();
            })
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def("_get_live_operation_objects",
+           &PyMlirContext::getLiveOperationObjects)
       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 79b7e0c96188c1..48f39c939340d7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -201,6 +201,9 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Get a list of Python objects which are still in the live context map.
+  std::vector<PyOperation *> getLiveOperationObjects();
+
   /// Gets the count of live operations associated with this context.
   /// Used for testing.
   size_t getLiveOperationCount();
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 57a85990f9bcf5..344abb64a57d23 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -985,6 +985,7 @@ class Context:
     def _get_context_again(self) -> Context: ...
     def _get_live_module_count(self) -> int: ...
     def _get_live_operation_count(self) -> int: ...
+    def _get_live_operation_objects(self) -> List[Operation]: ...
     def append_dialect_registry(self, registry: DialectRegistry) -> None: ...
     def attach_diagnostic_handler(
         self, callback: Callable[[Diagnostic], bool]
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index a5c38a6b0b076e..ecafcb46af2175 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -105,6 +105,10 @@ def testModuleOperation():
     assert ctx._get_live_module_count() == 1
     op1 = module.operation
     assert ctx._get_live_operation_count() == 1
+    live_ops = ctx._get_live_operation_objects()
+    assert len(live_ops) == 1
+    assert live_ops[0] is op1
+    live_ops = None
     # CHECK: module @successfulParse
     print(op1)
 



More information about the Mlir-commits mailing list