[Mlir-commits] [mlir] dc43f78 - Add basic Python bindings for the PassManager and bind libTransforms

Mehdi Amini llvmlistbot at llvm.org
Tue Nov 10 11:55:32 PST 2020


Author: Mehdi Amini
Date: 2020-11-10T19:55:21Z
New Revision: dc43f78565491471604103715cc4329abfca0f6d

URL: https://github.com/llvm/llvm-project/commit/dc43f78565491471604103715cc4329abfca0f6d
DIFF: https://github.com/llvm/llvm-project/commit/dc43f78565491471604103715cc4329abfca0f6d.diff

LOG: Add basic Python bindings for the PassManager and bind libTransforms

This only exposes the ability to round-trip a textual pipeline at the
moment.
To exercise it, we also bind the libTransforms in a new Python extension. This
does not include any interesting bindings, but it includes all the
mechanism to add separate native extensions and load them dynamically.
As such passes in libTransforms are only registered after `import
mlir.transforms`.
To support this global registration, the TableGen backend is also
extended to bind to the C API the group registration for passes.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90819

Added: 
    mlir/lib/Bindings/Python/Pass.cpp
    mlir/lib/Bindings/Python/Pass.h
    mlir/lib/Bindings/Python/Transforms/CMakeLists.txt
    mlir/lib/Bindings/Python/Transforms/Transforms.cpp
    mlir/lib/Bindings/Python/mlir/passmanager.py
    mlir/lib/Bindings/Python/mlir/transforms/__init__.py
    mlir/test/Bindings/Python/pass_manager.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/Bindings/Python/mlir/__init__.py
    mlir/test/CMakeLists.txt
    mlir/tools/mlir-tblgen/PassCAPIGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 499d684c076b..bf6f1d8d5567 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -6,8 +6,10 @@ add_custom_target(MLIRBindingsPythonExtension)
 
 set(PY_SRC_FILES
   mlir/__init__.py
-  mlir/ir.py
   mlir/dialects/__init__.py
+  mlir/ir.py
+  mlir/passmanager.py
+  mlir/transforms/__init__.py
 )
 
 add_custom_target(MLIRBindingsPythonSources ALL
@@ -38,6 +40,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
     MainModule.cpp
     IRModules.cpp
     PybindUtils.cpp
+    Pass.cpp
 )
 add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension)
 
@@ -57,3 +60,5 @@ if (NOT LLVM_ENABLE_IDE)
     DEPENDS MLIRBindingsPythonSources
     COMPONENT MLIRBindingsPythonSources)
 endif()
+
+add_subdirectory(Transforms)

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index b2c1bafa5d69..1f4b69dc4b7f 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,7 @@
 
 #include "Globals.h"
 #include "IRModules.h"
+#include "Pass.h"
 
 namespace py = pybind11;
 using namespace mlir;
@@ -210,4 +211,9 @@ PYBIND11_MODULE(_mlir, m) {
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
   populateIRSubmodule(irModule);
+
+  // Define and populate PassManager submodule.
+  auto passModule =
+      m.def_submodule("passmanager", "MLIR Pass Management Bindings");
+  populatePassManagerSubmodule(passModule);
 }

diff  --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp
new file mode 100644
index 000000000000..228f69b6deba
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Pass.cpp
@@ -0,0 +1,75 @@
+//===- Pass.cpp - Pass Management -----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Pass.h"
+
+#include "IRModules.h"
+#include "mlir-c/Pass.h"
+
+namespace py = pybind11;
+using namespace mlir;
+using namespace mlir::python;
+
+namespace {
+
+/// Owning Wrapper around a PassManager.
+class PyPassManager {
+public:
+  PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
+  ~PyPassManager() { mlirPassManagerDestroy(passManager); }
+  MlirPassManager get() { return passManager; }
+
+private:
+  MlirPassManager passManager;
+};
+
+} // anonymous namespace
+
+/// Create the `mlir.passmanager` here.
+void mlir::python::populatePassManagerSubmodule(py::module &m) {
+  //----------------------------------------------------------------------------
+  // Mapping of the top-level PassManager
+  //----------------------------------------------------------------------------
+  py::class_<PyPassManager>(m, "PassManager")
+      .def(py::init<>([](DefaultingPyMlirContext context) {
+             MlirPassManager passManager =
+                 mlirPassManagerCreate(context->get());
+             return new PyPassManager(passManager);
+           }),
+           py::arg("context") = py::none(),
+           "Create a new PassManager for the current (or provided) Context.")
+      .def_static(
+          "parse",
+          [](const std::string pipeline, DefaultingPyMlirContext context) {
+            MlirPassManager passManager = mlirPassManagerCreate(context->get());
+            MlirLogicalResult status = mlirParsePassPipeline(
+                mlirPassManagerGetAsOpPassManager(passManager),
+                mlirStringRefCreate(pipeline.data(), pipeline.size()));
+            if (mlirLogicalResultIsFailure(status))
+              throw SetPyError(PyExc_ValueError,
+                               llvm::Twine("invalid pass pipeline '") +
+                                   pipeline + "'.");
+            return new PyPassManager(passManager);
+          },
+          py::arg("pipeline"), py::arg("context") = py::none(),
+          "Parse a textual pass-pipeline and return a top-level PassManager "
+          "that can be applied on a Module. Throw a ValueError if the pipeline "
+          "can't be parsed")
+      .def(
+          "__str__",
+          [](PyPassManager &self) {
+            MlirPassManager passManager = self.get();
+            PyPrintAccumulator printAccum;
+            mlirPrintPassPipeline(
+                mlirPassManagerGetAsOpPassManager(passManager),
+                printAccum.getCallback(), printAccum.getUserData());
+            return printAccum.join();
+          },
+          "Print the textual representation for this PassManager, suitable to "
+          "be passed to `parse` for round-tripping.");
+}

diff  --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h
new file mode 100644
index 000000000000..550ff47c396d
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Pass.h
@@ -0,0 +1,22 @@
+//===- Pass.h - PassManager Submodules of pybind module -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_PASS_H
+#define MLIR_BINDINGS_PYTHON_PASS_H
+
+#include "PybindUtils.h"
+
+namespace mlir {
+namespace python {
+
+void populatePassManagerSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_PASS_H
\ No newline at end of file

diff  --git a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000..8b53f03d42b9
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt
@@ -0,0 +1,10 @@
+################################################################################
+# Build python extension
+################################################################################
+
+add_mlir_python_extension(MLIRTransformsBindingsPythonExtension _mlirTransforms
+  INSTALL_DIR
+    python
+  SOURCES
+  Transforms.cpp
+)
\ No newline at end of file

diff  --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp
new file mode 100644
index 000000000000..46c4691923c7
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp
@@ -0,0 +1,24 @@
+//===- Transforms.cpp - Pybind module for the Transforms library ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Transforms.h"
+
+#include <pybind11/pybind11.h>
+
+namespace py = pybind11;
+
+// -----------------------------------------------------------------------------
+// Module initialization.
+// -----------------------------------------------------------------------------
+
+PYBIND11_MODULE(_mlirTransforms, m) {
+  m.doc() = "MLIR Transforms library";
+
+  // Register all the passes in the Transforms library on load.
+  mlirRegisterTransformsPasses();
+}

diff  --git a/mlir/lib/Bindings/Python/mlir/__init__.py b/mlir/lib/Bindings/Python/mlir/__init__.py
index 8f3b52c30f35..c63c4332be68 100644
--- a/mlir/lib/Bindings/Python/mlir/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/__init__.py
@@ -10,6 +10,7 @@
 
 __all__ = [
   "ir",
+  "passmanager",
 ]
 
 # Expose the corresponding C-Extension module with a well-known name at this
@@ -38,7 +39,7 @@ def _reexport_cext(cext_module_name, target_module_name):
 
 # Import sub-modules. Since these may import from here, this must come after
 # any exported definitions.
-from . import ir
+from . import ir, passmanager
 
 # Add our 'dialects' parent module to the search path for implementations.
 _cext.globals.append_dialect_search_prefix("mlir.dialects")

diff  --git a/mlir/lib/Bindings/Python/mlir/passmanager.py b/mlir/lib/Bindings/Python/mlir/passmanager.py
new file mode 100644
index 000000000000..95119e52f971
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/passmanager.py
@@ -0,0 +1,8 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Simply a wrapper around the extension module of the same name.
+from . import _reexport_cext
+_reexport_cext("passmanager", __name__)
+del _reexport_cext

diff  --git a/mlir/lib/Bindings/Python/mlir/transforms/__init__.py b/mlir/lib/Bindings/Python/mlir/transforms/__init__.py
new file mode 100644
index 000000000000..d6172521295a
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/transforms/__init__.py
@@ -0,0 +1,8 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# Expose the corresponding C-Extension module with a well-known name at this
+# level.
+import _mlirTransforms as _cextTransforms
+

diff  --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py
new file mode 100644
index 000000000000..c8e041c2eae0
--- /dev/null
+++ b/mlir/test/Bindings/Python/pass_manager.py
@@ -0,0 +1,54 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+
+# Log everything to stderr and flush so that we have a unified stream to match
+# errors emitted by MLIR to stderr. TODO: this shouldn't be needed when
+# everything is plumbed.
+def log(*args):
+  print(*args, file=sys.stderr)
+  sys.stderr.flush()
+
+def run(f):
+  log("\nTEST:", f.__name__)
+  f()
+  gc.collect()
+  assert Context._get_live_count() == 0
+
+# Verify successful round-trip.
+# CHECK-LABEL: TEST: testParseSuccess
+def testParseSuccess():
+  with Context():
+    # A first import is expected to fail because the pass isn't registered
+    # until we import mlir.transforms
+    try:
+      pm = PassManager.parse("module(func(print-op-stats))")
+      # TODO: this error should be propagate to Python but the C API does not help right now.
+      # CHECK: error: 'print-op-stats' does not refer to a registered pass or pass pipeline
+    except ValueError as e:
+      # CHECK: ValueError exception: invalid pass pipeline 'module(func(print-op-stats))'.
+      log("ValueError exception:", e)
+    else:
+      log("Exception not produced")
+
+    # This will register the pass and round-trip should be possible now.
+    import mlir.transforms
+    pm = PassManager.parse("module(func(print-op-stats))")
+    # CHECK: Roundtrip: module(func(print-op-stats))
+    log("Roundtrip: ", pm)
+run(testParseSuccess)
+
+# Verify failure on unregistered pass.
+# CHECK-LABEL: TEST: testParseFail
+def testParseFail():
+  with Context():
+    try:
+      pm = PassManager.parse("unknown-pass")
+    except ValueError as e:
+      # CHECK: ValueError exception: invalid pass pipeline 'unknown-pass'.
+      log("ValueError exception:", e)
+    else:
+      log("Exception not produced")
+run(testParseFail)

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 8b5d3d409e43..60fd2d982453 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -100,6 +100,7 @@ endif()
 if(MLIR_BINDINGS_PYTHON_ENABLED)
   list(APPEND MLIR_TEST_DEPENDS
     MLIRBindingsPythonExtension
+    MLIRTransformsBindingsPythonExtension
   )
 endif()
 

diff  --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
index 32d540967f91..4fa1150957c5 100644
--- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp
@@ -58,6 +58,8 @@ const char *const fileFooter = R"(
 /// Emit TODO
 static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
   os << fileHeader;
+  os << "// Registration for the entire group\n";
+  os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName << "Passes();\n\n";
   for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
     Pass pass(def);
     StringRef defName = pass.getDef()->getName();
@@ -77,8 +79,21 @@ void mlirRegister{0}{1}() {
 
 )";
 
+/// {0}: The name of the pass group.
+const char *const passGroupRegistrationCode = R"(
+//===----------------------------------------------------------------------===//
+// {0} Group Registration
+//===----------------------------------------------------------------------===//
+
+void mlirRegister{0}Passes() {{
+  register{0}Passes();
+}
+)";
+
 static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
   os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
+  os << llvm::formatv(passGroupRegistrationCode, groupName);
+
   for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
     Pass pass(def);
     StringRef defName = pass.getDef()->getName();


        


More information about the Mlir-commits mailing list