[Mlir-commits] [mlir] 85185b6 - First pass on MLIR python context lifetime management.
Stella Laurenzo
llvmlistbot at llvm.org
Fri Sep 18 12:18:24 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-18T12:17:50-07:00
New Revision: 85185b61b6371c29111611b8e3ac8d06403542c8
URL: https://github.com/llvm/llvm-project/commit/85185b61b6371c29111611b8e3ac8d06403542c8
DIFF: https://github.com/llvm/llvm-project/commit/85185b61b6371c29111611b8e3ac8d06403542c8.diff
LOG: First pass on MLIR python context lifetime management.
* Per thread https://llvm.discourse.group/t/revisiting-ownership-and-lifetime-in-the-python-bindings/1769
* Reworks contexts so it is always possible to get back to a py::object that holds the reference count for an arbitrary MlirContext.
* Retrofits some of the base classes to automatically take a reference to the context, elimintating keep_alives.
* More needs to be done, as discussed, when moving on to the operations/blocks/regions.
Differential Revision: https://reviews.llvm.org/D87886
Added:
mlir/test/Bindings/Python/context_lifecycle.py
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/CAPI/IR/IR.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 340f8c5d78ff..1cffa8a28d6a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -103,6 +103,9 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
/** Creates a location with unknown position owned by the given context. */
MlirLocation mlirLocationUnknownGet(MlirContext context);
+/** Gets the context that a location was created with. */
+MlirContext mlirLocationGetContext(MlirLocation location);
+
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
@@ -119,6 +122,9 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location);
/** Parses a module from the string and transfers ownership to the caller. */
MlirModule mlirModuleCreateParse(MlirContext context, const char *module);
+/** Gets the context that a module was created with. */
+MlirContext mlirModuleGetContext(MlirModule module);
+
/** Checks whether a module is null. */
inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; }
@@ -342,6 +348,9 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
+/** Gets the context that an attribute was created with. */
+MlirContext mlirAttributeGetContext(MlirAttribute attribute);
+
/** Checks whether an attribute is null. */
inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 527c530518ca..d7a0bd8ec1a9 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -170,6 +170,51 @@ int mlirTypeIsAIntegerOrFloat(MlirType type) {
} // namespace
+//------------------------------------------------------------------------------
+// PyMlirContext
+//------------------------------------------------------------------------------
+
+PyMlirContext *PyMlirContextRef::release() {
+ object.release();
+ return &referrent;
+}
+
+PyMlirContext::PyMlirContext(MlirContext context) : context(context) {}
+
+PyMlirContext::~PyMlirContext() {
+ // Note that the only public way to construct an instance is via the
+ // forContext method, which always puts the associated handle into
+ // liveContexts.
+ py::gil_scoped_acquire acquire;
+ getLiveContexts().erase(context.ptr);
+ mlirContextDestroy(context);
+}
+
+PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
+ py::gil_scoped_acquire acquire;
+ auto &liveContexts = getLiveContexts();
+ auto it = liveContexts.find(context.ptr);
+ if (it == liveContexts.end()) {
+ // Create.
+ PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+ py::object pyRef = py::cast(unownedContextWrapper);
+ unownedContextWrapper->handle = pyRef;
+ liveContexts[context.ptr] = std::make_pair(pyRef, unownedContextWrapper);
+ return PyMlirContextRef(*unownedContextWrapper, std::move(pyRef));
+ } else {
+ // Use existing.
+ py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ return PyMlirContextRef(*it->second.second, std::move(pyRef));
+ }
+}
+
+PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
+ static LiveContextMap liveContexts;
+ return liveContexts;
+}
+
+size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
+
//------------------------------------------------------------------------------
// PyBlock, PyRegion, and PyOperation.
//------------------------------------------------------------------------------
@@ -234,9 +279,10 @@ class PyConcreteAttribute : public BaseTy {
using IsAFunctionTy = int (*)(MlirAttribute);
PyConcreteAttribute() = default;
- PyConcreteAttribute(MlirAttribute attr) : BaseTy(attr) {}
+ PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+ : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute &orig)
- : PyConcreteAttribute(castFrom(orig)) {}
+ : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig.attr)) {
@@ -269,18 +315,18 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
"get",
[](PyMlirContext &context, std::string value) {
MlirAttribute attr =
- mlirStringAttrGet(context.context, value.size(), &value[0]);
- return PyStringAttribute(attr);
+ mlirStringAttrGet(context.get(), value.size(), &value[0]);
+ return PyStringAttribute(context.getRef(), attr);
},
- py::keep_alive<0, 1>(), "Gets a uniqued string attribute");
+ "Gets a uniqued string attribute");
c.def_static(
"get_typed",
[](PyType &type, std::string value) {
MlirAttribute attr =
mlirStringAttrTypedGet(type.type, value.size(), &value[0]);
- return PyStringAttribute(attr);
+ return PyStringAttribute(type.getContext(), attr);
},
- py::keep_alive<0, 1>(),
+
"Gets a uniqued string attribute associated to a type");
c.def_property_readonly(
"value",
@@ -315,8 +361,10 @@ class PyConcreteType : public BaseTy {
using IsAFunctionTy = int (*)(MlirType);
PyConcreteType() = default;
- PyConcreteType(MlirType t) : BaseTy(t) {}
- PyConcreteType(PyType &orig) : PyConcreteType(castFrom(orig)) {}
+ PyConcreteType(PyMlirContextRef contextRef, MlirType t)
+ : BaseTy(std::move(contextRef), t) {}
+ PyConcreteType(PyType &orig)
+ : PyConcreteType(orig.getContext(), castFrom(orig)) {}
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig.type)) {
@@ -348,24 +396,24 @@ class PyIntegerType : public PyConcreteType<PyIntegerType> {
c.def_static(
"get_signless",
[](PyMlirContext &context, unsigned width) {
- MlirType t = mlirIntegerTypeGet(context.context, width);
- return PyIntegerType(t);
+ MlirType t = mlirIntegerTypeGet(context.get(), width);
+ return PyIntegerType(context.getRef(), t);
},
- py::keep_alive<0, 1>(), "Create a signless integer type");
+ "Create a signless integer type");
c.def_static(
"get_signed",
[](PyMlirContext &context, unsigned width) {
- MlirType t = mlirIntegerTypeSignedGet(context.context, width);
- return PyIntegerType(t);
+ MlirType t = mlirIntegerTypeSignedGet(context.get(), width);
+ return PyIntegerType(context.getRef(), t);
},
- py::keep_alive<0, 1>(), "Create a signed integer type");
+ "Create a signed integer type");
c.def_static(
"get_unsigned",
[](PyMlirContext &context, unsigned width) {
- MlirType t = mlirIntegerTypeUnsignedGet(context.context, width);
- return PyIntegerType(t);
+ MlirType t = mlirIntegerTypeUnsignedGet(context.get(), width);
+ return PyIntegerType(context.getRef(), t);
},
- py::keep_alive<0, 1>(), "Create an unsigned integer type");
+ "Create an unsigned integer type");
c.def_property_readonly(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self.type); },
@@ -400,10 +448,10 @@ class PyIndexType : public PyConcreteType<PyIndexType> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirIndexTypeGet(context.context);
- return PyIndexType(t);
+ MlirType t = mlirIndexTypeGet(context.get());
+ return PyIndexType(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a index type.");
+ "Create a index type.");
}
};
@@ -416,10 +464,10 @@ class PyBF16Type : public PyConcreteType<PyBF16Type> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirBF16TypeGet(context.context);
- return PyBF16Type(t);
+ MlirType t = mlirBF16TypeGet(context.get());
+ return PyBF16Type(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a bf16 type.");
+ "Create a bf16 type.");
}
};
@@ -432,10 +480,10 @@ class PyF16Type : public PyConcreteType<PyF16Type> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF16TypeGet(context.context);
- return PyF16Type(t);
+ MlirType t = mlirF16TypeGet(context.get());
+ return PyF16Type(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a f16 type.");
+ "Create a f16 type.");
}
};
@@ -448,10 +496,10 @@ class PyF32Type : public PyConcreteType<PyF32Type> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF32TypeGet(context.context);
- return PyF32Type(t);
+ MlirType t = mlirF32TypeGet(context.get());
+ return PyF32Type(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a f32 type.");
+ "Create a f32 type.");
}
};
@@ -464,10 +512,10 @@ class PyF64Type : public PyConcreteType<PyF64Type> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirF64TypeGet(context.context);
- return PyF64Type(t);
+ MlirType t = mlirF64TypeGet(context.get());
+ return PyF64Type(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a f64 type.");
+ "Create a f64 type.");
}
};
@@ -480,10 +528,10 @@ class PyNoneType : public PyConcreteType<PyNoneType> {
static void bindDerived(ClassTy &c) {
c.def(py::init([](PyMlirContext &context) {
- MlirType t = mlirNoneTypeGet(context.context);
- return PyNoneType(t);
+ MlirType t = mlirNoneTypeGet(context.get());
+ return PyNoneType(context.getRef(), t);
}),
- py::keep_alive<0, 1>(), "Create a none type.");
+ "Create a none type.");
}
};
@@ -501,7 +549,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType.type)) {
MlirType t = mlirComplexTypeGet(elementType.type);
- return PyComplexType(t);
+ return PyComplexType(elementType.getContext(), t);
}
throw SetPyError(
PyExc_ValueError,
@@ -509,12 +557,12 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
},
- py::keep_alive<0, 1>(), "Create a complex type");
+ "Create a complex type");
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
MlirType t = mlirComplexTypeGetElementType(self.type);
- return PyType(t);
+ return PyType(self.getContext(), t);
},
"Returns element type.");
}
@@ -531,9 +579,9 @@ class PyShapedType : public PyConcreteType<PyShapedType> {
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self.type);
- return PyType(t);
+ return PyType(self.getContext(), t);
},
- py::keep_alive<0, 1>(), "Returns the element type of the shaped type.");
+ "Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
[](PyShapedType &self) -> bool {
@@ -616,9 +664,9 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
- return PyVectorType(t);
+ return PyVectorType(elementType.getContext(), t);
},
- py::keep_alive<0, 2>(), "Create a vector type");
+ "Create a vector type");
}
};
@@ -648,9 +696,9 @@ class PyRankedTensorType
"complex "
"type.");
}
- return PyRankedTensorType(t);
+ return PyRankedTensorType(elementType.getContext(), t);
},
- py::keep_alive<0, 2>(), "Create a ranked tensor type");
+ "Create a ranked tensor type");
}
};
@@ -680,9 +728,9 @@ class PyUnrankedTensorType
"complex "
"type.");
}
- return PyUnrankedTensorType(t);
+ return PyUnrankedTensorType(elementType.getContext(), t);
},
- py::keep_alive<0, 1>(), "Create a unranked tensor type");
+ "Create a unranked tensor type");
}
};
@@ -715,9 +763,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
"complex "
"type.");
}
- return PyMemRefType(t);
+ return PyMemRefType(elementType.getContext(), t);
},
- py::keep_alive<0, 1>(), "Create a memref type")
+ "Create a memref type")
.def_property_readonly(
"num_affine_maps",
[](PyMemRefType &self) -> intptr_t {
@@ -760,9 +808,9 @@ class PyUnrankedMemRefType
"complex "
"type.");
}
- return PyUnrankedMemRefType(t);
+ return PyUnrankedMemRefType(elementType.getContext(), t);
},
- py::keep_alive<0, 1>(), "Create a unranked memref type")
+ "Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> unsigned {
@@ -788,17 +836,17 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>().type);
- MlirType t = mlirTupleTypeGet(context.context, num, elements.data());
- return PyTupleType(t);
+ MlirType t = mlirTupleTypeGet(context.get(), num, elements.data());
+ return PyTupleType(context.getRef(), t);
},
- py::keep_alive<0, 1>(), "Create a tuple type");
+ "Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
MlirType t = mlirTupleTypeGetType(self.type, pos);
- return PyType(t);
+ return PyType(self.getContext(), t);
},
- py::keep_alive<0, 1>(), "Returns the pos-th type in the tuple type.");
+ "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
"num_types",
[](PyTupleType &self) -> intptr_t {
@@ -817,12 +865,21 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of MlirContext
py::class_<PyMlirContext>(m, "Context")
- .def(py::init<>())
+ .def(py::init<>([]() {
+ MlirContext context = mlirContextCreate();
+ auto contextRef = PyMlirContext::forContext(context);
+ return contextRef.release();
+ }))
+ .def_static("_get_live_count", &PyMlirContext::getLiveCount)
+ .def("_get_context_again",
+ [](PyMlirContext &self) {
+ auto ref = PyMlirContext::forContext(self.get());
+ return ref.release();
+ })
.def(
"parse_module",
[](PyMlirContext &self, const std::string module) {
- auto moduleRef =
- mlirModuleCreateParse(self.context, module.c_str());
+ auto moduleRef = mlirModuleCreateParse(self.get(), module.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirModuleIsNull(moduleRef)) {
@@ -830,14 +887,14 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
- return PyModule(moduleRef);
+ return PyModule(self.getRef(), moduleRef);
},
- py::keep_alive<0, 1>(), kContextParseDocstring)
+ kContextParseDocstring)
.def(
"parse_attr",
[](PyMlirContext &self, std::string attrSpec) {
MlirAttribute type =
- mlirAttributeParseGet(self.context, attrSpec.c_str());
+ mlirAttributeParseGet(self.get(), attrSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirAttributeIsNull(type)) {
@@ -845,13 +902,13 @@ void mlir::python::populateIRSubmodule(py::module &m) {
llvm::Twine("Unable to parse attribute: '") +
attrSpec + "'");
}
- return PyAttribute(type);
+ return PyAttribute(self.getRef(), type);
},
py::keep_alive<0, 1>())
.def(
"parse_type",
[](PyMlirContext &self, std::string typeSpec) {
- MlirType type = mlirTypeParseGet(self.context, typeSpec.c_str());
+ MlirType type = mlirTypeParseGet(self.get(), typeSpec.c_str());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(type)) {
@@ -859,30 +916,32 @@ void mlir::python::populateIRSubmodule(py::module &m) {
llvm::Twine("Unable to parse type: '") +
typeSpec + "'");
}
- return PyType(type);
+ return PyType(self.getRef(), type);
},
- py::keep_alive<0, 1>(), kContextParseTypeDocstring)
+ kContextParseTypeDocstring)
.def(
"get_unknown_location",
[](PyMlirContext &self) {
- return PyLocation(mlirLocationUnknownGet(self.context));
+ return PyLocation(self.getRef(),
+ mlirLocationUnknownGet(self.get()));
},
- py::keep_alive<0, 1>(), kContextGetUnknownLocationDocstring)
+ kContextGetUnknownLocationDocstring)
.def(
"get_file_location",
[](PyMlirContext &self, std::string filename, int line, int col) {
- return PyLocation(mlirLocationFileLineColGet(
- self.context, filename.c_str(), line, col));
+ return PyLocation(self.getRef(),
+ mlirLocationFileLineColGet(
+ self.get(), filename.c_str(), line, col));
},
- py::keep_alive<0, 1>(), kContextGetFileLocationDocstring,
- py::arg("filename"), py::arg("line"), py::arg("col"))
+ kContextGetFileLocationDocstring, py::arg("filename"),
+ py::arg("line"), py::arg("col"))
.def(
"create_region",
[](PyMlirContext &self) {
// The creating context is explicitly captured on regions to
// facilitate illegal assemblies of objects from multiple contexts
// that would invalidate the memory model.
- return PyRegion(self.context, mlirRegionCreate(),
+ return PyRegion(self.get(), mlirRegionCreate(),
/*detached=*/true);
},
py::keep_alive<0, 1>(), kContextCreateRegionDocstring)
@@ -893,7 +952,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// types must be from the same context.
for (auto pyType : pyTypes) {
if (!mlirContextEqual(mlirTypeGetContext(pyType.type),
- self.context)) {
+ self.get())) {
throw SetPyError(
PyExc_ValueError,
"All types used to construct a block must be from "
@@ -902,8 +961,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
}
llvm::SmallVector<MlirType, 4> types(pyTypes.begin(),
pyTypes.end());
- return PyBlock(self.context,
- mlirBlockCreate(types.size(), &types[0]),
+ return PyBlock(self.get(), mlirBlockCreate(types.size(), &types[0]),
/*detached=*/true);
},
py::keep_alive<0, 1>(), kContextCreateBlockDocstring);
@@ -1063,7 +1121,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def_property_readonly(
"attr",
[](PyNamedAttribute &self) {
- return PyAttribute(self.namedAttr.attribute);
+ // TODO: When named attribute is removed/refactored, also remove
+ // this constructor (it does an inefficient table lookup).
+ auto contextRef = PyMlirContext::forContext(
+ mlirAttributeGetContext(self.namedAttr.attribute));
+ return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
},
py::keep_alive<0, 1>(),
"The underlying generic attribute of the NamedAttribute binding");
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 20fe8014e138..fa52c3979359 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -12,6 +12,7 @@
#include <pybind11/pybind11.h>
#include "mlir-c/IR.h"
+#include "llvm/ADT/DenseMap.h"
namespace mlir {
namespace python {
@@ -19,28 +20,105 @@ namespace python {
class PyMlirContext;
class PyModule;
+/// Holds a C++ PyMlirContext and associated py::object, making it convenient
+/// to have an auto-releasing C++-side keep-alive reference to the context.
+/// The reference to the PyMlirContext is a simple C++ reference and the
+/// py::object holds the reference count which keeps it alive.
+class PyMlirContextRef {
+public:
+ PyMlirContextRef(PyMlirContext &referrent, pybind11::object object)
+ : referrent(referrent), object(std::move(object)) {}
+ ~PyMlirContextRef() {}
+
+ /// Releases the object held by this instance, causing its reference count
+ /// to remain artifically inflated by one. This must be used to return
+ /// the referenced PyMlirContext from a function. Otherwise, the destructor
+ /// of this reference would be called prior to the default take_ownership
+ /// policy assuming that the reference count has been transferred to it.
+ PyMlirContext *release();
+
+ PyMlirContext &operator->() { return referrent; }
+ pybind11::object getObject() { return object; }
+
+private:
+ PyMlirContext &referrent;
+ pybind11::object object;
+};
+
/// Wrapper around MlirContext.
class PyMlirContext {
public:
- PyMlirContext() { context = mlirContextCreate(); }
- ~PyMlirContext() { mlirContextDestroy(context); }
+ PyMlirContext() = delete;
+ PyMlirContext(const PyMlirContext &) = delete;
+ PyMlirContext(PyMlirContext &&) = delete;
+
+ /// Returns a context reference for the singleton PyMlirContext wrapper for
+ /// the given context.
+ static PyMlirContextRef forContext(MlirContext context);
+ ~PyMlirContext();
+
+ /// Accesses the underlying MlirContext.
+ MlirContext get() { return context; }
+
+ /// Gets a strong reference to this context, which will ensure it is kept
+ /// alive for the life of the reference.
+ PyMlirContextRef getRef() {
+ return PyMlirContextRef(
+ *this, pybind11::reinterpret_borrow<pybind11::object>(handle));
+ }
+
+ /// Gets the count of live context objects. Used for testing.
+ static size_t getLiveCount();
+
+private:
+ PyMlirContext(MlirContext context);
+
+ // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
+ // preserving the relationship that an MlirContext maps to a single
+ // PyMlirContext wrapper. This could be replaced in the future with an
+ // extension mechanism on the MlirContext for stashing user pointers.
+ // Note that this holds a handle, which does not imply ownership.
+ // Mappings will be removed when the context is destructed.
+ using LiveContextMap =
+ llvm::DenseMap<void *, std::pair<pybind11::handle, PyMlirContext *>>;
+ static LiveContextMap &getLiveContexts();
MlirContext context;
+ // The handle is set as part of lookup with forContext() (post construction).
+ pybind11::handle handle;
+};
+
+/// Base class for all objects that directly or indirectly depend on an
+/// MlirContext. The lifetime of the context will extend at least to the
+/// lifetime of these instances.
+/// Immutable objects that depend on a context extend this directly.
+class BaseContextObject {
+public:
+ BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {}
+
+ /// Accesses the context reference.
+ PyMlirContextRef &getContext() { return contextRef; }
+
+private:
+ PyMlirContextRef contextRef;
};
/// Wrapper around an MlirLocation.
-class PyLocation {
+class PyLocation : public BaseContextObject {
public:
- PyLocation(MlirLocation loc) : loc(loc) {}
+ PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
+ : BaseContextObject(std::move(contextRef)), loc(loc) {}
MlirLocation loc;
};
/// Wrapper around MlirModule.
-class PyModule {
+class PyModule : public BaseContextObject {
public:
- PyModule(MlirModule module) : module(module) {}
+ PyModule(PyMlirContextRef contextRef, MlirModule module)
+ : BaseContextObject(std::move(contextRef)), module(module) {}
PyModule(PyModule &) = delete;
- PyModule(PyModule &&other) {
+ PyModule(PyModule &&other)
+ : BaseContextObject(std::move(other.getContext())) {
module = other.module;
other.module.ptr = nullptr;
}
@@ -120,9 +198,10 @@ class PyBlock {
/// Wrapper around the generic MlirAttribute.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyAttribute {
+class PyAttribute : public BaseContextObject {
public:
- PyAttribute(MlirAttribute attr) : attr(attr) {}
+ PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
+ : BaseContextObject(std::move(contextRef)), attr(attr) {}
bool operator==(const PyAttribute &other);
MlirAttribute attr;
@@ -153,9 +232,10 @@ class PyNamedAttribute {
/// Wrapper around the generic MlirType.
/// The lifetime of a type is bound by the PyContext that created it.
-class PyType {
+class PyType : public BaseContextObject {
public:
- PyType(MlirType type) : type(type) {}
+ PyType(PyMlirContextRef contextRef, MlirType type)
+ : BaseContextObject(std::move(contextRef)), type(type) {}
bool operator==(const PyType &other);
operator MlirType() const { return type; }
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 8611d6537371..0304d977f494 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -48,6 +48,10 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
+MlirContext mlirLocationGetContext(MlirLocation location) {
+ return wrap(unwrap(location).getContext());
+}
+
void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData) {
detail::CallbackOstream stream(callback, userData);
@@ -70,6 +74,10 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
return MlirModule{owning.release().getOperation()};
}
+MlirContext mlirModuleGetContext(MlirModule module) {
+ return wrap(unwrap(module).getContext());
+}
+
void mlirModuleDestroy(MlirModule module) {
// Transfer ownership to an OwningModuleRef so that its destructor is called.
OwningModuleRef(unwrap(module));
@@ -349,6 +357,10 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context)));
}
+MlirContext mlirAttributeGetContext(MlirAttribute attribute) {
+ return wrap(unwrap(attribute).getContext());
+}
+
int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
diff --git a/mlir/test/Bindings/Python/context_lifecycle.py b/mlir/test/Bindings/Python/context_lifecycle.py
new file mode 100644
index 000000000000..e2b287061b22
--- /dev/null
+++ b/mlir/test/Bindings/Python/context_lifecycle.py
@@ -0,0 +1,42 @@
+# RUN: %PYTHON %s
+# Standalone sanity check of context life-cycle.
+import gc
+import mlir
+
+assert mlir.ir.Context._get_live_count() == 0
+
+# Create first context.
+print("CREATE C1")
+c1 = mlir.ir.Context()
+assert mlir.ir.Context._get_live_count() == 1
+c1_repr = repr(c1)
+print("C1 = ", c1_repr)
+
+print("GETTING AGAIN...")
+c2 = c1._get_context_again()
+c2_repr = repr(c2)
+assert mlir.ir.Context._get_live_count() == 1
+assert c1_repr == c2_repr
+
+print("C2 =", c2)
+
+# Make sure new contexts on constructor.
+print("CREATE C3")
+c3 = mlir.ir.Context()
+assert mlir.ir.Context._get_live_count() == 2
+c3_repr = repr(c3)
+print("C3 =", c3)
+assert c3_repr != c1_repr
+print("FREE C3")
+c3 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 1
+
+print("Free C1")
+c1 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 1
+print("Free C2")
+c2 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0
More information about the Mlir-commits
mailing list