[Mlir-commits] [mlir] 49745f8 - [mlir][python] Add `destroy` method to PyOperation.

Mike Urbach llvmlistbot at llvm.org
Wed Apr 28 18:30:12 PDT 2021


Author: Mike Urbach
Date: 2021-04-28T19:30:05-06:00
New Revision: 49745f87e61014ac2a9e93bcad1225c55695b9b7

URL: https://github.com/llvm/llvm-project/commit/49745f87e61014ac2a9e93bcad1225c55695b9b7
DIFF: https://github.com/llvm/llvm-project/commit/49745f87e61014ac2a9e93bcad1225c55695b9b7.diff

LOG: [mlir][python] Add `destroy` method to PyOperation.

This adds a method to directly invoke `mlirOperationDestroy` on the
MlirOperation wrapped by a PyOperation.

Reviewed By: stellaraccident, mehdi_amini

Differential Revision: https://reviews.llvm.org/D101422

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 781e9aed66e9..160e35b21353 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -753,6 +753,9 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
     : BaseContextObject(std::move(contextRef)), operation(operation) {}
 
 PyOperation::~PyOperation() {
+  // If the operation has already been invalidated there is nothing to do.
+  if (!valid)
+    return;
   auto &liveOperations = getContext()->liveOperations;
   assert(liveOperations.count(operation.ptr) == 1 &&
          "destroying operation not in live map");
@@ -869,6 +872,7 @@ py::object PyOperationBase::getAsm(bool binary,
 }
 
 PyOperationRef PyOperation::getParentOperation() {
+  checkValid();
   if (!isAttached())
     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
   MlirOperation operation = mlirOperationGetParentOperation(get());
@@ -878,6 +882,7 @@ PyOperationRef PyOperation::getParentOperation() {
 }
 
 PyBlock PyOperation::getBlock() {
+  checkValid();
   PyOperationRef parentOperation = getParentOperation();
   MlirBlock block = mlirOperationGetBlock(get());
   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
@@ -885,6 +890,7 @@ PyBlock PyOperation::getBlock() {
 }
 
 py::object PyOperation::getCapsule() {
+  checkValid();
   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
 }
 
@@ -1032,6 +1038,7 @@ py::object PyOperation::create(
 }
 
 py::object PyOperation::createOpView() {
+  checkValid();
   MlirIdentifier ident = mlirOperationGetName(get());
   MlirStringRef identStr = mlirIdentifierStr(ident);
   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
@@ -1041,6 +1048,18 @@ py::object PyOperation::createOpView() {
   return py::cast(PyOpView(getRef().getObject()));
 }
 
+void PyOperation::erase() {
+  checkValid();
+  // TODO: Fix memory hazards when erasing a tree of operations for which a deep
+  // Python reference to a child operation is live. All children should also
+  // have their `valid` bit set to false.
+  auto &liveOperations = getContext()->liveOperations;
+  if (liveOperations.count(operation.ptr))
+    liveOperations.erase(operation.ptr);
+  mlirOperationDestroy(operation);
+  valid = false;
+}
+
 //------------------------------------------------------------------------------
 // PyOpView
 //------------------------------------------------------------------------------
@@ -2094,11 +2113,13 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
                   kOperationCreateDocstring)
+      .def("erase", &PyOperation::erase)
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyOperation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
       .def_property_readonly("name",
                              [](PyOperation &self) {
+                               self.checkValid();
                                MlirOperation operation = self.get();
                                MlirStringRef name = mlirIdentifierStr(
                                    mlirOperationGetName(operation));
@@ -2106,7 +2127,10 @@ void mlir::python::populateIRCore(py::module &m) {
                              })
       .def_property_readonly(
           "context",
-          [](PyOperation &self) { return self.getContext().getObject(); },
+          [](PyOperation &self) {
+            self.checkValid();
+            return self.getContext().getObject();
+          },
           "Context that owns the Operation")
       .def_property_readonly("opview", &PyOperation::createOpView);
 

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 292080d911d1..79c480e9446f 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -473,6 +473,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// Creates an OpView suitable for this operation.
   pybind11::object createOpView();
 
+  /// Erases the underlying MlirOperation, removes its pointer from the
+  /// parent context's live operations map, and sets the valid bit false.
+  void erase();
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 746cd3e6ddbf..83e4a4fdfca6 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -646,3 +646,25 @@ def testCapsuleConversions():
     assert m2 is m
 
 run(testCapsuleConversions)
+
+# CHECK-LABEL: TEST: testOperationErase
+def testOperationErase():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    m = Module.create()
+    with InsertionPoint(m.body):
+      op = Operation.create("custom.op1")
+
+      # CHECK: "custom.op1"
+      print(m)
+
+      op.operation.erase()
+
+      # CHECK-NOT: "custom.op1"
+      print(m)
+
+      # Ensure we can create another operation
+      Operation.create("custom.op2")
+
+run(testOperationErase)


        


More information about the Mlir-commits mailing list