[Mlir-commits] [mlir] [MLIR][Python] Add method for getting the live operation objects (PR #78663)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 18 19:12:12 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: John Demme (teqdruid)

<details>
<summary>Changes</summary>

Currently, a method exists to get the count of the operation objects which are still alive. This helps for sanity checking, but isn't terribly useful for debugging. This new method returns the actual operation objects which are still alive.

This allows Python code like the following:

```
    gc.collect()
    live_ops = ir.Context.current._get_live_operation_objects()
    for op in live_ops:
      print(f"Warning: {op} is still live. Referrers:")
      for referrer in gc.get_referrers(op)[0]:
        print(f"  {referrer}")
```

---
Full diff: https://github.com/llvm/llvm-project/pull/78663.diff


4 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+9) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+3) 
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+1) 
- (modified) mlir/test/python/ir/module.py (+4) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5412c3dec4b1b6..8a7951dc29fe5f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -636,6 +636,13 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
+std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
+  std::vector<PyOperation *> liveObjects;
+  for (auto &entry : liveOperations)
+    liveObjects.push_back(entry.second.second);
+  return liveObjects;
+}
+
 size_t PyMlirContext::clearLiveOperations() {
   for (auto &op : liveOperations)
     op.second.second->setInvalid();
@@ -2546,6 +2553,8 @@ void mlir::python::populateIRCore(py::module &m) {
              return ref.releaseObject();
            })
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def("_get_live_operation_objects",
+           &PyMlirContext::getLiveOperationObjects)
       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 79b7e0c96188c1..48f39c939340d7 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -201,6 +201,9 @@ class PyMlirContext {
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Get a list of Python objects which are still in the live context map.
+  std::vector<PyOperation *> getLiveOperationObjects();
+
   /// Gets the count of live operations associated with this context.
   /// Used for testing.
   size_t getLiveOperationCount();
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 57a85990f9bcf5..344abb64a57d23 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -985,6 +985,7 @@ class Context:
     def _get_context_again(self) -> Context: ...
     def _get_live_module_count(self) -> int: ...
     def _get_live_operation_count(self) -> int: ...
+    def _get_live_operation_objects(self) -> List[Operation]: ...
     def append_dialect_registry(self, registry: DialectRegistry) -> None: ...
     def attach_diagnostic_handler(
         self, callback: Callable[[Diagnostic], bool]
diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index a5c38a6b0b076e..ecafcb46af2175 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -105,6 +105,10 @@ def testModuleOperation():
     assert ctx._get_live_module_count() == 1
     op1 = module.operation
     assert ctx._get_live_operation_count() == 1
+    live_ops = ctx._get_live_operation_objects()
+    assert len(live_ops) == 1
+    assert live_ops[0] is op1
+    live_ops = None
     # CHECK: module @successfulParse
     print(op1)
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/78663


More information about the Mlir-commits mailing list