[Mlir-commits] [mlir] 7abb0ff - Add Operation to python bindings.

Stella Laurenzo llvmlistbot at llvm.org
Wed Sep 23 07:58:58 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-23T07:57:50-07:00
New Revision: 7abb0ff7e0419a9554d77e9108cb7da670b7471c

URL: https://github.com/llvm/llvm-project/commit/7abb0ff7e0419a9554d77e9108cb7da670b7471c
DIFF: https://github.com/llvm/llvm-project/commit/7abb0ff7e0419a9554d77e9108cb7da670b7471c.diff

LOG: Add Operation to python bindings.

* Fixes a rather egregious bug with respect to the inability to return arbitrary objects from py::init (was causing aliasing of multiple py::object -> native instance).
* Makes Modules and Operations referencable types so that they can be reliably depended on.
* Uniques python operation instances within a context. Opens the door for further accounting.
* Next I will retrofit region and block to be dependent on the operation, and I will attempt to model the API to avoid detached regions/blocks, which will simplify things a lot (in that world, only operations can be detached).
* Added quite a bit of test coverage to check for leaks and reference issues.
* Supercedes: https://reviews.llvm.org/D87213

Differential Revision: https://reviews.llvm.org/D87958

Added: 
    

Modified: 
    mlir/docs/Bindings/Python.md
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/test/Bindings/Python/ir_attributes.py
    mlir/test/Bindings/Python/ir_location.py
    mlir/test/Bindings/Python/ir_module.py
    mlir/test/Bindings/Python/ir_operation.py
    mlir/test/Bindings/Python/ir_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 8d9cee5e88ca..782b46f503ea 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -161,6 +161,28 @@ issues that arise when combining RTTI-based modules (which pybind derived things
 are) with non-RTTI polymorphic C++ code (the default compilation mode of LLVM).
 
 
+### Ownership in the Core IR
+
+There are several top-level types in the core IR that are strongly owned by their python-side reference:
+
+* `PyContext` (`mlir.ir.Context`)
+* `PyModule` (`mlir.ir.Module`)
+* `PyOperation` (`mlir.ir.Operation`) - but with caveats
+
+All other objects are dependent. All objects maintain a back-reference (keep-alive) to their closest containing top-level object. Further, dependent objects fall into two categories: a) uniqued (which live for the life-time of the context) and b) mutable. Mutable objects need additional machinery for keeping track of when the C++ instance that backs their Python object is no longer valid (typically due to some specific mutation of the IR, deletion, or bulk operation).
+
+#### Operation hierarchy
+
+As mentioned above, `PyOperation` is special because it can exist in either a top-level or dependent state. The life-cycle is unidirectional: operations can be created detached (top-level) and once added to another operation, they are then dependent for the remainder of their lifetime. The situation is more complicated when considering construction scenarios where an operation is added to a transitive parent that is still detached, necessitating further accounting at such transition points (i.e. all such added children are initially added to the IR with a parent of their outer-most detached operation, but then once it is added to an attached operation, they need to be re-parented to the containing module).
+
+Due to the validity and parenting accounting needs, `PyOperation` is the owner for regions and blocks and needs to be a top-level type that we can count on not aliasing. This let's us do things like selectively invalidating instances when mutations occur without worrying that there is some alias to the same operation in the hierarchy. Operations are also the only entity that are allowed to be in a detached state, and they are interned at the context level so that there is never more than one Python `mlir.ir.Operation` object for a unique `MlirOperation`, regardless of how it is obtained.
+
+The C/C++ API allows for Region/Block to also be detached, but it simplifies the ownership model a lot to eliminate that possibility in this API, allowing the Region/Block to be completely dependent on its owning operation for accounting. The aliasing of Python `Region`/`Block` instances to underlying `MlirRegion`/`MlirBlock` is considered benign and these objects are not interned in the context (unlike operations).
+
+If we ever want to re-introduce detached regions/blocks, we could do so with new "DetachedRegion" class or similar and also avoid the complexity of accounting. With the way it is now, we can avoid having a global live list for regions and blocks. We may end up needing an op-local one at some point TBD, depending on how hard it is to guarantee how mutations interact with their Python peer objects. We can cross that bridge easily when we get there.
+
+Module, when used purely from the Python API, can't alias anyway, so we can use it as a top-level ref type without a live-list for interning. If the API ever changes such that this cannot be guaranteed (i.e. by letting you marshal a native-defined Module in), then there would need to be a live table for it too.
+
 ## Style
 
 In general, for the core parts of MLIR, the Python bindings should be largely

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index d7a0bd8ec1a9..66e975e3ea56 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -174,13 +174,12 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
 // PyMlirContext
 //------------------------------------------------------------------------------
 
-PyMlirContext *PyMlirContextRef::release() {
-  object.release();
-  return &referrent;
+PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
+  py::gil_scoped_acquire acquire;
+  auto &liveContexts = getLiveContexts();
+  liveContexts[context.ptr] = this;
 }
 
-PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
-
 PyMlirContext::~PyMlirContext() {
   // Note that the only public way to construct an instance is via the
   // forContext method, which always puts the associated handle into
@@ -190,6 +189,11 @@ PyMlirContext::~PyMlirContext() {
   mlirContextDestroy(context);
 }
 
+PyMlirContext *PyMlirContext::createNewContextForInit() {
+  MlirContext context = mlirContextCreate();
+  return new PyMlirContext(context);
+}
+
 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   py::gil_scoped_acquire acquire;
   auto &liveContexts = getLiveContexts();
@@ -198,14 +202,13 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
     // Create.
     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
     py::object pyRef = py::cast(unownedContextWrapper);
-    unownedContextWrapper->handle = pyRef;
-    liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
-    return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
-  } else {
-    // Use existing.
-    py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
-    return PyMlirContextRef(*it->second.second, std::move(pyRef));
+    assert(pyRef && "cast to py::object failed");
+    liveContexts[context.ptr] = unownedContextWrapper;
+    return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
   }
+  // Use existing.
+  py::object pyRef = py::cast(it->second);
+  return PyMlirContextRef(it->second, std::move(pyRef));
 }
 
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
@@ -215,8 +218,99 @@ PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
 
 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
 
+size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+
+//------------------------------------------------------------------------------
+// PyModule
+//------------------------------------------------------------------------------
+
+PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
+  PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+  // Note that the default return value policy on cast is automatic_reference,
+  // which does not take ownership (delete will not be called).
+  // Just be explicit.
+  py::object pyRef =
+      py::cast(unownedModule, py::return_value_policy::take_ownership);
+  unownedModule->handle = pyRef;
+  return PyModuleRef(unownedModule, std::move(pyRef));
+}
+
+//------------------------------------------------------------------------------
+// PyOperation
+//------------------------------------------------------------------------------
+
+PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
+    : BaseContextObject(std::move(contextRef)), operation(operation) {}
+
+PyOperation::~PyOperation() {
+  auto &liveOperations = getContext()->liveOperations;
+  assert(liveOperations.count(operation.ptr) == 1 &&
+         "destroying operation not in live map");
+  liveOperations.erase(operation.ptr);
+  if (!isAttached()) {
+    mlirOperationDestroy(operation);
+  }
+}
+
+PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
+                                           MlirOperation operation,
+                                           py::object parentKeepAlive) {
+  auto &liveOperations = contextRef->liveOperations;
+  // Create.
+  PyOperation *unownedOperation =
+      new PyOperation(std::move(contextRef), operation);
+  // Note that the default return value policy on cast is automatic_reference,
+  // which does not take ownership (delete will not be called).
+  // Just be explicit.
+  py::object pyRef =
+      py::cast(unownedOperation, py::return_value_policy::take_ownership);
+  unownedOperation->handle = pyRef;
+  if (parentKeepAlive) {
+    unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
+  }
+  liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
+  return PyOperationRef(unownedOperation, std::move(pyRef));
+}
+
+PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
+                                         MlirOperation operation,
+                                         py::object parentKeepAlive) {
+  auto &liveOperations = contextRef->liveOperations;
+  auto it = liveOperations.find(operation.ptr);
+  if (it == liveOperations.end()) {
+    // Create.
+    return createInstance(std::move(contextRef), operation,
+                          std::move(parentKeepAlive));
+  }
+  // Use existing.
+  PyOperation *existing = it->second.second;
+  assert(existing->parentKeepAlive.is(parentKeepAlive));
+  py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+  return PyOperationRef(existing, std::move(pyRef));
+}
+
+PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
+                                           MlirOperation operation,
+                                           py::object parentKeepAlive) {
+  auto &liveOperations = contextRef->liveOperations;
+  assert(liveOperations.count(operation.ptr) == 0 &&
+         "cannot create detached operation that already exists");
+  (void)liveOperations;
+
+  PyOperationRef created = createInstance(std::move(contextRef), operation,
+                                          std::move(parentKeepAlive));
+  created->attached = false;
+  return created;
+}
+
+void PyOperation::checkValid() {
+  if (!valid) {
+    throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
+  }
+}
+
 //------------------------------------------------------------------------------
-// PyBlock, PyRegion, and PyOperation.
+// PyBlock, PyRegion.
 //------------------------------------------------------------------------------
 
 void PyRegion::attachToParent() {
@@ -865,29 +959,27 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
 void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of MlirContext
   py::class_<PyMlirContext>(m, "Context")
-      .def(py::init<>([]() {
-        MlirContext context = mlirContextCreate();
-        auto contextRef = PyMlirContext::forContext(context);
-        return contextRef.release();
-      }))
+      .def(py::init<>(&PyMlirContext::createNewContextForInit))
       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
       .def("_get_context_again",
            [](PyMlirContext &self) {
-             auto ref = PyMlirContext::forContext(self.get());
-             return ref.release();
+             PyMlirContextRef ref = PyMlirContext::forContext(self.get());
+             return ref.releaseObject();
            })
+      .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
       .def(
           "parse_module",
-          [](PyMlirContext &self, const std::string module) {
-            auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
+          [](PyMlirContext &self, const std::string moduleAsm) {
+            MlirModule module =
+                mlirModuleCreateParse(self.get(), moduleAsm.c_str());
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
-            if (mlirModuleIsNull(moduleRef)) {
+            if (mlirModuleIsNull(module)) {
               throw SetPyError(
                   PyExc_ValueError,
                   "Unable to parse module assembly (see diagnostics)");
             }
-            return PyModule(self.getRef(), moduleRef);
+            return PyModule::create(self.getRef(), module).releaseObject();
           },
           kContextParseDocstring)
       .def(
@@ -975,16 +1067,25 @@ void mlir::python::populateIRSubmodule(py::module &m) {
 
   // Mapping of Module
   py::class_<PyModule>(m, "Module")
+      .def_property_readonly(
+          "operation",
+          [](PyModule &self) {
+            return PyOperation::forOperation(self.getContext(),
+                                             mlirModuleGetOperation(self.get()),
+                                             self.getRef().releaseObject())
+                .releaseObject();
+          },
+          "Accesses the module as an operation")
       .def(
           "dump",
           [](PyModule &self) {
-            mlirOperationDump(mlirModuleGetOperation(self.module));
+            mlirOperationDump(mlirModuleGetOperation(self.get()));
           },
           kDumpDocstring)
       .def(
           "__str__",
           [](PyModule &self) {
-            auto operation = mlirModuleGetOperation(self.module);
+            MlirOperation operation = mlirModuleGetOperation(self.get());
             PyPrintAccumulator printAccum;
             mlirOperationPrint(operation, printAccum.getCallback(),
                                printAccum.getUserData());
@@ -992,6 +1093,31 @@ void mlir::python::populateIRSubmodule(py::module &m) {
           },
           kOperationStrDunderDocstring);
 
+  // Mapping of Operation.
+  py::class_<PyOperation>(m, "Operation")
+      .def_property_readonly(
+          "first_region",
+          [](PyOperation &self) {
+            self.checkValid();
+            if (mlirOperationGetNumRegions(self.get()) == 0) {
+              throw SetPyError(PyExc_IndexError, "Operation has no regions");
+            }
+            return PyRegion(self.getContext()->get(),
+                            mlirOperationGetRegion(self.get(), 0),
+                            /*detached=*/false);
+          },
+          py::keep_alive<0, 1>(), "Gets the operation's first region")
+      .def(
+          "__str__",
+          [](PyOperation &self) {
+            self.checkValid();
+            PyPrintAccumulator printAccum;
+            mlirOperationPrint(self.get(), printAccum.getCallback(),
+                               printAccum.getUserData());
+            return printAccum.join();
+          },
+          kTypeStrDunderDocstring);
+
   // Mapping of PyRegion.
   py::class_<PyRegion>(m, "Region")
       .def(

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index fa52c3979359..a7f6ee2425ad 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -19,32 +19,61 @@ namespace python {
 
 class PyMlirContext;
 class PyModule;
+class PyOperation;
 
-/// Holds a C++ PyMlirContext and associated py::object, making it convenient
-/// to have an auto-releasing C++-side keep-alive reference to the context.
-/// The reference to the PyMlirContext is a simple C++ reference and the
-/// py::object holds the reference count which keeps it alive.
-class PyMlirContextRef {
+/// Template for a reference to a concrete type which captures a python
+/// reference to its underlying python object.
+template <typename T>
+class PyObjectRef {
 public:
-  PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
-      : referrent(referrent), object(std::move(object)) {}
-  ~PyMlirContextRef() {}
+  PyObjectRef(T *referrent, pybind11::object object)
+      : referrent(referrent), object(std::move(object)) {
+    assert(this->referrent &&
+           "cannot construct PyObjectRef with null referrent");
+    assert(this->object && "cannot construct PyObjectRef with null object");
+  }
+  PyObjectRef(PyObjectRef &&other)
+      : referrent(other.referrent), object(std::move(other.object)) {
+    other.referrent = nullptr;
+    assert(!other.object);
+  }
+  PyObjectRef(const PyObjectRef &other)
+      : referrent(other.referrent), object(other.object /* copies */) {}
+  ~PyObjectRef() {}
+
+  int getRefCount() {
+    if (!object)
+      return 0;
+    return object.ref_count();
+  }
 
-  /// Releases the object held by this instance, causing its reference count
-  /// to remain artifically inflated by one. This must be used to return
-  /// the referenced PyMlirContext from a function. Otherwise, the destructor
-  /// of this reference would be called prior to the default take_ownership
-  /// policy assuming that the reference count has been transferred to it.
-  PyMlirContext *release();
+  /// Releases the object held by this instance, returning it.
+  /// This is the proper thing to return from a function that wants to return
+  /// the reference. Note that this does not work from initializers.
+  pybind11::object releaseObject() {
+    assert(referrent && object);
+    referrent = nullptr;
+    auto stolen = std::move(object);
+    return stolen;
+  }
 
-  PyMlirContext &operator->() { return referrent; }
-  pybind11::object getObject() { return object; }
+  T *operator->() {
+    assert(referrent && object);
+    return referrent;
+  }
+  pybind11::object getObject() {
+    assert(referrent && object);
+    return object;
+  }
+  operator bool() const { return referrent && object; }
 
 private:
-  PyMlirContext &referrent;
+  T *referrent;
   pybind11::object object;
 };
 
+using PyMlirContextRef = PyObjectRef<PyMlirContext>;
+
 /// Wrapper around MlirContext.
 class PyMlirContext {
 public:
@@ -52,6 +81,16 @@ class PyMlirContext {
   PyMlirContext(const PyMlirContext &) = delete;
   PyMlirContext(PyMlirContext &&) = delete;
 
+  /// For the case of a python __init__ (py::init) method, pybind11 is quite
+  /// strict about needing to return a pointer that is not yet associated to
+  /// an py::object. Since the forContext() method acts like a pool, possibly
+  /// returning a recycled context, it does not satisfy this need. The usual
+  /// way in python to accomplish such a thing is to override __new__, but
+  /// that is also not supported by pybind11. Instead, we use this entry
+  /// point which always constructs a fresh context (which cannot alias an
+  /// existing one because it is fresh).
+  static PyMlirContext *createNewContextForInit();
+
   /// Returns a context reference for the singleton PyMlirContext wrapper for
   /// the given context.
   static PyMlirContextRef forContext(MlirContext context);
@@ -63,29 +102,37 @@ class PyMlirContext {
   /// Gets a strong reference to this context, which will ensure it is kept
   /// alive for the life of the reference.
   PyMlirContextRef getRef() {
-    return PyMlirContextRef(
-        *this, pybind11::reinterpret_borrow<pybind11::object>(handle));
+    return PyMlirContextRef(this, pybind11::cast(this));
   }
 
   /// Gets the count of live context objects. Used for testing.
   static size_t getLiveCount();
 
+  /// Gets the count of live operations associated with this context.
+  /// Used for testing.
+  size_t getLiveOperationCount();
+
 private:
   PyMlirContext(MlirContext context);
-
   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
   // preserving the relationship that an MlirContext maps to a single
   // PyMlirContext wrapper. This could be replaced in the future with an
   // extension mechanism on the MlirContext for stashing user pointers.
   // Note that this holds a handle, which does not imply ownership.
   // Mappings will be removed when the context is destructed.
-  using LiveContextMap =
-      llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
+  using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
   static LiveContextMap &getLiveContexts();
 
+  // Interns all live operations associated with this context. Operations
+  // tracked in this map are valid. When an operation is invalidated, it is
+  // removed from this map, and while it still exists as an instance, any
+  // attempt to access it will raise an error.
+  using LiveOperationMap =
+      llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
+  LiveOperationMap liveOperations;
+
   MlirContext context;
-  // The handle is set as part of lookup with forContext() (post construction).
-  pybind11::handle handle;
+  friend class PyOperation;
 };
 
 /// Base class for all objects that directly or indirectly depend on an
@@ -94,7 +141,10 @@ class PyMlirContext {
 /// Immutable objects that depend on a context extend this directly.
 class BaseContextObject {
 public:
-  BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
+  BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
+    assert(this->contextRef &&
+           "context object constructed with null context ref");
+  }
 
   /// Accesses the context reference.
   PyMlirContextRef &getContext() { return contextRef; }
@@ -112,22 +162,90 @@ class PyLocation : public BaseContextObject {
 };
 
 /// Wrapper around MlirModule.
+/// This is the top-level, user-owned object that contains regions/ops/blocks.
+class PyModule;
+using PyModuleRef = PyObjectRef<PyModule>;
 class PyModule : public BaseContextObject {
 public:
-  PyModule(PyMlirContextRef contextRef, MlirModule module)
-      : BaseContextObject(std::move(contextRef)), module(module) {}
+  /// Creates a reference to the module
+  static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module);
   PyModule(PyModule &) = delete;
-  PyModule(PyModule &&other)
-      : BaseContextObject(std::move(other.getContext())) {
-    module = other.module;
-    other.module.ptr = nullptr;
-  }
   ~PyModule() {
     if (module.ptr)
       mlirModuleDestroy(module);
   }
 
+  /// Gets the backing MlirModule.
+  MlirModule get() { return module; }
+
+  /// Gets a strong reference to this module.
+  PyModuleRef getRef() {
+    return PyModuleRef(this,
+                       pybind11::reinterpret_borrow<pybind11::object>(handle));
+  }
+
+private:
+  PyModule(PyMlirContextRef contextRef, MlirModule module)
+      : BaseContextObject(std::move(contextRef)), module(module) {}
   MlirModule module;
+  pybind11::handle handle;
+};
+
+/// Wrapper around PyOperation.
+/// Operations exist in either an attached (dependent) or detached (top-level)
+/// state. In the detached state (as on creation), an operation is owned by
+/// the creator and its lifetime extends either until its reference count
+/// drops to zero or it is attached to a parent, at which point its lifetime
+/// is bounded by its top-level parent reference.
+class PyOperation;
+using PyOperationRef = PyObjectRef<PyOperation>;
+class PyOperation : public BaseContextObject {
+public:
+  ~PyOperation();
+  /// Returns a PyOperation for the given MlirOperation, optionally associating
+  /// it with a parentKeepAlive (which must match on all such calls for the
+  /// same operation).
+  static PyOperationRef
+  forOperation(PyMlirContextRef contextRef, MlirOperation operation,
+               pybind11::object parentKeepAlive = pybind11::object());
+
+  /// Creates a detached operation. The operation must not be associated with
+  /// any existing live operation.
+  static PyOperationRef
+  createDetached(PyMlirContextRef contextRef, MlirOperation operation,
+                 pybind11::object parentKeepAlive = pybind11::object());
+
+  /// Gets the backing operation.
+  MlirOperation get() {
+    checkValid();
+    return operation;
+  }
+
+  PyOperationRef getRef() {
+    return PyOperationRef(
+        this, pybind11::reinterpret_borrow<pybind11::object>(handle));
+  }
+
+  bool isAttached() { return attached; }
+  void checkValid();
+
+private:
+  PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
+  static PyOperationRef createInstance(PyMlirContextRef contextRef,
+                                       MlirOperation operation,
+                                       pybind11::object parentKeepAlive);
+
+  MlirOperation operation;
+  pybind11::handle handle;
+  // Keeps the parent alive, regardless of whether it is an Operation or
+  // Module.
+  // TODO: As implemented, this facility is only sufficient for modeling the
+  // trivial module parent back-reference. Generalize this to also account for
+  // transitions from detached to attached and address TODOs in the
+  // ir_operation.py regarding testing corresponding lifetime guarantees.
+  pybind11::object parentKeepAlive;
+  bool attached = true;
+  bool valid = true;
 };
 
 /// Wrapper around an MlirRegion.

diff  --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index 328dfb40b972..a2fd50056bf0 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -1,16 +1,21 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import gc
 import mlir
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testParsePrint
 def testParsePrint():
   ctx = mlir.ir.Context()
   t = ctx.parse_attr('"hello"')
+  ctx = None
+  gc.collect()
   # CHECK: "hello"
   print(str(t))
   # CHECK: Attribute("hello")

diff  --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
index a24962ad476d..ac42c61a0723 100644
--- a/mlir/test/Bindings/Python/ir_location.py
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -1,15 +1,21 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import gc
 import mlir
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
+
 
 # CHECK-LABEL: TEST: testUnknown
 def testUnknown():
   ctx = mlir.ir.Context()
   loc = ctx.get_unknown_location()
+  ctx = None
+  gc.collect()
   # CHECK: unknown str: loc(unknown)
   print("unknown str:", str(loc))
   # CHECK: unknown repr: loc(unknown)
@@ -22,6 +28,8 @@ def testUnknown():
 def testFileLineCol():
   ctx = mlir.ir.Context()
   loc = ctx.get_file_location("foo.txt", 123, 56)
+  ctx = None
+  gc.collect()
   # CHECK: file str: loc("foo.txt":123:56)
   print("file str:", str(loc))
   # CHECK: file repr: loc("foo.txt":123:56)

diff  --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py
index 3e7a53995a37..614e1af8b8e7 100644
--- a/mlir/test/Bindings/Python/ir_module.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -1,10 +1,14 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import gc
 import mlir
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
+
 
 # Verify successful parse.
 # CHECK-LABEL: TEST: testParseSuccess
@@ -12,6 +16,9 @@ def run(f):
 def testParseSuccess():
   ctx = mlir.ir.Context()
   module = ctx.parse_module(r"""module @successfulParse {}""")
+  print("CLEAR CONTEXT")
+  ctx = None  # Ensure that module captures the context.
+  gc.collect()
   module.dump()  # Just outputs to stderr. Verifies that it functions.
   print(str(module))
 
@@ -47,3 +54,33 @@ def testRoundtripUnicode():
   print(str(module))
 
 run(testRoundtripUnicode)
+
+
+# Tests that module.operation works and correctly interns instances.
+# CHECK-LABEL: TEST: testModuleOperation
+def testModuleOperation():
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""module @successfulParse {}""")
+  op1 = module.operation
+  assert ctx._get_live_operation_count() == 1
+  # CHECK: module @successfulParse
+  print(op1)
+
+  # Ensure that operations are the same on multiple calls.
+  op2 = module.operation
+  assert ctx._get_live_operation_count() == 1
+  assert op1 is op2
+
+  # Ensure that if module is de-referenced, the operations are still valid.
+  module = None
+  gc.collect()
+  print(op1)
+
+  # Collect and verify lifetime.
+  op1 = None
+  op2 = None
+  gc.collect()
+  print("LIVE OPERATIONS:", ctx._get_live_operation_count())
+  assert ctx._get_live_operation_count() == 0
+
+run(testModuleOperation)

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index c4246844f690..9c4c33a10ab8 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -1,10 +1,13 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import gc
 import mlir
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testDetachedRegionBlock

diff  --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index 4710bee27e37..b80cbebb10e2 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -1,16 +1,21 @@
 # RUN: %PYTHON %s | FileCheck %s
 
+import gc
 import mlir
 
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  gc.collect()
+  assert mlir.ir.Context._get_live_count() == 0
 
 
 # CHECK-LABEL: TEST: testParsePrint
 def testParsePrint():
   ctx = mlir.ir.Context()
   t = ctx.parse_type("i32")
+  ctx = None
+  gc.collect()
   # CHECK: i32
   print(str(t))
   # CHECK: Type(i32)


        


More information about the Mlir-commits mailing list