[Mlir-commits] [mlir] [mlir][python] Make the Context/Operation capsule creation methods work as documented. (PR #76010)

Stella Laurenzo llvmlistbot at llvm.org
Tue Dec 19 21:29:20 PST 2023


https://github.com/stellaraccident created https://github.com/llvm/llvm-project/pull/76010

None

>From d1f34bad1b00e310e84572296240303e9b664529 Mon Sep 17 00:00:00 2001
From: Stella Laurenzo <stellaraccident at gmail.com>
Date: Tue, 19 Dec 2023 21:27:48 -0800
Subject: [PATCH] [mlir][python] Make the Context/Operation capsule creation
 methods work as documented.

---
 mlir/lib/Bindings/Python/IRCore.cpp      | 78 +++++++++++++++++++++---
 mlir/lib/Bindings/Python/IRModule.h      | 19 +++++-
 mlir/test/python/ir/context_lifecycle.py | 45 +++++++++++++-
 mlir/test/python/ir/operation.py         | 13 ----
 4 files changed, 129 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5412c3dec4b1b6..39757dfad5be1d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
   if (mlirContextIsNull(rawContext))
     throw py::error_already_set();
-  return forContext(rawContext).releaseObject();
+  return stealExternalContext(rawContext).releaseObject();
 }
 
 PyMlirContext *PyMlirContext::createNewContextForInit() {
@@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   auto &liveContexts = getLiveContexts();
   auto it = liveContexts.find(context.ptr);
   if (it == liveContexts.end()) {
-    // Create.
-    PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
-    py::object pyRef = py::cast(unownedContextWrapper);
-    assert(pyRef && "cast to py::object failed");
-    liveContexts[context.ptr] = unownedContextWrapper;
-    return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+    throw std::runtime_error(
+        "Cannot use a context that is not owned by the Python bindings.");
   }
+
   // Use existing.
   py::object pyRef = py::cast(it->second);
   return PyMlirContextRef(it->second, std::move(pyRef));
 }
 
+PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) {
+  py::gil_scoped_acquire acquire;
+  auto &liveContexts = getLiveContexts();
+  auto it = liveContexts.find(context.ptr);
+  if (it != liveContexts.end()) {
+    throw std::runtime_error(
+        "Cannot transfer ownership of the context to Python "
+        "as it is already owned by Python.");
+  }
+
+  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+  // Note that the default return value policy on cast is automatic_reference,
+  // which does not take ownership (delete will not be called).
+  // Just be explicit.
+  py::object pyRef =
+      py::cast(unownedContextWrapper, py::return_value_policy::take_ownership);
+  assert(pyRef && "cast to py::object failed");
+  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
+}
+
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
   static LiveContextMap liveContexts;
   return liveContexts;
@@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
   return PyOperationRef(existing, std::move(pyRef));
 }
 
+PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef,
+                                                   MlirOperation operation) {
+  auto &liveOperations = contextRef->liveOperations;
+  auto it = liveOperations.find(operation.ptr);
+  if (it != liveOperations.end()) {
+    throw std::runtime_error(
+        "Cannot transfer ownership of the operation to Python "
+        "as it is already owned by Python.");
+  }
+  return createInstance(std::move(contextRef), operation, py::none());
+}
+
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            py::object parentKeepAlive) {
@@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
   if (mlirOperationIsNull(rawOperation))
     throw py::error_already_set();
   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
-  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
+  return stealExternalOperation(PyMlirContext::forContext(rawCtxt),
+                                rawOperation)
       .releaseObject();
 }
 
@@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
+      .def_static("_testing_create_raw_context_capsule",
+                  []() {
+                    // Creates an MlirContext not known to the Python bindings
+                    // and puts it in a capsule. Used to test interop. Using
+                    // this without passing it back to the capsule creation
+                    // API will leak.
+                    return py::reinterpret_steal<py::object>(
+                        mlirPythonContextToCapsule(
+                            mlirContextCreateWithThreading(false)));
+                  })
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(
-               &PyOperationBase::print),
+                             bool, py::object, bool>(&PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
@@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyOperation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
+      .def_static(
+          "_testing_create_raw_capsule",
+          [](std::string sourceStr) {
+            // Creates a raw context and an operation via parsing the given
+            // source and returns them in a capsule. Error handling is
+            // minimal as this is purely intended for testing interop with
+            // operation creation from capsule functions.
+            MlirContext context = mlirContextCreateWithThreading(false);
+            MlirOperation op = mlirOperationCreateParse(
+                context, toMlirStringRef(sourceStr), toMlirStringRef("temp"));
+            if (mlirOperationIsNull(op)) {
+              mlirContextDestroy(context);
+              throw std::invalid_argument("Failed to parse");
+            }
+            return py::make_tuple(py::reinterpret_steal<py::object>(
+                                      mlirPythonContextToCapsule(context)),
+                                  py::reinterpret_steal<py::object>(
+                                      mlirPythonOperationToCapsule(op)));
+          })
       .def_property_readonly("operation", [](py::object self) { return self; })
       .def_property_readonly("opview", &PyOperation::createOpView)
       .def_property_readonly(
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 79b7e0c96188c1..04164b78b3e250 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -176,8 +176,19 @@ class PyMlirContext {
   static PyMlirContext *createNewContextForInit();
 
   /// Returns a context reference for the singleton PyMlirContext wrapper for
-  /// the given context.
+  /// the given context. It is only valid to call this on an MlirContext that
+  /// is already owned by the Python bindings. Typically this will be because
+  /// it came in some fashion from createNewContextForInit(). However, it
+  /// is also possible to explicitly transfer ownership of an existing
+  /// MlirContext to the Python bindings via stealExternalContext().
   static PyMlirContextRef forContext(MlirContext context);
+
+  /// Explicitly takes ownership of an MlirContext that must not already be
+  /// known to the Python bindings. Once done, the life-cycle of the context
+  /// will be controlled by the Python bindings, and it will be destroyed
+  /// when the reference count goes to zero.
+  static PyMlirContextRef stealExternalContext(MlirContext context);
+
   ~PyMlirContext();
 
   /// Accesses the underlying MlirContext.
@@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
                pybind11::object parentKeepAlive = pybind11::object());
 
+  /// Explicitly takes ownership of an operation that must not already be known
+  /// to the Python bindings. Once done, the life-cycle of the operation
+  /// will be controlled by the Python bindings.
+  static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef,
+                                               MlirOperation operation);
+
   /// Creates a detached operation. The operation must not be associated with
   /// any existing live operation.
   static PyOperationRef
diff --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index c20270999425ee..fbd1851ba70aee 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -45,5 +45,46 @@
 c4 = mlir.ir.Context()
 c4_capsule = c4._CAPIPtr
 assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
-c5 = mlir.ir.Context._CAPICreate(c4_capsule)
-assert c4 is c5
+# Because the context is already owned by Python, it cannot be created
+# a second time.
+try:
+    c5 = mlir.ir.Context._CAPICreate(c4_capsule)
+except RuntimeError:
+    pass
+else:
+    raise AssertionError(
+        "Should have gotten a RuntimeError when attempting to "
+        "re-create an already owned context"
+    )
+c4 = None
+c4_capsule = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0
+
+# Use a private testing method to create an unowned context capsule and
+# import it.
+c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule()
+c6 = mlir.ir.Context._CAPICreate(c6_capsule)
+assert mlir.ir.Context._get_live_count() == 1
+c6_capsule = None
+c6 = None
+gc.collect()
+assert mlir.ir.Context._get_live_count() == 0
+
+# Also test operation import/export as it is tightly coupled to the context.
+(
+    raw_context_capsule,
+    raw_operation_capsule,
+) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}")
+assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule)
+# Attempting to import an operation for an unknown context should fail.
+try:
+    mlir.ir.Operation._CAPICreate(raw_operation_capsule)
+except RuntimeError:
+    pass
+else:
+    raise AssertionError("Expected exception for unknown context")
+
+# Try again having imported the context.
+c7 = mlir.ir.Context._CAPICreate(raw_context_capsule)
+op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule)
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 04f8a9936e31f7..f59b1a26ba48b5 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -844,19 +844,6 @@ def testOperationName():
         print(op.operation.name)
 
 
-# CHECK-LABEL: TEST: testCapsuleConversions
- at run
-def testCapsuleConversions():
-    ctx = Context()
-    ctx.allow_unregistered_dialects = True
-    with Location.unknown(ctx):
-        m = Operation.create("custom.op1").operation
-        m_capsule = m._CAPIPtr
-        assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
-        m2 = Operation._CAPICreate(m_capsule)
-        assert m2 is m
-
-
 # CHECK-LABEL: TEST: testOperationErase
 @run
 def testOperationErase():



More information about the Mlir-commits mailing list