[Mlir-commits] [mlir] 0126e90 - [MLIR] [Python] Add capsule methods for pybind11 to PyOperation

John Demme llvmlistbot at llvm.org
Tue Apr 6 14:30:11 PDT 2021


Author: John Demme
Date: 2021-04-06T14:29:03-07:00
New Revision: 0126e906483c50c47db0687195e4b0216479846e

URL: https://github.com/llvm/llvm-project/commit/0126e906483c50c47db0687195e4b0216479846e
DIFF: https://github.com/llvm/llvm-project/commit/0126e906483c50c47db0687195e4b0216479846e.diff

LOG: [MLIR] [Python] Add capsule methods for pybind11 to PyOperation

Add the `getCapsule()` and `createFromCapsule()` methods to the PyOperation class.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D99927

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5046eedb11940..7a7bae92cf87a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -868,6 +868,19 @@ PyBlock PyOperation::getBlock() {
   return PyBlock{std::move(parentOperation), block};
 }
 
+py::object PyOperation::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
+}
+
+py::object PyOperation::createFromCapsule(py::object capsule) {
+  MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
+  if (mlirOperationIsNull(rawOperation))
+    throw py::error_already_set();
+  MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
+  return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
+      .releaseObject();
+}
+
 py::object PyOperation::create(
     std::string name, llvm::Optional<std::vector<PyType *>> results,
     llvm::Optional<std::vector<PyValue *>> operands,
@@ -2031,6 +2044,9 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
                   kOperationCreateDocstring)
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyOperation::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
       .def_property_readonly("name",
                              [](PyOperation &self) {
                                MlirOperation operation = self.get();

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 5c710abe789a1..861673abc7018 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -454,6 +454,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// no parent.
   PyOperationRef getParentOperation();
 
+  /// Gets a capsule wrapping the void* within the MlirOperation.
+  pybind11::object getCapsule();
+
+  /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
+  /// Ownership of the underlying MlirOperation is taken by calling this
+  /// function.
+  static pybind11::object createFromCapsule(pybind11::object capsule);
+
   /// Creates an operation. See corresponding python docstring.
   static pybind11::object
   create(std::string name, llvm::Optional<std::vector<PyType *>> results,

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index 847c1093cd37f..f7036cde771e1 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -601,3 +601,16 @@ def testOperationName():
     print(op.operation.name)
 
 run(testOperationName)
+
+# CHECK-LABEL: TEST: testCapsuleConversions
+def testCapsuleConversions():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    m = Operation.create("custom.op1").operation
+    m_capsule = m._CAPIPtr
+    assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule)
+    m2 = Operation._CAPICreate(m_capsule)
+    assert m2 is m
+
+run(testCapsuleConversions)


        


More information about the Mlir-commits mailing list