[Mlir-commits] [mlir] ad958f6 - [mlir][Python] Add missing capsule->module and Context.create_module.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Oct 13 13:17:21 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-13T13:10:33-07:00
New Revision: ad958f648e46680966375a93a3f2f1f5ee870671
URL: https://github.com/llvm/llvm-project/commit/ad958f648e46680966375a93a3f2f1f5ee870671
DIFF: https://github.com/llvm/llvm-project/commit/ad958f648e46680966375a93a3f2f1f5ee870671.diff
LOG: [mlir][Python] Add missing capsule->module and Context.create_module.
* Extends Context/Operation interning to cover Module as well.
* Implements Module.context, Attribute.context, Type.context, and Location.context back-references (facilitated testing and also on the TODO list).
* Adds method to create an empty Module.
* Discovered missing in npcomp.
Differential Revision: https://reviews.llvm.org/D89294
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/test/Bindings/Python/ir_attributes.py
mlir/test/Bindings/Python/ir_location.py
mlir/test/Bindings/Python/ir_module.py
mlir/test/Bindings/Python/ir_operation.py
mlir/test/Bindings/Python/ir_types.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 24b2a8b9de39..acb168c3fb73 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -86,6 +86,16 @@ inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL);
}
+/** Extracts an MlirModule from a capsule as produced from
+ * mlirPythonModuleToCapsule. If the capsule is not of the right type, then
+ * a null module is returned (as checked via mlirModuleIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE);
+ MlirModule module = {ptr};
+ return module;
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 36e25eebfc71..8f525e8b6239 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -497,6 +497,8 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
+size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
+
py::object PyMlirContext::createOperation(
std::string name, PyLocation location,
llvm::Optional<std::vector<PyType *>> results,
@@ -582,15 +584,49 @@ py::object PyMlirContext::createOperation(
// PyModule
//------------------------------------------------------------------------------
-PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
- PyModule *unownedModule = new PyModule(std::move(contextRef), module);
- // Note that the default return value policy on cast is automatic_reference,
- // which does not take ownership (delete will not be called).
- // Just be explicit.
- py::object pyRef =
- py::cast(unownedModule, py::return_value_policy::take_ownership);
- unownedModule->handle = pyRef;
- return PyModuleRef(unownedModule, std::move(pyRef));
+PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
+ : BaseContextObject(std::move(contextRef)), module(module) {}
+
+PyModule::~PyModule() {
+ py::gil_scoped_acquire acquire;
+ auto &liveModules = getContext()->liveModules;
+ assert(liveModules.count(module.ptr) == 1 &&
+ "destroying module not in live map");
+ liveModules.erase(module.ptr);
+ mlirModuleDestroy(module);
+}
+
+PyModuleRef PyModule::forModule(MlirModule module) {
+ MlirContext context = mlirModuleGetContext(module);
+ PyMlirContextRef contextRef = PyMlirContext::forContext(context);
+
+ py::gil_scoped_acquire acquire;
+ auto &liveModules = contextRef->liveModules;
+ auto it = liveModules.find(module.ptr);
+ if (it == liveModules.end()) {
+ // Create.
+ PyModule *unownedModule = new PyModule(std::move(contextRef), module);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedModule, py::return_value_policy::take_ownership);
+ unownedModule->handle = pyRef;
+ liveModules[module.ptr] =
+ std::make_pair(unownedModule->handle, unownedModule);
+ return PyModuleRef(unownedModule, std::move(pyRef));
+ }
+ // Use existing.
+ PyModule *existing = it->second.second;
+ py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
+ return PyModuleRef(existing, std::move(pyRef));
+}
+
+py::object PyModule::createFromCapsule(py::object capsule) {
+ MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
+ if (mlirModuleIsNull(rawModule))
+ throw py::error_already_set();
+ return forModule(rawModule).releaseObject();
}
py::object PyModule::getCapsule() {
@@ -1461,6 +1497,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -1489,9 +1526,16 @@ void mlir::python::populateIRSubmodule(py::module &m) {
PyExc_ValueError,
"Unable to parse module assembly (see diagnostics)");
}
- return PyModule::create(self.getRef(), module).releaseObject();
+ return PyModule::forModule(module).releaseObject();
},
kContextParseDocstring)
+ .def(
+ "create_module",
+ [](PyMlirContext &self, PyLocation &loc) {
+ MlirModule module = mlirModuleCreateEmpty(loc.loc);
+ return PyModule::forModule(module).releaseObject();
+ },
+ py::arg("loc"), "Creates an empty module")
.def(
"parse_attr",
[](PyMlirContext &self, std::string attrSpec) {
@@ -1538,16 +1582,26 @@ void mlir::python::populateIRSubmodule(py::module &m) {
kContextGetFileLocationDocstring, py::arg("filename"),
py::arg("line"), py::arg("col"));
- py::class_<PyLocation>(m, "Location").def("__repr__", [](PyLocation &self) {
- PyPrintAccumulator printAccum;
- mlirLocationPrint(self.loc, printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
- });
+ py::class_<PyLocation>(m, "Location")
+ .def_property_readonly(
+ "context",
+ [](PyLocation &self) { return self.getContext().getObject(); },
+ "Context that owns the Location")
+ .def("__repr__", [](PyLocation &self) {
+ PyPrintAccumulator printAccum;
+ mlirLocationPrint(self.loc, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ });
// Mapping of Module
py::class_<PyModule>(m, "Module")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
+ .def_property_readonly(
+ "context",
+ [](PyModule &self) { return self.getContext().getObject(); },
+ "Context that created the Module")
.def_property_readonly(
"operation",
[](PyModule &self) {
@@ -1576,6 +1630,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Operation.
py::class_<PyOperation>(m, "Operation")
+ .def_property_readonly(
+ "context",
+ [](PyOperation &self) { return self.getContext().getObject(); },
+ "Context that owns the Operation")
.def_property_readonly(
"regions",
[](PyOperation &self) { return PyRegionList(self.getRef()); })
@@ -1657,6 +1715,10 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Type.
py::class_<PyAttribute>(m, "Attribute")
+ .def_property_readonly(
+ "context",
+ [](PyAttribute &self) { return self.getContext().getObject(); },
+ "Context that owns the Attribute")
.def(
"get_named",
[](PyAttribute &self, std::string name) {
@@ -1737,6 +1799,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Type.
py::class_<PyType>(m, "Type")
+ .def_property_readonly(
+ "context", [](PyType &self) { return self.getContext().getObject(); },
+ "Context that owns the Type")
.def("__eq__",
[](PyType &self, py::object &other) {
try {
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index e67142e56c00..c175018c8bb6 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -113,7 +113,8 @@ class PyMlirContext {
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
- /// may be a pre-existing object.
+ /// may be a pre-existing object. Ownership of the underlying MlirContext
+ /// is taken by calling this function.
static pybind11::object createFromCapsule(pybind11::object capsule);
/// Gets the count of live context objects. Used for testing.
@@ -123,6 +124,10 @@ class PyMlirContext {
/// Used for testing.
size_t getLiveOperationCount();
+ /// Gets the count of live modules associated with this context.
+ /// Used for testing.
+ size_t getLiveModuleCount();
+
/// Creates an operation. See corresponding python docstring.
pybind11::object
createOperation(std::string name, PyLocation location,
@@ -142,6 +147,14 @@ class PyMlirContext {
using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
static LiveContextMap &getLiveContexts();
+ // Interns all live modules associated with this context. Modules tracked
+ // in this map are valid. When a module is invalidated, it is removed
+ // from this map, and while it still exists as an instance, any
+ // attempt to access it will raise an error.
+ using LiveModuleMap =
+ llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
+ LiveModuleMap liveModules;
+
// Interns all live operations associated with this context. Operations
// tracked in this map are valid. When an operation is invalidated, it is
// removed from this map, and while it still exists as an instance, any
@@ -151,6 +164,7 @@ class PyMlirContext {
LiveOperationMap liveOperations;
MlirContext context;
+ friend class PyModule;
friend class PyOperation;
};
@@ -186,13 +200,12 @@ class PyModule;
using PyModuleRef = PyObjectRef<PyModule>;
class PyModule : public BaseContextObject {
public:
- /// Creates a reference to the module
- static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module);
+ /// Returns a PyModule reference for the given MlirModule. This may return
+ /// a pre-existing or new object.
+ static PyModuleRef forModule(MlirModule module);
PyModule(PyModule &) = delete;
- ~PyModule() {
- if (module.ptr)
- mlirModuleDestroy(module);
- }
+ PyModule(PyMlirContext &&) = delete;
+ ~PyModule();
/// Gets the backing MlirModule.
MlirModule get() { return module; }
@@ -209,9 +222,14 @@ class PyModule : public BaseContextObject {
/// instances, which is not currently done.
pybind11::object getCapsule();
+ /// Creates a PyModule from the MlirModule wrapped by a capsule.
+ /// Note that PyModule instances are uniqued, so the returned object
+ /// may be a pre-existing object. Ownership of the underlying MlirModule
+ /// is taken by calling this function.
+ static pybind11::object createFromCapsule(pybind11::object capsule);
+
private:
- PyModule(PyMlirContextRef contextRef, MlirModule module)
- : BaseContextObject(std::move(contextRef)), module(module) {}
+ PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
pybind11::handle handle;
};
diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py
index dfdc81909a9a..bf99a7686b17 100644
--- a/mlir/test/Bindings/Python/ir_attributes.py
+++ b/mlir/test/Bindings/Python/ir_attributes.py
@@ -14,6 +14,7 @@ def run(f):
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_attr('"hello"')
+ assert t.context is ctx
ctx = None
gc.collect()
# CHECK: "hello"
diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py
index ac42c61a0723..f7e99242b8ad 100644
--- a/mlir/test/Bindings/Python/ir_location.py
+++ b/mlir/test/Bindings/Python/ir_location.py
@@ -14,6 +14,7 @@ def run(f):
def testUnknown():
ctx = mlir.ir.Context()
loc = ctx.get_unknown_location()
+ assert loc.context is ctx
ctx = None
gc.collect()
# CHECK: unknown str: loc(unknown)
diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py
index d85a415308ae..5f3403809e83 100644
--- a/mlir/test/Bindings/Python/ir_module.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -16,6 +16,7 @@ def run(f):
def testParseSuccess():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert module.context is ctx
print("CLEAR CONTEXT")
ctx = None # Ensure that module captures the context.
gc.collect()
@@ -40,6 +41,21 @@ def testParseError():
run(testParseError)
+# Verify successful parse.
+# CHECK-LABEL: TEST: testCreateEmpty
+# CHECK: module {
+def testCreateEmpty():
+ ctx = mlir.ir.Context()
+ loc = ctx.get_unknown_location()
+ module = ctx.create_module(loc)
+ print("CLEAR CONTEXT")
+ ctx = None # Ensure that module captures the context.
+ gc.collect()
+ print(str(module))
+
+run(testCreateEmpty)
+
+
# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
@@ -61,6 +77,7 @@ def testRoundtripUnicode():
def testModuleOperation():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert ctx._get_live_module_count() == 1
op1 = module.operation
assert ctx._get_live_operation_count() == 1
# CHECK: module @successfulParse
@@ -82,6 +99,7 @@ def testModuleOperation():
gc.collect()
print("LIVE OPERATIONS:", ctx._get_live_operation_count())
assert ctx._get_live_operation_count() == 0
+ assert ctx._get_live_module_count() == 0
run(testModuleOperation)
@@ -90,7 +108,19 @@ def testModuleOperation():
def testModuleCapsule():
ctx = mlir.ir.Context()
module = ctx.parse_module(r"""module @successfulParse {}""")
+ assert ctx._get_live_module_count() == 1
# CHECK: "mlir.ir.Module._CAPIPtr"
- print(module._CAPIPtr)
+ module_capsule = module._CAPIPtr
+ print(module_capsule)
+ module_dup = mlir.ir.Module._CAPICreate(module_capsule)
+ assert module is module_dup
+ assert module_dup.context is ctx
+ # Gc and verify destructed.
+ module = None
+ module_capsule = None
+ module_dup = None
+ gc.collect()
+ assert ctx._get_live_module_count() == 0
+
run(testModuleCapsule)
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 881398e1eba3..37b830558528 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -23,6 +23,7 @@ def testTraverseOpRegionBlockIterators():
}
""")
op = module.operation
+ assert op.context is ctx
# Get the block using iterators off of the named collections.
regions = list(op.regions)
blocks = list(regions[0].blocks)
diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py
index d8ae77f1f092..5a9c5a16bc92 100644
--- a/mlir/test/Bindings/Python/ir_types.py
+++ b/mlir/test/Bindings/Python/ir_types.py
@@ -14,6 +14,7 @@ def run(f):
def testParsePrint():
ctx = mlir.ir.Context()
t = ctx.parse_type("i32")
+ assert t.context is ctx
ctx = None
gc.collect()
# CHECK: i32
More information about the Mlir-commits
mailing list