[Mlir-commits] [mlir] 78bd124 - Revert "[mlir][python] Make the Context/Operation capsule creation methods work as documented. (#76010)"

Alex Zinenko llvmlistbot at llvm.org
Thu Dec 21 02:06:51 PST 2023


Author: Alex Zinenko
Date: 2023-12-21T10:06:44Z
New Revision: 78bd124649ece163d3a26b33608bdbe518d8ff76

URL: https://github.com/llvm/llvm-project/commit/78bd124649ece163d3a26b33608bdbe518d8ff76
DIFF: https://github.com/llvm/llvm-project/commit/78bd124649ece163d3a26b33608bdbe518d8ff76.diff

LOG: Revert "[mlir][python] Make the Context/Operation capsule creation methods work as documented. (#76010)"

This reverts commit bbc29768683b394b34600347f46be2b8245ddb30.

This change seems to be at odds with the non-owning part semantics of
MlirOperation in C API. Since downstream clients can only take and
return MlirOperation, it does not sound correct to force all returns of
MlirOperation transfer ownership. Specifically, this makes it impossible
for downstreams to implement IR-traversing functions that, e.g., look at
neighbors of an operation.

The following patch triggers the exception, and there does not seem to
be an alternative way for a downstream binding writer to express this:

```
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 39757dfad5be..2ce640674245 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3071,6 +3071,11 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
                   py::arg("infer_type") = false, kOperationCreateDocstring)
+      .def("_get_first_in_block", [](PyOperation &self) -> MlirOperation {
+        MlirBlock block = mlirOperationGetBlock(self.get());
+        MlirOperation first = mlirBlockGetFirstOperation(block);
+        return first;
+      })
       .def_static(
           "parse",
           [](const std::string &sourceStr, const std::string &sourceName,
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f59b1a26ba48..6b12b8da5c24 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -24,6 +24,25 @@ def expect_index_error(callback):
     except IndexError:
         pass

+ at run
+def testCustomBind():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    module = Module.parse(
+        r"""
+    func.func @f1(%arg0: i32) -> i32 {
+      %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32
+      return %1 : i32
+    }
+  """,
+        ctx,
+    )
+    add = module.body.operations[0].regions[0].blocks[0].operations[0]
+    op = add.operation
+    # This will get a reference to itself.
+    f1 = op._get_first_in_block()
+
+

 # Verify iterator based traversal of the op/region/block hierarchy.
 # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
```

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/python/ir/context_lifecycle.py
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 39757dfad5be1d..5412c3dec4b1b6 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) {
   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
   if (mlirContextIsNull(rawContext))
     throw py::error_already_set();
-  return stealExternalContext(rawContext).releaseObject();
+  return forContext(rawContext).releaseObject();
 }
 
 PyMlirContext *PyMlirContext::createNewContextForInit() {
@@ -615,35 +615,18 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
   auto &liveContexts = getLiveContexts();
   auto it = liveContexts.find(context.ptr);
   if (it == liveContexts.end()) {
-    throw std::runtime_error(
-        "Cannot use a context that is not owned by the Python bindings.");
+    // Create.
+    PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
+    py::object pyRef = py::cast(unownedContextWrapper);
+    assert(pyRef && "cast to py::object failed");
+    liveContexts[context.ptr] = unownedContextWrapper;
+    return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
   }
-
   // Use existing.
   py::object pyRef = py::cast(it->second);
   return PyMlirContextRef(it->second, std::move(pyRef));
 }
 
-PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) {
-  py::gil_scoped_acquire acquire;
-  auto &liveContexts = getLiveContexts();
-  auto it = liveContexts.find(context.ptr);
-  if (it != liveContexts.end()) {
-    throw std::runtime_error(
-        "Cannot transfer ownership of the context to Python "
-        "as it is already owned by Python.");
-  }
-
-  PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
-  // Note that the default return value policy on cast is automatic_reference,
-  // which does not take ownership (delete will not be called).
-  // Just be explicit.
-  py::object pyRef =
-      py::cast(unownedContextWrapper, py::return_value_policy::take_ownership);
-  assert(pyRef && "cast to py::object failed");
-  return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
-}
-
 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
   static LiveContextMap liveContexts;
   return liveContexts;
@@ -1162,18 +1145,6 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
   return PyOperationRef(existing, std::move(pyRef));
 }
 
-PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef,
-                                                   MlirOperation operation) {
-  auto &liveOperations = contextRef->liveOperations;
-  auto it = liveOperations.find(operation.ptr);
-  if (it != liveOperations.end()) {
-    throw std::runtime_error(
-        "Cannot transfer ownership of the operation to Python "
-        "as it is already owned by Python.");
-  }
-  return createInstance(std::move(contextRef), operation, py::none());
-}
-
 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
                                            MlirOperation operation,
                                            py::object parentKeepAlive) {
@@ -1345,8 +1316,7 @@ py::object PyOperation::createFromCapsule(py::object capsule) {
   if (mlirOperationIsNull(rawOperation))
     throw py::error_already_set();
   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
-  return stealExternalOperation(PyMlirContext::forContext(rawCtxt),
-                                rawOperation)
+  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
       .releaseObject();
 }
 
@@ -2578,16 +2548,6 @@ void mlir::python::populateIRCore(py::module &m) {
       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
-      .def_static("_testing_create_raw_context_capsule",
-                  []() {
-                    // Creates an MlirContext not known to the Python bindings
-                    // and puts it in a capsule. Used to test interop. Using
-                    // this without passing it back to the capsule creation
-                    // API will leak.
-                    return py::reinterpret_steal<py::object>(
-                        mlirPythonContextToCapsule(
-                            mlirContextCreateWithThreading(false)));
-                  })
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyMlirContext::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
@@ -3013,7 +2973,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(&PyOperationBase::print),
+                             bool, py::object, bool>(
+               &PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
@@ -3085,25 +3046,6 @@ void mlir::python::populateIRCore(py::module &m) {
       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
                              &PyOperation::getCapsule)
       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
-      .def_static(
-          "_testing_create_raw_capsule",
-          [](std::string sourceStr) {
-            // Creates a raw context and an operation via parsing the given
-            // source and returns them in a capsule. Error handling is
-            // minimal as this is purely intended for testing interop with
-            // operation creation from capsule functions.
-            MlirContext context = mlirContextCreateWithThreading(false);
-            MlirOperation op = mlirOperationCreateParse(
-                context, toMlirStringRef(sourceStr), toMlirStringRef("temp"));
-            if (mlirOperationIsNull(op)) {
-              mlirContextDestroy(context);
-              throw std::invalid_argument("Failed to parse");
-            }
-            return py::make_tuple(py::reinterpret_steal<py::object>(
-                                      mlirPythonContextToCapsule(context)),
-                                  py::reinterpret_steal<py::object>(
-                                      mlirPythonOperationToCapsule(op)));
-          })
       .def_property_readonly("operation", [](py::object self) { return self; })
       .def_property_readonly("opview", &PyOperation::createOpView)
       .def_property_readonly(

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 04164b78b3e250..79b7e0c96188c1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -176,19 +176,8 @@ class PyMlirContext {
   static PyMlirContext *createNewContextForInit();
 
   /// Returns a context reference for the singleton PyMlirContext wrapper for
-  /// the given context. It is only valid to call this on an MlirContext that
-  /// is already owned by the Python bindings. Typically this will be because
-  /// it came in some fashion from createNewContextForInit(). However, it
-  /// is also possible to explicitly transfer ownership of an existing
-  /// MlirContext to the Python bindings via stealExternalContext().
+  /// the given context.
   static PyMlirContextRef forContext(MlirContext context);
-
-  /// Explicitly takes ownership of an MlirContext that must not already be
-  /// known to the Python bindings. Once done, the life-cycle of the context
-  /// will be controlled by the Python bindings, and it will be destroyed
-  /// when the reference count goes to zero.
-  static PyMlirContextRef stealExternalContext(MlirContext context);
-
   ~PyMlirContext();
 
   /// Accesses the underlying MlirContext.
@@ -617,12 +606,6 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
                pybind11::object parentKeepAlive = pybind11::object());
 
-  /// Explicitly takes ownership of an operation that must not already be known
-  /// to the Python bindings. Once done, the life-cycle of the operation
-  /// will be controlled by the Python bindings.
-  static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef,
-                                               MlirOperation operation);
-
   /// Creates a detached operation. The operation must not be associated with
   /// any existing live operation.
   static PyOperationRef

diff  --git a/mlir/test/python/ir/context_lifecycle.py b/mlir/test/python/ir/context_lifecycle.py
index fbd1851ba70aee..c20270999425ee 100644
--- a/mlir/test/python/ir/context_lifecycle.py
+++ b/mlir/test/python/ir/context_lifecycle.py
@@ -45,46 +45,5 @@
 c4 = mlir.ir.Context()
 c4_capsule = c4._CAPIPtr
 assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
-# Because the context is already owned by Python, it cannot be created
-# a second time.
-try:
-    c5 = mlir.ir.Context._CAPICreate(c4_capsule)
-except RuntimeError:
-    pass
-else:
-    raise AssertionError(
-        "Should have gotten a RuntimeError when attempting to "
-        "re-create an already owned context"
-    )
-c4 = None
-c4_capsule = None
-gc.collect()
-assert mlir.ir.Context._get_live_count() == 0
-
-# Use a private testing method to create an unowned context capsule and
-# import it.
-c6_capsule = mlir.ir.Context._testing_create_raw_context_capsule()
-c6 = mlir.ir.Context._CAPICreate(c6_capsule)
-assert mlir.ir.Context._get_live_count() == 1
-c6_capsule = None
-c6 = None
-gc.collect()
-assert mlir.ir.Context._get_live_count() == 0
-
-# Also test operation import/export as it is tightly coupled to the context.
-(
-    raw_context_capsule,
-    raw_operation_capsule,
-) = mlir.ir.Operation._testing_create_raw_capsule("builtin.module {}")
-assert '"mlir.ir.Operation._CAPIPtr"' in repr(raw_operation_capsule)
-# Attempting to import an operation for an unknown context should fail.
-try:
-    mlir.ir.Operation._CAPICreate(raw_operation_capsule)
-except RuntimeError:
-    pass
-else:
-    raise AssertionError("Expected exception for unknown context")
-
-# Try again having imported the context.
-c7 = mlir.ir.Context._CAPICreate(raw_context_capsule)
-op7 = mlir.ir.Operation._CAPICreate(raw_operation_capsule)
+c5 = mlir.ir.Context._CAPICreate(c4_capsule)
+assert c4 is c5

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index f59b1a26ba48b5..04f8a9936e31f7 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -844,6 +844,19 @@ def testOperationName():
         print(op.operation.name)
 
 
+# CHECK-LABEL: TEST: testCapsuleConversions
+ at run
+def testCapsuleConversions():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        m = Operation.create("custom.op1").operation
+        m_capsule = m._CAPIPtr
+        assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
+        m2 = Operation._CAPICreate(m_capsule)
+        assert m2 is m
+
+
 # CHECK-LABEL: TEST: testOperationErase
 @run
 def testOperationErase():


        


More information about the Mlir-commits mailing list