[Mlir-commits] [mlir] 5fef6ce - [mlir][Python] Allow PassManager to interop with the capsule APIs.

Stella Laurenzo llvmlistbot at llvm.org
Wed Nov 11 10:37:50 PST 2020


Author: Stella Laurenzo
Date: 2020-11-11T10:37:21-08:00
New Revision: 5fef6ce0cce05a9fb05f47c9d62f3724377ea076

URL: https://github.com/llvm/llvm-project/commit/5fef6ce0cce05a9fb05f47c9d62f3724377ea076
DIFF: https://github.com/llvm/llvm-project/commit/5fef6ce0cce05a9fb05f47c9d62f3724377ea076.diff

LOG: [mlir][Python] Allow PassManager to interop with the capsule APIs.

* Used in npcomp to cast Python objects via the C-API.

Differential Revision: https://reviews.llvm.org/D91232

Added: 
    

Modified: 
    mlir/include/mlir-c/Bindings/Python/Interop.h
    mlir/include/mlir-c/Pass.h
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/test/Bindings/Python/pass_manager.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h
index 785ecce21804..dad51563a324 100644
--- a/mlir/include/mlir-c/Bindings/Python/Interop.h
+++ b/mlir/include/mlir-c/Bindings/Python/Interop.h
@@ -24,9 +24,11 @@
 #include <Python.h>
 
 #include "mlir-c/IR.h"
+#include "mlir-c/Pass.h"
 
 #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
 #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
+#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr"
 
 /** Attribute on MLIR Python objects that expose their C-API pointer.
  * This will be a type-specific capsule created as per one of the helpers
@@ -52,6 +54,14 @@
  * delineated). */
 #define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate"
 
+/// Gets a void* from a wrapped struct. Needed because const cast is 
diff erent
+/// between C/C++.
+#ifdef __cplusplus
+#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) const_cast<void *>(object.ptr)
+#else
+#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) (void *)(object.ptr)
+#endif
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -60,7 +70,7 @@ extern "C" {
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the context in any way.
  */
-inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
+static inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
   return PyCapsule_New(context.ptr, MLIR_PYTHON_CAPSULE_CONTEXT, NULL);
 }
 
@@ -68,7 +78,7 @@ inline PyObject *mlirPythonContextToCapsule(MlirContext context) {
  * mlirPythonContextToCapsule. If the capsule is not of the right type, then
  * a null context is returned (as checked via mlirContextIsNull). In such a
  * case, the Python APIs will have already set an error. */
-inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
+static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
   void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_CONTEXT);
   MlirContext context = {ptr};
   return context;
@@ -77,25 +87,39 @@ inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) {
 /** Creates a capsule object encapsulating the raw C-API MlirModule.
  * The returned capsule does not extend or affect ownership of any Python
  * objects that reference the module in any way. */
-inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
-#ifdef __cplusplus
-  void *ptr = const_cast<void *>(module.ptr);
-#else
-  void *ptr = (void *)ptr;
-#endif
-  return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL);
+static inline PyObject *mlirPythonModuleToCapsule(MlirModule module) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(module),
+                       MLIR_PYTHON_CAPSULE_MODULE, NULL);
 }
 
 /** Extracts an MlirModule from a capsule as produced from
  * mlirPythonModuleToCapsule. If the capsule is not of the right type, then
  * a null module is returned (as checked via mlirModuleIsNull). In such a
  * case, the Python APIs will have already set an error. */
-inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
+static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) {
   void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE);
   MlirModule module = {ptr};
   return module;
 }
 
+/** Creates a capsule object encapsulating the raw C-API MlirPassManager.
+ * The returned capsule does not extend or affect ownership of any Python
+ * objects that reference the module in any way. */
+static inline PyObject *mlirPythonPassManagerToCapsule(MlirPassManager pm) {
+  return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm),
+                       MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL);
+}
+
+/** Extracts an MlirPassManager from a capsule as produced from
+ * mlirPythonPassManagerToCapsule. If the capsule is not of the right type, then
+ * a null pass manager is returned (as checked via mlirPassManagerIsNull). */
+static inline MlirPassManager
+mlirPythonCapsuleToPassManager(PyObject *capsule) {
+  void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER);
+  MlirPassManager pm = {ptr};
+  return pm;
+}
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h
index 7e56dd3388c2..a059c4608197 100644
--- a/mlir/include/mlir-c/Pass.h
+++ b/mlir/include/mlir-c/Pass.h
@@ -53,6 +53,11 @@ MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx);
 /// Destroy the provided PassManager.
 MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager);
 
+/// Checks if a PassManager is null.
+static inline int mlirPassManagerIsNull(MlirPassManager passManager) {
+  return !passManager.ptr;
+}
+
 /// Cast a top-level PassManager to a generic OpPassManager.
 MLIR_CAPI_EXPORTED MlirOpPassManager
 mlirPassManagerGetAsOpPassManager(MlirPassManager passManager);

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
index cfd9c281e734..dd57647f0327 100644
--- a/mlir/lib/Bindings/Python/Pass.cpp
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -9,6 +9,7 @@
 #include "Pass.h"
 
 #include "IRModules.h"
+#include "mlir-c/Bindings/Python/Interop.h"
 #include "mlir-c/Pass.h"
 
 namespace py = pybind11;
@@ -21,9 +22,28 @@ namespace {
 class PyPassManager {
 public:
   PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
-  ~PyPassManager() { mlirPassManagerDestroy(passManager); }
+  PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
+    other.passManager.ptr = nullptr;
+  }
+  ~PyPassManager() {
+    if (!mlirPassManagerIsNull(passManager))
+      mlirPassManagerDestroy(passManager);
+  }
   MlirPassManager get() { return passManager; }
 
+  void release() { passManager.ptr = nullptr; }
+  pybind11::object getCapsule() {
+    return py::reinterpret_steal<py::object>(
+        mlirPythonPassManagerToCapsule(get()));
+  }
+
+  static pybind11::object createFromCapsule(pybind11::object capsule) {
+    MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
+    if (mlirPassManagerIsNull(rawPm))
+      throw py::error_already_set();
+    return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
+  }
+
 private:
   MlirPassManager passManager;
 };
@@ -43,6 +63,11 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
            }),
            py::arg("context") = py::none(),
            "Create a new PassManager for the current (or provided) Context.")
+      .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
+                             &PyPassManager::getCapsule)
+      .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
+      .def("_testing_release", &PyPassManager::release,
+           "Releases (leaks) the backing pass manager (testing)")
       .def_static(
           "parse",
           [](const std::string pipeline, DefaultingPyMlirContext context) {

diff  --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py
index 62a92484f0f4..bd269ce29896 100644
--- a/mlir/test/Bindings/Python/pass_manager.py
+++ b/mlir/test/Bindings/Python/pass_manager.py
@@ -16,6 +16,19 @@ def run(f):
   gc.collect()
   assert Context._get_live_count() == 0
 
+# Verify capsule interop.
+# CHECK-LABEL: TEST: testCapsule
+def testCapsule():
+  with Context():
+    pm = PassManager()
+    pm_capsule = pm._CAPIPtr
+    assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule)
+    pm._testing_release()
+    pm1 = PassManager._CAPICreate(pm_capsule)
+    assert pm1 is not None  # And does not crash.
+run(testCapsule)
+
+
 # Verify successful round-trip.
 # CHECK-LABEL: TEST: testParseSuccess
 def testParseSuccess():


        


More information about the Mlir-commits mailing list