[Mlir-commits] [mlir] 3f3d1c9 - [MLIR][Python] Add capsule methods for pybind11 to PyValue.
Mike Urbach
llvmlistbot at llvm.org
Tue Apr 27 19:14:23 PDT 2021
Author: Mike Urbach
Date: 2021-04-27T20:14:16-06:00
New Revision: 3f3d1c901d7abcc5b91468335679b1b27d8a02dd
URL: https://github.com/llvm/llvm-project/commit/3f3d1c901d7abcc5b91468335679b1b27d8a02dd
DIFF: https://github.com/llvm/llvm-project/commit/3f3d1c901d7abcc5b91468335679b1b27d8a02dd.diff
LOG: [MLIR][Python] Add capsule methods for pybind11 to PyValue.
Add the `getCapsule()` and `createFromCapsule()` methods to the
PyValue class, as well as the necessary interoperability.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D101090
Added:
mlir/test/Bindings/Python/ir_value.py
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index d853159468e0..882f73d84383 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -42,6 +42,7 @@
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_VALUE "mlir.ir.Value._CAPIPtr"
/** Attribute on MLIR Python objects that expose their C-API pointer.
* This will be a type-specific capsule created as per one of the helpers
@@ -285,6 +286,25 @@ mlirPythonCapsuleToExecutionEngine(PyObject *capsule) {
return jit;
}
+/** Creates a capsule object encapsulating the raw C-API MlirValue.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the operation in any way.
+ */
+static inline PyObject *mlirPythonValueToCapsule(MlirValue value) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(value),
+ MLIR_PYTHON_CAPSULE_VALUE, NULL);
+}
+
+/** Extracts an MlirValue from a capsule as produced from
+ * mlirPythonValueToCapsule. If the capsule is not of the right type, then a
+ * null type is returned (as checked via mlirValueIsNull). In such a case, the
+ * Python APIs will have already set an error. */
+static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_VALUE);
+ MlirValue value = {ptr};
+ return value;
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index a2655d9c4faf..b93786e05f15 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -15,6 +15,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
+#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "llvm/ADT/SmallVector.h"
#include <pybind11/stl.h>
@@ -1467,6 +1468,27 @@ PyType PyType::createFromCapsule(py::object capsule) {
// PyValue and subclases.
//------------------------------------------------------------------------------
+pybind11::object PyValue::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
+}
+
+PyValue PyValue::createFromCapsule(pybind11::object capsule) {
+ MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
+ if (mlirValueIsNull(value))
+ throw py::error_already_set();
+ MlirOperation owner;
+ if (mlirValueIsAOpResult(value))
+ owner = mlirOpResultGetOwner(value);
+ if (mlirValueIsABlockArgument(value))
+ owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
+ if (mlirOperationIsNull(owner))
+ throw py::error_already_set();
+ MlirContext ctx = mlirOperationGetContext(owner);
+ PyOperationRef ownerRef =
+ PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
+ return PyValue(ownerRef, value);
+}
+
namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
@@ -2353,6 +2375,8 @@ void mlir::python::populateIRCore(py::module &m) {
// Mapping of Value.
//----------------------------------------------------------------------------
py::class_<PyValue>(m, "Value")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
.def_property_readonly(
"context",
[](PyValue &self) { return self.getParentOperation()->getContext(); },
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f3f5ee5edf52..ff3faeefd994 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -721,6 +721,13 @@ class PyValue {
void checkValid() { return parentOperation->checkValid(); }
+ /// Gets a capsule wrapping the void* within the MlirValue.
+ pybind11::object getCapsule();
+
+ /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
+ /// the underlying MlirValue is still tied to the owning operation.
+ static PyValue createFromCapsule(pybind11::object capsule);
+
private:
PyOperationRef parentOperation;
MlirValue value;
diff --git a/mlir/test/Bindings/Python/ir_value.py b/mlir/test/Bindings/Python/ir_value.py
new file mode 100644
index 000000000000..3b88fee375a0
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_value.py
@@ -0,0 +1,27 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+from mlir.ir import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+# CHECK-LABEL: TEST: testCapsuleConversions
+def testCapsuleConversions():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ value = Operation.create("custom.op1", results=[i32]).result
+ value_capsule = value._CAPIPtr
+ assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
+ value2 = Value._CAPICreate(value_capsule)
+ assert value2 == value
+
+
+run(testCapsuleConversions)
More information about the Mlir-commits
mailing list