[Mlir-commits] [mlir] fcd2969 - Initial MLIR python bindings based on the C API.

Stella Laurenzo llvmlistbot at llvm.org
Sun Aug 16 19:36:52 PDT 2020


Author: zhanghb97
Date: 2020-08-16T19:34:25-07:00
New Revision: fcd2969da9e04a70103bfbf8a382c0842fcf6aaf

URL: https://github.com/llvm/llvm-project/commit/fcd2969da9e04a70103bfbf8a382c0842fcf6aaf
DIFF: https://github.com/llvm/llvm-project/commit/fcd2969da9e04a70103bfbf8a382c0842fcf6aaf.diff

LOG: Initial MLIR python bindings based on the C API.

* Basic support for context creation, module parsing and dumping.

Differential Revision: https://reviews.llvm.org/D85481

Added: 
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/test/Bindings/Python/ir_test.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/MainModule.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index ed08da968812..fab4061bcb9f 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -17,8 +17,11 @@ endif()
 # The actual extension library produces a shared-object or DLL and has
 # sources that must be compiled in accordance with pybind11 needs (RTTI and
 # exceptions).
+# TODO: Link the libraries separately once a helper function is available
+# to more generically add a pybind11 compliant library.
 add_library(MLIRBindingsPythonExtension ${PYEXT_LINK_MODE}
   MainModule.cpp
+  IRModules.cpp
 )
 
 target_include_directories(MLIRBindingsPythonExtension PRIVATE
@@ -66,5 +69,7 @@ endif()
 target_link_libraries(MLIRBindingsPythonExtension
   PRIVATE
   MLIRIR
+  MLIRCAPIIR
+  MLIRCAPIRegistration
   ${PYEXT_LIBADD}
 )

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
new file mode 100644
index 000000000000..b3b790af5516
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -0,0 +1,36 @@
+//===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModules.h"
+
+//------------------------------------------------------------------------------
+// Context Wrapper Class.
+//------------------------------------------------------------------------------
+
+PyMlirModule PyMlirContext::parse(const std::string &module) {
+  auto moduleRef = mlirModuleCreateParse(context, module.c_str());
+  return PyMlirModule(moduleRef);
+}
+
+//------------------------------------------------------------------------------
+// Module Wrapper Class.
+//------------------------------------------------------------------------------
+
+void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
+
+//------------------------------------------------------------------------------
+// Populates the pybind11 IR submodule.
+//------------------------------------------------------------------------------
+
+void populateIRSubmodule(py::module &m) {
+  py::class_<PyMlirContext>(m, "MlirContext")
+      .def(py::init<>())
+      .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>());
+
+  py::class_<PyMlirModule>(m, "MlirModule").def("dump", &PyMlirModule::dump);
+}

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
new file mode 100644
index 000000000000..6bfd00656df9
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -0,0 +1,53 @@
+//===- IRModules.h - IR Submodules of pybind module -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
+#define MLIR_BINDINGS_PYTHON_IRMODULES_H
+
+#include <pybind11/pybind11.h>
+
+#include "mlir-c/IR.h"
+
+namespace py = pybind11;
+
+class PyMlirContext;
+class PyMlirModule;
+
+/// Wrapper around MlirContext.
+class PyMlirContext {
+public:
+  PyMlirContext() { context = mlirContextCreate(); }
+  ~PyMlirContext() { mlirContextDestroy(context); }
+  /// Parses the module from asm.
+  PyMlirModule parse(const std::string &module);
+
+  MlirContext context;
+};
+
+/// Wrapper around MlirModule.
+class PyMlirModule {
+public:
+  PyMlirModule(MlirModule module) : module(module) {}
+  PyMlirModule(PyMlirModule &) = delete;
+  PyMlirModule(PyMlirModule &&other) {
+    module = other.module;
+    other.module.ptr = nullptr;
+  }
+  ~PyMlirModule() {
+    if (module.ptr)
+      mlirModuleDestroy(module);
+  }
+  /// Dumps the module.
+  void dump();
+
+  MlirModule module;
+};
+
+void populateIRSubmodule(py::module &m);
+
+#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b6d7abc4512f..760aa86a7b1c 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -10,6 +10,7 @@
 
 #include <pybind11/pybind11.h>
 
+#include "IRModules.h"
 #include "mlir/IR/MLIRContext.h"
 
 using namespace mlir;
@@ -24,4 +25,8 @@ PYBIND11_MODULE(_mlir, m) {
     return std::make_tuple(std::string("From the native module"),
                            context.isMultithreadingEnabled());
   });
+
+  // Define and populate IR submodule.
+  auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
+  populateIRSubmodule(irModule);
 }

diff  --git a/mlir/test/Bindings/Python/ir_test.py b/mlir/test/Bindings/Python/ir_test.py
new file mode 100644
index 000000000000..1dfb356f956b
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_test.py
@@ -0,0 +1,14 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+TEST_MLIR_ASM = r"""
+module {
+}
+"""
+
+ctx = mlir.ir.MlirContext()
+module = ctx.parse(TEST_MLIR_ASM)
+module.dump()
+print(bool(module))
+# CHECK: True


        


More information about the Mlir-commits mailing list