[Mlir-commits] [mlir] 5fef6ce - [mlir][Python] Allow PassManager to interop with the capsule APIs.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Nov 11 10:37:50 PST 2020
Author: Stella Laurenzo
Date: 2020-11-11T10:37:21-08:00
New Revision: 5fef6ce0cce05a9fb05f47c9d62f3724377ea076
URL: https://github.com/llvm/llvm-project/commit/5fef6ce0cce05a9fb05f47c9d62f3724377ea076
DIFF: https://github.com/llvm/llvm-project/commit/5fef6ce0cce05a9fb05f47c9d62f3724377ea076.diff
LOG: [mlir][Python] Allow PassManager to interop with the capsule APIs.
* Used in npcomp to cast Python objects via the C-API.
Differential Revision: https://reviews.llvm.org/D91232
Added:
Modified:
mlir/include/mlir-c/Bindings/Python/Interop.h
mlir/include/mlir-c/Pass.h
mlir/lib/Bindings/Python/Pass.cpp
mlir/test/Bindings/Python/pass_manager.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 785ecce21804..dad51563a324 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -24,9 +24,11 @@
#include <Python.h>
#include "mlir-c/IR.h"
+#include "mlir-c/Pass.h"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr"
/** Attribute on MLIR Python objects that expose their C-API pointer.
* This will be a type-specific capsule created as per one of the helpers
@@ -52,6 +54,14 @@
* delineated). */
#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
+/// Gets a void* from a wrapped struct. Needed because const cast is
diff erent
+/// between C/C++.
+#ifdef __cplusplus
+#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) const_cast<void *>(object.ptr)
+#else
+#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) (void *)(object.ptr)
+#endif
+
#ifdef __cplusplus
extern "C" {
#endif
@@ -60,7 +70,7 @@ extern "C" {
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the context in any way.
*/
-inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
+static inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
return PyCapsule_New(context.ptr, MLIR_PYTHON_CAPSULE_CONTEXT, NULL);
}
@@ -68,7 +78,7 @@ inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
* mlirPythonContextToCapsule. If the capsule is not of the right type, then
* a null context is returned (as checked via mlirContextIsNull). In such a
* case, the Python APIs will have already set an error. */
-inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
+static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_CONTEXT);
MlirContext context = {ptr};
return context;
@@ -77,25 +87,39 @@ inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
/** Creates a capsule object encapsulating the raw C-API MlirModule.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the module in any way. */
-inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
-#ifdef __cplusplus
- void *ptr = const_cast<void *>(module.ptr);
-#else
- void *ptr = (void *)ptr;
-#endif
- return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL);
+static inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(module),
+ MLIR_PYTHON_CAPSULE_MODULE, NULL);
}
/** Extracts an MlirModule from a capsule as produced from
* mlirPythonModuleToCapsule. If the capsule is not of the right type, then
* a null module is returned (as checked via mlirModuleIsNull). In such a
* case, the Python APIs will have already set an error. */
-inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
+static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE);
MlirModule module = {ptr};
return module;
}
+/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+static inline PyObject *mlirPythonPassManagerToCapsule(MlirPassManager pm) {
+ return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
+ MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
+}
+
+/** Extracts an MlirPassManager from a capsule as produced from
+ * mlirPythonPassManagerToCapsule. If the capsule is not of the right type, then
+ * a null pass manager is returned (as checked via mlirPassManagerIsNull). */
+static inline MlirPassManager
+mlirPythonCapsuleToPassManager(PyObject *capsule) {
+ void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
+ MlirPassManager pm = {ptr};
+ return pm;
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 7e56dd3388c2..a059c4608197 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -53,6 +53,11 @@ MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
/// Destroy the provided PassManager.
MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager);
+/// Checks if a PassManager is null.
+static inline int mlirPassManagerIsNull(MlirPassManager passManager) {
+ return !passManager.ptr;
+}
+
/// Cast a top-level PassManager to a generic OpPassManager.
MLIR_CAPI_EXPORTED MlirOpPassManager
mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);
diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cfd9c281e734..dd57647f0327 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,6 +9,7 @@
#include "Pass.h"
#include "IRModules.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
namespace py = pybind11;
@@ -21,9 +22,28 @@ namespace {
class PyPassManager {
public:
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
- ~PyPassManager() { mlirPassManagerDestroy(passManager); }
+ PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
+ other.passManager.ptr = nullptr;
+ }
+ ~PyPassManager() {
+ if (!mlirPassManagerIsNull(passManager))
+ mlirPassManagerDestroy(passManager);
+ }
MlirPassManager get() { return passManager; }
+ void release() { passManager.ptr = nullptr; }
+ pybind11::object getCapsule() {
+ return py::reinterpret_steal<py::object>(
+ mlirPythonPassManagerToCapsule(get()));
+ }
+
+ static pybind11::object createFromCapsule(pybind11::object capsule) {
+ MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
+ if (mlirPassManagerIsNull(rawPm))
+ throw py::error_already_set();
+ return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
+ }
+
private:
MlirPassManager passManager;
};
@@ -43,6 +63,11 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
}),
py::arg("context") = py::none(),
"Create a new PassManager for the current (or provided) Context.")
+ .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+ &PyPassManager::getCapsule)
+ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
+ .def("_testing_release", &PyPassManager::release,
+ "Releases (leaks) the backing pass manager (testing)")
.def_static(
"parse",
[](const std::string pipeline, DefaultingPyMlirContext context) {
diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py
index 62a92484f0f4..bd269ce29896 100644
--- a/mlir/test/Bindings/Python/pass_manager.py
+++ b/mlir/test/Bindings/Python/pass_manager.py
@@ -16,6 +16,19 @@ def run(f):
gc.collect()
assert Context._get_live_count() == 0
+# Verify capsule interop.
+# CHECK-LABEL: TEST: testCapsule
+def testCapsule():
+ with Context():
+ pm = PassManager()
+ pm_capsule = pm._CAPIPtr
+ assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
+ pm._testing_release()
+ pm1 = PassManager._CAPICreate(pm_capsule)
+ assert pm1 is not None # And does not crash.
+run(testCapsule)
+
+
# Verify successful round-trip.
# CHECK-LABEL: TEST: testParseSuccess
def testParseSuccess():
More information about the Mlir-commits
mailing list