[Mlir-commits] [mlir] 85185b6 - First pass on MLIR python context lifetime management.

Stella Laurenzo llvmlistbot at llvm.org
Fri Sep 18 12:18:24 PDT 2020


Author: Stella Laurenzo
Date: 2020-09-18T12:17:50-07:00
New Revision: 85185b61b6371c29111611b8e3ac8d06403542c8

URL: https://github.com/llvm/llvm-project/commit/85185b61b6371c29111611b8e3ac8d06403542c8
DIFF: https://github.com/llvm/llvm-project/commit/85185b61b6371c29111611b8e3ac8d06403542c8.diff

LOG: First pass on MLIR python context lifetime management.

* Per thread https://llvm.discourse.group/t/revisiting-ownership-and-lifetime-in-the-python-bindings/1769
* Reworks contexts so it is always possible to get back to a py::object that holds the reference count for an arbitrary MlirContext.
* Retrofits some of the base classes to automatically take a reference to the context, elimintating keep_alives.
* More needs to be done, as discussed, when moving on to the operations/blocks/regions.

Differential Revision: https://reviews.llvm.org/D87886

Added: 
    mlir/test/Bindings/Python/context_lifecycle.py

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/IR.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 340f8c5d78ff..1cffa8a28d6a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -103,6 +103,9 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
 /** Creates a location with unknown position owned by the given context. */
 MlirLocation mlirLocationUnknownGet(MlirContext context);
 
+/** Gets the context that a location was created with. */
+MlirContext mlirLocationGetContext(MlirLocation location);
+
 /** Prints a location by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */
@@ -119,6 +122,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
 /** Parses a module from the string and transfers ownership to the caller. */
 MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
 
+/** Gets the context that a module was created with. */
+MlirContext mlirModuleGetContext(MlirModule module);
+
 /** Checks whether a module is null. */
 inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
 
@@ -342,6 +348,9 @@ void mlirTypeDump(MlirType type);
 /** Parses an attribute. The attribute is owned by the context. */
 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
 
+/** Gets the context that an attribute was created with. */
+MlirContext mlirAttributeGetContext(MlirAttribute attribute);
+
 /** Checks whether an attribute is null. */
 inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 527c530518ca..d7a0bd8ec1a9 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -170,6 +170,51 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
 
 } // namespace
 
+//------------------------------------------------------------------------------
+// PyMlirContext
+//------------------------------------------------------------------------------
+
+PyMlirContext *PyMlirContextRef::release() {
+  object.release();
+  return &referrent;
+}
+
+PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
+
+PyMlirContext::~PyMlirContext() {
+  // Note that the only public way to construct an instance is via the
+  // forContext method, which always puts the associated handle into
+  // liveContexts.
+  py::gil_scoped_acquire acquire;
+  getLiveContexts().erase(context.ptr);
+  mlirContextDestroy(context);
+}
+
+PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
+  py::gil_scoped_acquire acquire;
+  auto &liveContexts = getLiveContexts();
+  auto it = liveContexts.find(context.ptr);
+  if (it == liveContexts.end()) {
+    // Create.
+    PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+    py::object pyRef = py::cast(unownedContextWrapper);
+    unownedContextWrapper->handle = pyRef;
+    liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
+    return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
+  } else {
+    // Use existing.
+    py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+    return PyMlirContextRef(*it->second.second, std::move(pyRef));
+  }
+}
+
+PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
+  static LiveContextMap liveContexts;
+  return liveContexts;
+}
+
+size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+
 //------------------------------------------------------------------------------
 // PyBlock, PyRegion, and PyOperation.
 //------------------------------------------------------------------------------
@@ -234,9 +279,10 @@ class PyConcreteAttribute : public BaseTy {
   using IsAFunctionTy = int (*)(MlirAttribute);
 
   PyConcreteAttribute() = default;
-  PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
+  PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+      : BaseTy(std::move(contextRef), attr) {}
   PyConcreteAttribute(PyAttribute &orig)
-      : PyConcreteAttribute(castFrom(orig)) {}
+      : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
 
   static MlirAttribute castFrom(PyAttribute &orig) {
     if (!DerivedTy::isaFunction(orig.attr)) {
@@ -269,18 +315,18 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
         "get",
         [](PyMlirContext &context, std::string value) {
           MlirAttribute attr =
-              mlirStringAttrGet(context.context, value.size(), &value[0]);
-          return PyStringAttribute(attr);
+              mlirStringAttrGet(context.get(), value.size(), &value[0]);
+          return PyStringAttribute(context.getRef(), attr);
         },
-        py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
+        "Gets a uniqued string attribute");
     c.def_static(
         "get_typed",
         [](PyType &type, std::string value) {
           MlirAttribute attr =
               mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
-          return PyStringAttribute(attr);
+          return PyStringAttribute(type.getContext(), attr);
         },
-        py::keep_alive<0, 1>(),
+
         "Gets a uniqued string attribute associated to a type");
     c.def_property_readonly(
         "value",
@@ -315,8 +361,10 @@ class PyConcreteType : public BaseTy {
   using IsAFunctionTy = int (*)(MlirType);
 
   PyConcreteType() = default;
-  PyConcreteType(MlirType t) : BaseTy(t) {}
-  PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
+  PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+      : BaseTy(std::move(contextRef), t) {}
+  PyConcreteType(PyType &orig)
+      : PyConcreteType(orig.getContext(), castFrom(orig)) {}
 
   static MlirType castFrom(PyType &orig) {
     if (!DerivedTy::isaFunction(orig.type)) {
@@ -348,24 +396,24 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
     c.def_static(
         "get_signless",
         [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeGet(context.context, width);
-          return PyIntegerType(t);
+          MlirType t = mlirIntegerTypeGet(context.get(), width);
+          return PyIntegerType(context.getRef(), t);
         },
-        py::keep_alive<0, 1>(), "Create a signless integer type");
+        "Create a signless integer type");
     c.def_static(
         "get_signed",
         [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeSignedGet(context.context, width);
-          return PyIntegerType(t);
+          MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
+          return PyIntegerType(context.getRef(), t);
         },
-        py::keep_alive<0, 1>(), "Create a signed integer type");
+        "Create a signed integer type");
     c.def_static(
         "get_unsigned",
         [](PyMlirContext &context, unsigned width) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
-          return PyIntegerType(t);
+          MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
+          return PyIntegerType(context.getRef(), t);
         },
-        py::keep_alive<0, 1>(), "Create an unsigned integer type");
+        "Create an unsigned integer type");
     c.def_property_readonly(
         "width",
         [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
@@ -400,10 +448,10 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirIndexTypeGet(context.context);
-            return PyIndexType(t);
+            MlirType t = mlirIndexTypeGet(context.get());
+            return PyIndexType(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a index type.");
+          "Create a index type.");
   }
 };
 
@@ -416,10 +464,10 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirBF16TypeGet(context.context);
-            return PyBF16Type(t);
+            MlirType t = mlirBF16TypeGet(context.get());
+            return PyBF16Type(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a bf16 type.");
+          "Create a bf16 type.");
   }
 };
 
@@ -432,10 +480,10 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF16TypeGet(context.context);
-            return PyF16Type(t);
+            MlirType t = mlirF16TypeGet(context.get());
+            return PyF16Type(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a f16 type.");
+          "Create a f16 type.");
   }
 };
 
@@ -448,10 +496,10 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF32TypeGet(context.context);
-            return PyF32Type(t);
+            MlirType t = mlirF32TypeGet(context.get());
+            return PyF32Type(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a f32 type.");
+          "Create a f32 type.");
   }
 };
 
@@ -464,10 +512,10 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirF64TypeGet(context.context);
-            return PyF64Type(t);
+            MlirType t = mlirF64TypeGet(context.get());
+            return PyF64Type(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a f64 type.");
+          "Create a f64 type.");
   }
 };
 
@@ -480,10 +528,10 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
 
   static void bindDerived(ClassTy &c) {
     c.def(py::init([](PyMlirContext &context) {
-            MlirType t = mlirNoneTypeGet(context.context);
-            return PyNoneType(t);
+            MlirType t = mlirNoneTypeGet(context.get());
+            return PyNoneType(context.getRef(), t);
           }),
-          py::keep_alive<0, 1>(), "Create a none type.");
+          "Create a none type.");
   }
 };
 
@@ -501,7 +549,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
           // The element must be a floating point or integer scalar type.
           if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
             MlirType t = mlirComplexTypeGet(elementType.type);
-            return PyComplexType(t);
+            return PyComplexType(elementType.getContext(), t);
           }
           throw SetPyError(
               PyExc_ValueError,
@@ -509,12 +557,12 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
                   py::repr(py::cast(elementType)).cast<std::string>() +
                   "' and expected floating point or integer type.");
         },
-        py::keep_alive<0, 1>(), "Create a complex type");
+        "Create a complex type");
     c.def_property_readonly(
         "element_type",
         [](PyComplexType &self) -> PyType {
           MlirType t = mlirComplexTypeGetElementType(self.type);
-          return PyType(t);
+          return PyType(self.getContext(), t);
         },
         "Returns element type.");
   }
@@ -531,9 +579,9 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
         "element_type",
         [](PyShapedType &self) {
           MlirType t = mlirShapedTypeGetElementType(self.type);
-          return PyType(t);
+          return PyType(self.getContext(), t);
         },
-        py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
+        "Returns the element type of the shaped type.");
     c.def_property_readonly(
         "has_rank",
         [](PyShapedType &self) -> bool {
@@ -616,9 +664,9 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
                     py::repr(py::cast(elementType)).cast<std::string>() +
                     "' and expected floating point or integer type.");
           }
-          return PyVectorType(t);
+          return PyVectorType(elementType.getContext(), t);
         },
-        py::keep_alive<0, 2>(), "Create a vector type");
+        "Create a vector type");
   }
 };
 
@@ -648,9 +696,9 @@ class PyRankedTensorType
                     "complex "
                     "type.");
           }
-          return PyRankedTensorType(t);
+          return PyRankedTensorType(elementType.getContext(), t);
         },
-        py::keep_alive<0, 2>(), "Create a ranked tensor type");
+        "Create a ranked tensor type");
   }
 };
 
@@ -680,9 +728,9 @@ class PyUnrankedTensorType
                     "complex "
                     "type.");
           }
-          return PyUnrankedTensorType(t);
+          return PyUnrankedTensorType(elementType.getContext(), t);
         },
-        py::keep_alive<0, 1>(), "Create a unranked tensor type");
+        "Create a unranked tensor type");
   }
 };
 
@@ -715,9 +763,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
                      "complex "
                      "type.");
            }
-           return PyMemRefType(t);
+           return PyMemRefType(elementType.getContext(), t);
          },
-         py::keep_alive<0, 1>(), "Create a memref type")
+         "Create a memref type")
         .def_property_readonly(
             "num_affine_maps",
             [](PyMemRefType &self) -> intptr_t {
@@ -760,9 +808,9 @@ class PyUnrankedMemRefType
                      "complex "
                      "type.");
            }
-           return PyUnrankedMemRefType(t);
+           return PyUnrankedMemRefType(elementType.getContext(), t);
          },
-         py::keep_alive<0, 1>(), "Create a unranked memref type")
+         "Create a unranked memref type")
         .def_property_readonly(
             "memory_space",
             [](PyUnrankedMemRefType &self) -> unsigned {
@@ -788,17 +836,17 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
           SmallVector<MlirType, 4> elements;
           for (auto element : elementList)
             elements.push_back(element.cast<PyType>().type);
-          MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
-          return PyTupleType(t);
+          MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
+          return PyTupleType(context.getRef(), t);
         },
-        py::keep_alive<0, 1>(), "Create a tuple type");
+        "Create a tuple type");
     c.def(
         "get_type",
         [](PyTupleType &self, intptr_t pos) -> PyType {
           MlirType t = mlirTupleTypeGetType(self.type, pos);
-          return PyType(t);
+          return PyType(self.getContext(), t);
         },
-        py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
+        "Returns the pos-th type in the tuple type.");
     c.def_property_readonly(
         "num_types",
         [](PyTupleType &self) -> intptr_t {
@@ -817,12 +865,21 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
 void mlir::python::populateIRSubmodule(py::module &m) {
   // Mapping of MlirContext
   py::class_<PyMlirContext>(m, "Context")
-      .def(py::init<>())
+      .def(py::init<>([]() {
+        MlirContext context = mlirContextCreate();
+        auto contextRef = PyMlirContext::forContext(context);
+        return contextRef.release();
+      }))
+      .def_static("_get_live_count", &PyMlirContext::getLiveCount)
+      .def("_get_context_again",
+           [](PyMlirContext &self) {
+             auto ref = PyMlirContext::forContext(self.get());
+             return ref.release();
+           })
       .def(
           "parse_module",
           [](PyMlirContext &self, const std::string module) {
-            auto moduleRef =
-                mlirModuleCreateParse(self.context, module.c_str());
+            auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirModuleIsNull(moduleRef)) {
@@ -830,14 +887,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                   PyExc_ValueError,
                   "Unable to parse module assembly (see diagnostics)");
             }
-            return PyModule(moduleRef);
+            return PyModule(self.getRef(), moduleRef);
           },
-          py::keep_alive<0, 1>(), kContextParseDocstring)
+          kContextParseDocstring)
       .def(
           "parse_attr",
           [](PyMlirContext &self, std::string attrSpec) {
             MlirAttribute type =
-                mlirAttributeParseGet(self.context, attrSpec.c_str());
+                mlirAttributeParseGet(self.get(), attrSpec.c_str());
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirAttributeIsNull(type)) {
@@ -845,13 +902,13 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                                llvm::Twine("Unable to parse attribute: '") +
                                    attrSpec + "'");
             }
-            return PyAttribute(type);
+            return PyAttribute(self.getRef(), type);
           },
           py::keep_alive<0, 1>())
       .def(
           "parse_type",
           [](PyMlirContext &self, std::string typeSpec) {
-            MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+            MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
             // TODO: Rework error reporting once diagnostic engine is exposed
             // in C API.
             if (mlirTypeIsNull(type)) {
@@ -859,30 +916,32 @@ void mlir::python::populateIRSubmodule(py::module &m) {
                                llvm::Twine("Unable to parse type: '") +
                                    typeSpec + "'");
             }
-            return PyType(type);
+            return PyType(self.getRef(), type);
           },
-          py::keep_alive<0, 1>(), kContextParseTypeDocstring)
+          kContextParseTypeDocstring)
       .def(
           "get_unknown_location",
           [](PyMlirContext &self) {
-            return PyLocation(mlirLocationUnknownGet(self.context));
+            return PyLocation(self.getRef(),
+                              mlirLocationUnknownGet(self.get()));
           },
-          py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
+          kContextGetUnknownLocationDocstring)
       .def(
           "get_file_location",
           [](PyMlirContext &self, std::string filename, int line, int col) {
-            return PyLocation(mlirLocationFileLineColGet(
-                self.context, filename.c_str(), line, col));
+            return PyLocation(self.getRef(),
+                              mlirLocationFileLineColGet(
+                                  self.get(), filename.c_str(), line, col));
           },
-          py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
-          py::arg("filename"), py::arg("line"), py::arg("col"))
+          kContextGetFileLocationDocstring, py::arg("filename"),
+          py::arg("line"), py::arg("col"))
       .def(
           "create_region",
           [](PyMlirContext &self) {
             // The creating context is explicitly captured on regions to
             // facilitate illegal assemblies of objects from multiple contexts
             // that would invalidate the memory model.
-            return PyRegion(self.context, mlirRegionCreate(),
+            return PyRegion(self.get(), mlirRegionCreate(),
                             /*detached=*/true);
           },
           py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
@@ -893,7 +952,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             // types must be from the same context.
             for (auto pyType : pyTypes) {
               if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
-                                    self.context)) {
+                                    self.get())) {
                 throw SetPyError(
                     PyExc_ValueError,
                     "All types used to construct a block must be from "
@@ -902,8 +961,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
             }
             llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
                                                  pyTypes.end());
-            return PyBlock(self.context,
-                           mlirBlockCreate(types.size(), &types[0]),
+            return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
                            /*detached=*/true);
           },
           py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
@@ -1063,7 +1121,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def_property_readonly(
           "attr",
           [](PyNamedAttribute &self) {
-            return PyAttribute(self.namedAttr.attribute);
+            // TODO: When named attribute is removed/refactored, also remove
+            // this constructor (it does an inefficient table lookup).
+            auto contextRef = PyMlirContext::forContext(
+                mlirAttributeGetContext(self.namedAttr.attribute));
+            return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
           },
           py::keep_alive<0, 1>(),
           "The underlying generic attribute of the NamedAttribute binding");

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 20fe8014e138..fa52c3979359 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -12,6 +12,7 @@
 #include <pybind11/pybind11.h>
 
 #include "mlir-c/IR.h"
+#include "llvm/ADT/DenseMap.h"
 
 namespace mlir {
 namespace python {
@@ -19,28 +20,105 @@ namespace python {
 class PyMlirContext;
 class PyModule;
 
+/// Holds a C++ PyMlirContext and associated py::object, making it convenient
+/// to have an auto-releasing C++-side keep-alive reference to the context.
+/// The reference to the PyMlirContext is a simple C++ reference and the
+/// py::object holds the reference count which keeps it alive.
+class PyMlirContextRef {
+public:
+  PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
+      : referrent(referrent), object(std::move(object)) {}
+  ~PyMlirContextRef() {}
+
+  /// Releases the object held by this instance, causing its reference count
+  /// to remain artifically inflated by one. This must be used to return
+  /// the referenced PyMlirContext from a function. Otherwise, the destructor
+  /// of this reference would be called prior to the default take_ownership
+  /// policy assuming that the reference count has been transferred to it.
+  PyMlirContext *release();
+
+  PyMlirContext &operator->() { return referrent; }
+  pybind11::object getObject() { return object; }
+
+private:
+  PyMlirContext &referrent;
+  pybind11::object object;
+};
+
 /// Wrapper around MlirContext.
 class PyMlirContext {
 public:
-  PyMlirContext() { context = mlirContextCreate(); }
-  ~PyMlirContext() { mlirContextDestroy(context); }
+  PyMlirContext() = delete;
+  PyMlirContext(const PyMlirContext &) = delete;
+  PyMlirContext(PyMlirContext &&) = delete;
+
+  /// Returns a context reference for the singleton PyMlirContext wrapper for
+  /// the given context.
+  static PyMlirContextRef forContext(MlirContext context);
+  ~PyMlirContext();
+
+  /// Accesses the underlying MlirContext.
+  MlirContext get() { return context; }
+
+  /// Gets a strong reference to this context, which will ensure it is kept
+  /// alive for the life of the reference.
+  PyMlirContextRef getRef() {
+    return PyMlirContextRef(
+        *this, pybind11::reinterpret_borrow<pybind11::object>(handle));
+  }
+
+  /// Gets the count of live context objects. Used for testing.
+  static size_t getLiveCount();
+
+private:
+  PyMlirContext(MlirContext context);
+
+  // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
+  // preserving the relationship that an MlirContext maps to a single
+  // PyMlirContext wrapper. This could be replaced in the future with an
+  // extension mechanism on the MlirContext for stashing user pointers.
+  // Note that this holds a handle, which does not imply ownership.
+  // Mappings will be removed when the context is destructed.
+  using LiveContextMap =
+      llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
+  static LiveContextMap &getLiveContexts();
 
   MlirContext context;
+  // The handle is set as part of lookup with forContext() (post construction).
+  pybind11::handle handle;
+};
+
+/// Base class for all objects that directly or indirectly depend on an
+/// MlirContext. The lifetime of the context will extend at least to the
+/// lifetime of these instances.
+/// Immutable objects that depend on a context extend this directly.
+class BaseContextObject {
+public:
+  BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
+
+  /// Accesses the context reference.
+  PyMlirContextRef &getContext() { return contextRef; }
+
+private:
+  PyMlirContextRef contextRef;
 };
 
 /// Wrapper around an MlirLocation.
-class PyLocation {
+class PyLocation : public BaseContextObject {
 public:
-  PyLocation(MlirLocation loc) : loc(loc) {}
+  PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
+      : BaseContextObject(std::move(contextRef)), loc(loc) {}
   MlirLocation loc;
 };
 
 /// Wrapper around MlirModule.
-class PyModule {
+class PyModule : public BaseContextObject {
 public:
-  PyModule(MlirModule module) : module(module) {}
+  PyModule(PyMlirContextRef contextRef, MlirModule module)
+      : BaseContextObject(std::move(contextRef)), module(module) {}
   PyModule(PyModule &) = delete;
-  PyModule(PyModule &&other) {
+  PyModule(PyModule &&other)
+      : BaseContextObject(std::move(other.getContext())) {
     module = other.module;
     other.module.ptr = nullptr;
   }
@@ -120,9 +198,10 @@ class PyBlock {
 
 /// Wrapper around the generic MlirAttribute.
 /// The lifetime of a type is bound by the PyContext that created it.
-class PyAttribute {
+class PyAttribute : public BaseContextObject {
 public:
-  PyAttribute(MlirAttribute attr) : attr(attr) {}
+  PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+      : BaseContextObject(std::move(contextRef)), attr(attr) {}
   bool operator==(const PyAttribute &other);
 
   MlirAttribute attr;
@@ -153,9 +232,10 @@ class PyNamedAttribute {
 
 /// Wrapper around the generic MlirType.
 /// The lifetime of a type is bound by the PyContext that created it.
-class PyType {
+class PyType : public BaseContextObject {
 public:
-  PyType(MlirType type) : type(type) {}
+  PyType(PyMlirContextRef contextRef, MlirType type)
+      : BaseContextObject(std::move(contextRef)), type(type) {}
   bool operator==(const PyType &other);
   operator MlirType() const { return type; }
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8611d6537371..0304d977f494 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -48,6 +48,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
   return wrap(UnknownLoc::get(unwrap(context)));
 }
 
+MlirContext mlirLocationGetContext(MlirLocation location) {
+  return wrap(unwrap(location).getContext());
+}
+
 void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
                        void *userData) {
   detail::CallbackOstream stream(callback, userData);
@@ -70,6 +74,10 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
   return MlirModule{owning.release().getOperation()};
 }
 
+MlirContext mlirModuleGetContext(MlirModule module) {
+  return wrap(unwrap(module).getContext());
+}
+
 void mlirModuleDestroy(MlirModule module) {
   // Transfer ownership to an OwningModuleRef so that its destructor is called.
   OwningModuleRef(unwrap(module));
@@ -349,6 +357,10 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
   return wrap(mlir::parseAttribute(attr, unwrap(context)));
 }
 
+MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
+  return wrap(unwrap(attribute).getContext());
+}
+
 int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
   return unwrap(a1) == unwrap(a2);
 }

diff  --git a/mlir/test/Bindings/Python/context_lifecycle.py b/mlir/test/Bindings/Python/context_lifecycle.py
new file mode 100644
index 000000000000..e2b287061b22
--- /dev/null
+++ b/mlir/test/Bindings/Python/context_lifecycle.py
@@ -0,0 +1,42 @@
+# RUN: %PYTHON %s
+# Standalone sanity check of context life-cycle.
+import gc
+import mlir
+
+assert mlir.ir.Context._get_live_count() == 0
+
+# Create first context.
+print("CREATE C1")
+c1 = mlir.ir.Context()
+assert mlir.ir.Context._get_live_count() == 1
+c1_repr = repr(c1)
+print("C1 = ", c1_repr)
+
+print("GETTING AGAIN...")
+c2 = c1._get_context_again()
+c2_repr = repr(c2)
+assert mlir.ir.Context._get_live_count() == 1
+assert c1_repr == c2_repr
+
+print("C2 =", c2)
+
+# Make sure new contexts on constructor.
+print("CREATE C3")
+c3 = mlir.ir.Context()
+assert mlir.ir.Context._get_live_count() == 2
+c3_repr = repr(c3)
+print("C3 =", c3)
+assert c3_repr != c1_repr
+print("FREE C3")
+c3 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 1
+
+print("Free C1")
+c1 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 1
+print("Free C2")
+c2 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0


        


More information about the Mlir-commits mailing list