[Mlir-commits] [mlir] 543922c - Adds MLIR C-API for marshaling Python capsules.
Stella Laurenzo
llvmlistbot at llvm.org
Tue Sep 29 10:49:27 PDT 2020
Author: Stella Laurenzo
Date: 2020-09-29T10:48:53-07:00
New Revision: 543922cd3630ca3a1e06a6a946d148bc0e22e720
URL: https://github.com/llvm/llvm-project/commit/543922cd3630ca3a1e06a6a946d148bc0e22e720
DIFF: https://github.com/llvm/llvm-project/commit/543922cd3630ca3a1e06a6a946d148bc0e22e720.diff
LOG: Adds MLIR C-API for marshaling Python capsules.
* Providing stable, C-accessible definitions for bridging MLIR Python<->C APIs, we eliminate inter-extension dependencies (i.e. they can all share a diamond dependency on the MLIR C-API).
* Just provides accessors for context and module right now.
* Needed in NPComp in ~a week or so for high level Torch APIs.
Differential Revision: https://reviews.llvm.org/D88426
Added:
mlir/include/mlir-c/Bindings/Python/Interop.h
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/test/Bindings/Python/context_lifecycle.py
mlir/test/Bindings/Python/ir_module.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
new file mode 100644
index 000000000000..24b2a8b9de39
--- /dev/null
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -0,0 +1,93 @@
+/*===-- mlir-c/Interop.h - Constants for Python/C-API interop -----*- C -*-===*\
+|* *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
+|* *|
+|*===----------------------------------------------------------------------===*|
+|* *|
+|* This header declares constants and helpers necessary for C-level *|
+|* interop with the MLIR Python extension module. Since the Python bindings *|
+|* are a thin wrapper around the MLIR C-API, a further C-API is not provided *|
+|* specifically for the Python extension. Instead, simple facilities are *|
+|* provided for translating between Python types and corresponding MLIR C-API *|
+|* types. *|
+|* *|
+|* This header is standalone, requiring nothing beyond normal linking against *|
+|* the Python implementation. *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_BINDINGS_PYTHON_INTEROP_H
+#define MLIR_C_BINDINGS_PYTHON_INTEROP_H
+
+#include <Python.h>
+
+#include "mlir-c/IR.h"
+
+#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
+
+/** Attribute on MLIR Python objects that expose their C-API pointer.
+ * This will be a type-specific capsule created as per one of the helpers
+ * below.
+ *
+ * Ownership is not transferred by acquiring a capsule in this way: the
+ * validity of the pointer wrapped by the capsule will be bounded by the
+ * lifetime of the Python object that produced it. Only the name and pointer
+ * of the capsule are set. The caller is free to set a destructor and context
+ * as needed to manage anything further. */
+#define MLIR_PYTHON_CAPI_PTR_ATTR "_CAPIPtr"
+
+/** Attribute on MLIR Python objects that exposes a factory function for
+ * constructing the corresponding Python object from a type-specific
+ * capsule wrapping the C-API pointer. The signature of the function is:
+ * def _CAPICreate(capsule) -> object
+ * Calling such a function implies a transfer of ownership of the object the
+ * capsule wraps: after such a call, the capsule should be considered invalid,
+ * and its wrapped pointer must not be destroyed.
+ *
+ * Only a very small number of Python objects can be created in such a fashion
+ * (i.e. top-level types such as Context where the lifetime can be cleanly
+ * delineated). */
+#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/** Creates a capsule object encapsulating the raw C-API MlirContext.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the context in any way.
+ */
+inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
+ return PyCapsule_New(context.ptr, MLIR_PYTHON_CAPSULE_CONTEXT, NULL);
+}
+
+/** Extracts a MlirContext from a capsule as produced from
+ * mlirPythonContextToCapsule. If the capsule is not of the right type, then
+ * a null context is returned (as checked via mlirContextIsNull). In such a
+ * case, the Python APIs will have already set an error. */
+inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_CONTEXT);
+ MlirContext context = {ptr};
+ return context;
+}
+
+/** Creates a capsule object encapsulating the raw C-API MlirModule.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
+#ifdef __cplusplus
+ void *ptr = const_cast<void *>(module.ptr);
+#else
+ void *ptr = (void *)ptr;
+#endif
+ return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL);
+}
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_BINDINGS_PYTHON_INTEROP_H
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 82149c7fce06..c751da804097 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -91,6 +91,9 @@ MlirContext mlirContextCreate();
/** Checks if two contexts are equal. */
int mlirContextEqual(MlirContext ctx1, MlirContext ctx2);
+/** Checks whether a context is null. */
+inline int mlirContextIsNull(MlirContext context) { return !context.ptr; }
+
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index f3bd96856d09..8d64b2d8de0a 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -9,6 +9,7 @@
#include "IRModules.h"
#include "PybindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
@@ -453,6 +454,17 @@ PyMlirContext::~PyMlirContext() {
mlirContextDestroy(context);
}
+py::object PyMlirContext::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
+}
+
+py::object PyMlirContext::createFromCapsule(py::object capsule) {
+ MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
+ if (mlirContextIsNull(rawContext))
+ throw py::error_already_set();
+ return forContext(rawContext).releaseObject();
+}
+
PyMlirContext *PyMlirContext::createNewContextForInit() {
MlirContext context = mlirContextCreate();
mlirRegisterAllDialects(context);
@@ -581,6 +593,10 @@ PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) {
return PyModuleRef(unownedModule, std::move(pyRef));
}
+py::object PyModule::getCapsule() {
+ return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
+}
+
//------------------------------------------------------------------------------
// PyOperation
//------------------------------------------------------------------------------
@@ -1345,6 +1361,9 @@ void mlir::python::populateIRSubmodule(py::module &m) {
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyMlirContext::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def_property(
"allow_unregistered_dialects",
[](PyMlirContext &self) -> bool {
@@ -1428,6 +1447,7 @@ void mlir::python::populateIRSubmodule(py::module &m) {
// Mapping of Module
py::class_<PyModule>(m, "Module")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
.def_property_readonly(
"operation",
[](PyModule &self) {
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 41b18d216026..e67142e56c00 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -108,6 +108,14 @@ class PyMlirContext {
return PyMlirContextRef(this, pybind11::cast(this));
}
+ /// Gets a capsule wrapping the void* within the MlirContext.
+ pybind11::object getCapsule();
+
+ /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
+ /// Note that PyMlirContext instances are uniqued, so the returned object
+ /// may be a pre-existing object.
+ static pybind11::object createFromCapsule(pybind11::object capsule);
+
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
@@ -195,6 +203,12 @@ class PyModule : public BaseContextObject {
pybind11::reinterpret_borrow<pybind11::object>(handle));
}
+ /// Gets a capsule wrapping the void* within the MlirModule.
+ /// Note that the module does not (yet) provide a corresponding factory for
+ /// constructing from a capsule as that would require uniquing PyModule
+ /// instances, which is not currently done.
+ pybind11::object getCapsule();
+
private:
PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}
diff --git a/mlir/test/Bindings/Python/context_lifecycle.py b/mlir/test/Bindings/Python/context_lifecycle.py
index e2b287061b22..460f41ccd4a7 100644
--- a/mlir/test/Bindings/Python/context_lifecycle.py
+++ b/mlir/test/Bindings/Python/context_lifecycle.py
@@ -40,3 +40,10 @@
c2 = None
gc.collect()
assert mlir.ir.Context._get_live_count() == 0
+
+# Create a context, get its capsule and create from capsule.
+c4 = mlir.ir.Context()
+c4_capsule = c4._CAPIPtr
+assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
+c5 = mlir.ir.Context._CAPICreate(c4_capsule)
+assert c4 is c5
diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py
index 614e1af8b8e7..d85a415308ae 100644
--- a/mlir/test/Bindings/Python/ir_module.py
+++ b/mlir/test/Bindings/Python/ir_module.py
@@ -84,3 +84,13 @@ def testModuleOperation():
assert ctx._get_live_operation_count() == 0
run(testModuleOperation)
+
+
+# CHECK-LABEL: TEST: testModuleCapsule
+def testModuleCapsule():
+ ctx = mlir.ir.Context()
+ module = ctx.parse_module(r"""module @successfulParse {}""")
+ # CHECK: "mlir.ir.Module._CAPIPtr"
+ print(module._CAPIPtr)
+
+run(testModuleCapsule)
More information about the Mlir-commits
mailing list