[Mlir-commits] [mlir] 6b0bed7 - [MLIR] [Python] Add a method to clear live operations map

John Demme llvmlistbot at llvm.org
Tue Apr 19 15:14:46 PDT 2022


Author: John Demme
Date: 2022-04-19T15:14:09-07:00
New Revision: 6b0bed7ea563624622c3c1fb1a3c90cd32c78db6

URL: https://github.com/llvm/llvm-project/commit/6b0bed7ea563624622c3c1fb1a3c90cd32c78db6
DIFF: https://github.com/llvm/llvm-project/commit/6b0bed7ea563624622c3c1fb1a3c90cd32c78db6.diff

LOG: [MLIR] [Python] Add a method to clear live operations map

Introduce a method on PyMlirContext (and plumb it through to Python) to
invalidate all of the operations in the live operations map and clear
it. Since Python has no notion of private data, an end-developer could
reach into some 3rd party API which uses the MLIR Python API (that is
behaving correctly with regard to holding references) and grab a
reference to an MLIR Python Operation, preventing it from being
deconstructed out of the live operations map. This allows the API
developer to clear the map when it calls C++ code which could delete
operations, protecting itself from its users.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123895

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/python/ir/module.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 25391ebd0a581..d1877a11b8154 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -505,6 +505,14 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
 
+size_t PyMlirContext::clearLiveOperations() {
+  for (auto &op : liveOperations)
+    op.second.second->setInvalid();
+  size_t numInvalidated = liveOperations.size();
+  liveOperations.clear();
+  return numInvalidated;
+}
+
 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
 
 pybind11::object PyMlirContext::contextEnter() {
@@ -2208,6 +2216,7 @@ void mlir::python::populateIRCore(py::module &m) {
              return ref.releaseObject();
            })
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+      .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 2046ce0c16552..371157a569665 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -201,6 +201,12 @@ class PyMlirContext {
   /// Used for testing.
   size_t getLiveOperationCount();
 
+  /// Clears the live operations map, returning the number of entries which were
+  /// invalidated. To be used as a safety mechanism so that API end-users can't
+  /// corrupt by holding references they shouldn't have accessed in the first
+  /// place.
+  size_t clearLiveOperations();
+
   /// Gets the count of live modules associated with this context.
   /// Used for testing.
   size_t getLiveModuleCount();
@@ -575,6 +581,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// parent context's live operations map, and sets the valid bit false.
   void erase();
 
+  /// Invalidate the operation.
+  void setInvalid() { valid = false; }
+
   /// Clones this operation.
   pybind11::object clone(const pybind11::object &ip);
 

diff  --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py
index 76358eb434c3b..adc27a2879d22 100644
--- a/mlir/test/python/ir/module.py
+++ b/mlir/test/python/ir/module.py
@@ -104,6 +104,16 @@ def testModuleOperation():
   assert ctx._get_live_operation_count() == 1
   assert op1 is op2
 
+  # Test live operation clearing.
+  op1 = module.operation
+  assert ctx._get_live_operation_count() == 1
+  num_invalidated = ctx._clear_live_operations()
+  assert num_invalidated == 1
+  assert ctx._get_live_operation_count() == 0
+  op1 = None
+  gc.collect()
+  op1 = module.operation
+
   # Ensure that if module is de-referenced, the operations are still valid.
   module = None
   gc.collect()


        


More information about the Mlir-commits mailing list