[Mlir-commits] [mlir] 3f3d1c9 - [MLIR][Python] Add capsule methods for pybind11 to PyValue.

Mike Urbach llvmlistbot at llvm.org
Tue Apr 27 19:14:23 PDT 2021


Author: Mike Urbach
Date: 2021-04-27T20:14:16-06:00
New Revision: 3f3d1c901d7abcc5b91468335679b1b27d8a02dd

URL: https://github.com/llvm/llvm-project/commit/3f3d1c901d7abcc5b91468335679b1b27d8a02dd
DIFF: https://github.com/llvm/llvm-project/commit/3f3d1c901d7abcc5b91468335679b1b27d8a02dd.diff

LOG: [MLIR][Python] Add capsule methods for pybind11 to PyValue.

Add the `getCapsule()` and `createFromCapsule()` methods to the
PyValue class, as well as the necessary interoperability.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D101090

Added: 
    mlir/test/Bindings/Python/ir_value.py

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index d853159468e0..882f73d84383 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -42,6 +42,7 @@
 #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_VALUE "mlir.ir.Value._CAPIPtr"
 
 /** Attribute on MLIR Python objects that expose their C-API pointer.
  * This will be a type-specific capsule created as per one of the helpers
@@ -285,6 +286,25 @@ mlirPythonCapsuleToExecutionEngine(PyObject *capsule) {
   return jit;
 }
 
+/** Creates a capsule object encapsulating the raw C-API MlirValue.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the operation in any way.
+ */
+static inline PyObject *mlirPythonValueToCapsule(MlirValue value) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(value),
+                       MLIR_PYTHON_CAPSULE_VALUE, NULL);
+}
+
+/** Extracts an MlirValue from a capsule as produced from
+ * mlirPythonValueToCapsule. If the capsule is not of the right type, then a
+ * null type is returned (as checked via mlirValueIsNull). In such a case, the
+ * Python APIs will have already set an error. */
+static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_VALUE);
+  MlirValue value = {ptr};
+  return value;
+}
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a2655d9c4faf..b93786e05f15 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -15,6 +15,7 @@
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Debug.h"
+#include "mlir-c/IR.h"
 #include "mlir-c/Registration.h"
 #include "llvm/ADT/SmallVector.h"
 #include <pybind11/stl.h>
@@ -1467,6 +1468,27 @@ PyType PyType::createFromCapsule(py::object capsule) {
 // PyValue and subclases.
 //------------------------------------------------------------------------------
 
+pybind11::object PyValue::getCapsule() {
+  return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
+}
+
+PyValue PyValue::createFromCapsule(pybind11::object capsule) {
+  MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
+  if (mlirValueIsNull(value))
+    throw py::error_already_set();
+  MlirOperation owner;
+  if (mlirValueIsAOpResult(value))
+    owner = mlirOpResultGetOwner(value);
+  if (mlirValueIsABlockArgument(value))
+    owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+  if (mlirOperationIsNull(owner))
+    throw py::error_already_set();
+  MlirContext ctx = mlirOperationGetContext(owner);
+  PyOperationRef ownerRef =
+      PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+  return PyValue(ownerRef, value);
+}
+
 namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2353,6 +2375,8 @@ void mlir::python::populateIRCore(py::module &m) {
   // Mapping of Value.
   //----------------------------------------------------------------------------
   py::class_<PyValue>(m, "Value")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
       .def_property_readonly(
           "context",
           [](PyValue &self) { return self.getParentOperation()->getContext(); },

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f3f5ee5edf52..ff3faeefd994 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -721,6 +721,13 @@ class PyValue {
 
   void checkValid() { return parentOperation->checkValid(); }
 
+  /// Gets a capsule wrapping the void* within the MlirValue.
+  pybind11::object getCapsule();
+
+  /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
+  /// the underlying MlirValue is still tied to the owning operation.
+  static PyValue createFromCapsule(pybind11::object capsule);
+
 private:
   PyOperationRef parentOperation;
   MlirValue value;

diff  --git a/mlir/test/Bindings/Python/ir_value.py b/mlir/test/Bindings/Python/ir_value.py
new file mode 100644
index 000000000000..3b88fee375a0
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_value.py
@@ -0,0 +1,27 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testCapsuleConversions
+def testCapsuleConversions():
+  ctx = Context()
+  ctx.allow_unregistered_dialects = True
+  with Location.unknown(ctx):
+    i32 = IntegerType.get_signless(32)
+    value = Operation.create("custom.op1", results=[i32]).result
+    value_capsule = value._CAPIPtr
+    assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
+    value2 = Value._CAPICreate(value_capsule)
+    assert value2 == value
+
+
+run(testCapsuleConversions)


        


More information about the Mlir-commits mailing list