[Mlir-commits] [mlir] [mlir][python] Make the Context/Operation capsule creation methods work as documented. (PR #76010)
Stella Laurenzo
llvmlistbot at llvm.org
Tue Dec 19 21:29:20 PST 2023
https://github.com/stellaraccident created https://github.com/llvm/llvm-project/pull/76010
None
>From d1f34bad1b00e310e84572296240303e9b664529 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Tue, 19 Dec 2023 21:27:48 -0800
Subject: [PATCH] [mlir][python] Make the Context/Operation capsule creation
methods work as documented.
---
mlir/lib/Bindings/Python/IRCore.cpp | 78 +++++++++++++++++++++---
mlir/lib/Bindings/Python/IRModule.h | 19 +++++-
mlir/test/python/ir/context_lifecycle.py | 45 +++++++++++++-
mlir/test/python/ir/operation.py | 13 ----
4 files changed, 129 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5412c3dec4b1b6..39757dfad5be1d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
if (mlirContextIsNull(rawContext))
throw py::error_already_set();
- return forContext(rawContext).releaseObject();
+ return stealExternalContext(rawContext).releaseObject();
}
PyMlirContext *PyMlirContext::createNewContextForInit() {
@@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
auto &liveContexts = getLiveContexts();
auto it = liveContexts.find(context.ptr);
if (it == liveContexts.end()) {
- // Create.
- PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
- py::object pyRef = py::cast(unownedContextWrapper);
- assert(pyRef && "cast to py::object failed");
- liveContexts[context.ptr] = unownedContextWrapper;
- return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+ throw std::runtime_error(
+ "Cannot use a context that is not owned by the Python bindings.");
}
+
// Use existing.
py::object pyRef = py::cast(it->second);
return PyMlirContextRef(it->second, std::move(pyRef));
}
+PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) {
+ py::gil_scoped_acquire acquire;
+ auto &liveContexts = getLiveContexts();
+ auto it = liveContexts.find(context.ptr);
+ if (it != liveContexts.end()) {
+ throw std::runtime_error(
+ "Cannot transfer ownership of the context to Python "
+ "as it is already owned by Python.");
+ }
+
+ PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+ // Note that the default return value policy on cast is automatic_reference,
+ // which does not take ownership (delete will not be called).
+ // Just be explicit.
+ py::object pyRef =
+ py::cast(unownedContextWrapper, py::return_value_policy::take_ownership);
+ assert(pyRef && "cast to py::object failed");
+ return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+}
+
PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
static LiveContextMap liveContexts;
return liveContexts;
@@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
return PyOperationRef(existing, std::move(pyRef));
}
+PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef,
+ MlirOperation operation) {
+ auto &liveOperations = contextRef->liveOperations;
+ auto it = liveOperations.find(operation.ptr);
+ if (it != liveOperations.end()) {
+ throw std::runtime_error(
+ "Cannot transfer ownership of the operation to Python "
+ "as it is already owned by Python.");
+ }
+ return createInstance(std::move(contextRef), operation, py::none());
+}
+
PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
py::object parentKeepAlive) {
@@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
if (mlirOperationIsNull(rawOperation))
throw py::error_already_set();
MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
- return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
+ return stealExternalOperation(PyMlirContext::forContext(rawCtxt),
+ rawOperation)
.releaseObject();
}
@@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) {
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
+ .def_static("_testing_create_raw_context_capsule",
+ []() {
+ // Creates an MlirContext not known to the Python bindings
+ // and puts it in a capsule. Used to test interop. Using
+ // this without passing it back to the capsule creation
+ // API will leak.
+ return py::reinterpret_steal<py::object>(
+ mlirPythonContextToCapsule(
+ mlirContextCreateWithThreading(false)));
+ })
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(
- &PyOperationBase::print),
+ bool, py::object, bool>(&PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
@@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) {
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyOperation::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
+ .def_static(
+ "_testing_create_raw_capsule",
+ [](std::string sourceStr) {
+ // Creates a raw context and an operation via parsing the given
+ // source and returns them in a capsule. Error handling is
+ // minimal as this is purely intended for testing interop with
+ // operation creation from capsule functions.
+ MlirContext context = mlirContextCreateWithThreading(false);
+ MlirOperation op = mlirOperationCreateParse(
+ context, toMlirStringRef(sourceStr), toMlirStringRef("temp"));
+ if (mlirOperationIsNull(op)) {
+ mlirContextDestroy(context);
+ throw std::invalid_argument("Failed to parse");
+ }
+ return py::make_tuple(py::reinterpret_steal<py::object>(
+ mlirPythonContextToCapsule(context)),
+ py::reinterpret_steal<py::object>(
+ mlirPythonOperationToCapsule(op)));
+ })
.def_property_readonly("operation", [](py::object self) { return self; })
.def_property_readonly("opview", &PyOperation::createOpView)
.def_property_readonly(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 79b7e0c96188c1..04164b78b3e250 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -176,8 +176,19 @@ class PyMlirContext {
static PyMlirContext *createNewContextForInit();
/// Returns a context reference for the singleton PyMlirContext wrapper for
- /// the given context.
+ /// the given context. It is only valid to call this on an MlirContext that
+ /// is already owned by the Python bindings. Typically this will be because
+ /// it came in some fashion from createNewContextForInit(). However, it
+ /// is also possible to explicitly transfer ownership of an existing
+ /// MlirContext to the Python bindings via stealExternalContext().
static PyMlirContextRef forContext(MlirContext context);
+
+ /// Explicitly takes ownership of an MlirContext that must not already be
+ /// known to the Python bindings. Once done, the life-cycle of the context
+ /// will be controlled by the Python bindings, and it will be destroyed
+ /// when the reference count goes to zero.
+ static PyMlirContextRef stealExternalContext(MlirContext context);
+
~PyMlirContext();
/// Accesses the underlying MlirContext.
@@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
+ /// Explicitly takes ownership of an operation that must not already be known
+ /// to the Python bindings. Once done, the life-cycle of the operation
+ /// will be controlled by the Python bindings.
+ static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef,
+ MlirOperation operation);
+
/// Creates a detached operation. The operation must not be associated with
/// any existing live operation.
static PyOperationRef
diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425ee..fbd1851ba70aee 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -45,5 +45,46 @@
c4 = mlir.ir.Context()
c4_capsule = c4._CAPIPtr
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
-c5 = mlir.ir.Context._CAPICreate(c4_capsule)
-assert c4 is c5
+# Because the context is already owned by Python, it cannot be created
+# a second time.
+try:
+ c5 = mlir.ir.Context._CAPICreate(c4_capsule)
+except RuntimeError:
+ pass
+else:
+ raise AssertionError(
+ "Should have gotten a RuntimeError when attempting to "
+ "re-create an already owned context"
+ )
+c4 = None
+c4_capsule = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0
+
+# Use a private testing method to create an unowned context capsule and
+# import it.
+c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule()
+c6 = mlir.ir.Context._CAPICreate(c6_capsule)
+assert mlir.ir.Context._get_live_count() == 1
+c6_capsule = None
+c6 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0
+
+# Also test operation import/export as it is tightly coupled to the context.
+(
+ raw_context_capsule,
+ raw_operation_capsule,
+) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}")
+assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule)
+# Attempting to import an operation for an unknown context should fail.
+try:
+ mlir.ir.Operation._CAPICreate(raw_operation_capsule)
+except RuntimeError:
+ pass
+else:
+ raise AssertionError("Expected exception for unknown context")
+
+# Try again having imported the context.
+c7 = mlir.ir.Context._CAPICreate(raw_context_capsule)
+op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..f59b1a26ba48b5 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -844,19 +844,6 @@ def testOperationName():
print(op.operation.name)
-# CHECK-LABEL: TEST: testCapsuleConversions
- at run
-def testCapsuleConversions():
- ctx = Context()
- ctx.allow_unregistered_dialects = True
- with Location.unknown(ctx):
- m = Operation.create("custom.op1").operation
- m_capsule = m._CAPIPtr
- assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
- m2 = Operation._CAPICreate(m_capsule)
- assert m2 is m
-
-
# CHECK-LABEL: TEST: testOperationErase
@run
def testOperationErase():
More information about the Mlir-commits
mailing list