[Mlir-commits] [mlir] 95b77f2 - Adds __str__ support to python mlir.ir.MlirModule.

Stella Laurenzo llvmlistbot at llvm.org
Mon Aug 17 09:47:04 PDT 2020


Author: Stella Laurenzo
Date: 2020-08-17T09:46:33-07:00
New Revision: 95b77f2eac8f3498fc20299f2515b3ce3440b82e

URL: https://github.com/llvm/llvm-project/commit/95b77f2eac8f3498fc20299f2515b3ce3440b82e
DIFF: https://github.com/llvm/llvm-project/commit/95b77f2eac8f3498fc20299f2515b3ce3440b82e.diff

LOG: Adds __str__ support to python mlir.ir.MlirModule.

* Also raises an exception on parse error.
* Removes placeholder smoketest.
* Adds docstrings.

Differential Revision: https://reviews.llvm.org/D86046

Added: 
    mlir/lib/Bindings/Python/PybindUtils.cpp
    mlir/lib/Bindings/Python/PybindUtils.h
    mlir/test/Bindings/Python/ir_module_test.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/CAPI/IR/IR.cpp

Removed: 
    mlir/test/Bindings/Python/ir_test.py
    mlir/test/Bindings/Python/smoke_test.py


################################################################################
diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index fab4061bcb9f..0f03445c2711 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -22,6 +22,7 @@ endif()
 add_library(MLIRBindingsPythonExtension ${PYEXT_LINK_MODE}
   MainModule.cpp
   IRModules.cpp
+  PybindUtils.cpp
 )
 
 target_include_directories(MLIRBindingsPythonExtension PRIVATE
@@ -68,7 +69,6 @@ endif()
 
 target_link_libraries(MLIRBindingsPythonExtension
   PRIVATE
-  MLIRIR
   MLIRCAPIIR
   MLIRCAPIRegistration
   ${PYEXT_LIBADD}

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index b3b790af5516..27e1854e7455 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -7,6 +7,61 @@
 //===----------------------------------------------------------------------===//
 
 #include "IRModules.h"
+#include "PybindUtils.h"
+
+namespace py = pybind11;
+using namespace mlir::python;
+
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kContextParseDocstring[] =
+    R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises a ValueError if the parsing fails.
+)";
+
+static const char kOperationStrDunderDocstring[] =
+    R"(Prints the assembly form of the operation with default options.
+
+If more advanced control over the assembly formatting or I/O options is needed,
+use the dedicated print method, which supports keyword arguments to customize
+behavior.
+)";
+
+static const char kDumpDocstring[] =
+    R"(Dumps a debug representation of the object to stderr.)";
+
+//------------------------------------------------------------------------------
+// Conversion utilities.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// Accumulates into a python string from a method that accepts an
+/// MlirPrintCallback.
+struct PyPrintAccumulator {
+  py::list parts;
+
+  void *getUserData() { return this; }
+
+  MlirPrintCallback getCallback() {
+    return [](const char *part, intptr_t size, void *userData) {
+      PyPrintAccumulator *printAccum =
+          static_cast<PyPrintAccumulator *>(userData);
+      py::str pyPart(part, size); // Decodes as UTF-8 by default.
+      printAccum->parts.append(std::move(pyPart));
+    };
+  }
+
+  py::str join() {
+    py::str delim("", 0);
+    return delim.attr("join")(parts);
+  }
+};
+
+} // namespace
 
 //------------------------------------------------------------------------------
 // Context Wrapper Class.
@@ -14,6 +69,10 @@
 
 PyMlirModule PyMlirContext::parse(const std::string &module) {
   auto moduleRef = mlirModuleCreateParse(context, module.c_str());
+  if (!moduleRef.ptr) {
+    throw SetPyError(PyExc_ValueError,
+                     "Unable to parse module assembly (see diagnostics)");
+  }
   return PyMlirModule(moduleRef);
 }
 
@@ -27,10 +86,22 @@ void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
 // Populates the pybind11 IR submodule.
 //------------------------------------------------------------------------------
 
-void populateIRSubmodule(py::module &m) {
+void mlir::python::populateIRSubmodule(py::module &m) {
   py::class_<PyMlirContext>(m, "MlirContext")
       .def(py::init<>())
-      .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>());
+      .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
+           kContextParseDocstring);
 
-  py::class_<PyMlirModule>(m, "MlirModule").def("dump", &PyMlirModule::dump);
+  py::class_<PyMlirModule>(m, "MlirModule")
+      .def("dump", &PyMlirModule::dump, kDumpDocstring)
+      .def(
+          "__str__",
+          [](PyMlirModule &self) {
+            auto operation = mlirModuleGetOperation(self.module);
+            PyPrintAccumulator printAccum;
+            mlirOperationPrint(operation, printAccum.getCallback(),
+                               printAccum.getUserData());
+            return printAccum.join();
+          },
+          kOperationStrDunderDocstring);
 }

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 6bfd00656df9..325db497e2aa 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -13,7 +13,8 @@
 
 #include "mlir-c/IR.h"
 
-namespace py = pybind11;
+namespace mlir {
+namespace python {
 
 class PyMlirContext;
 class PyMlirModule;
@@ -48,6 +49,9 @@ class PyMlirModule {
   MlirModule module;
 };
 
-void populateIRSubmodule(py::module &m);
+void populateIRSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
 
 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 760aa86a7b1c..7dd525b4b340 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,21 +11,14 @@
 #include <pybind11/pybind11.h>
 
 #include "IRModules.h"
-#include "mlir/IR/MLIRContext.h"
 
+namespace py = pybind11;
 using namespace mlir;
+using namespace mlir::python;
 
 PYBIND11_MODULE(_mlir, m) {
   m.doc() = "MLIR Python Native Extension";
 
-  m.def("get_test_value", []() {
-    // This is just calling a method on the MLIRContext as a smoketest
-    // for linkage.
-    MLIRContext context;
-    return std::make_tuple(std::string("From the native module"),
-                           context.isMultithreadingEnabled());
-  });
-
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
   populateIRSubmodule(irModule);

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp
new file mode 100644
index 000000000000..9013c0669794
--- /dev/null
+++ b/mlir/lib/Bindings/Python/PybindUtils.cpp
@@ -0,0 +1,18 @@
+//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PybindUtils.h"
+
+namespace py = pybind11;
+
+pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
+                                                     llvm::Twine message) {
+  auto messageStr = message.str();
+  PyErr_SetString(excClass, messageStr.c_str());
+  return pybind11::error_already_set();
+}

diff  --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
new file mode 100644
index 000000000000..1a82f8e824ec
--- /dev/null
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -0,0 +1,28 @@
+//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
+#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
+
+#include <pybind11/pybind11.h>
+
+#include "llvm/ADT/Twine.h"
+
+namespace mlir {
+namespace python {
+
+// Sets a python error, ready to be thrown to return control back to the
+// python runtime.
+// Correct usage:
+//   throw SetPyError(PyExc_ValueError, "Foobar'd");
+pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4ccfb45f2c43..5231096af785 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -126,6 +126,8 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) {
 
 MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
   OwningModuleRef owning = parseSourceString(module, unwrap(context));
+  if (!owning)
+    return MlirModule{nullptr};
   return MlirModule{owning.release().getOperation()};
 }
 

diff  --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module_test.py
new file mode 100644
index 000000000000..26b7fe63369c
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_module_test.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+  print("TEST:", f.__name__)
+  f()
+
+# Verify successful parse.
+# CHECK-LABEL: TEST: testParseSuccess
+# CHECK: module @successfulParse
+def testParseSuccess():
+  ctx = mlir.ir.MlirContext()
+  module = ctx.parse(r"""module @successfulParse {}""")
+  module.dump()  # Just outputs to stderr. Verifies that it functions.
+  print(str(module))
+
+run(testParseSuccess)
+
+
+# Verify parse error.
+# CHECK-LABEL: TEST: testParseError
+# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+def testParseError():
+  ctx = mlir.ir.MlirContext()
+  try:
+    module = ctx.parse(r"""}SYNTAX ERROR{""")
+  except ValueError as e:
+    print("testParseError:", e)
+  else:
+    print("Exception not produced")
+
+run(testParseError)
+
+
+# Verify round-trip of ASM that contains unicode.
+# Note that this does not test that the print path converts unicode properly
+# because MLIR asm always normalizes it to the hex encoding.
+# CHECK-LABEL: TEST: testRoundtripUnicode
+# CHECK: func @roundtripUnicode()
+# CHECK: foo = "\F0\9F\98\8A"
+def testRoundtripUnicode():
+  ctx = mlir.ir.MlirContext()
+  module = ctx.parse(r"""
+    func @roundtripUnicode() attributes { foo = "😊" }
+  """)
+  print(str(module))
+
+run(testRoundtripUnicode)

diff  --git a/mlir/test/Bindings/Python/ir_test.py b/mlir/test/Bindings/Python/ir_test.py
deleted file mode 100644
index 1dfb356f956b..000000000000
--- a/mlir/test/Bindings/Python/ir_test.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# RUN: %PYTHON %s | FileCheck %s
-
-import mlir
-
-TEST_MLIR_ASM = r"""
-module {
-}
-"""
-
-ctx = mlir.ir.MlirContext()
-module = ctx.parse(TEST_MLIR_ASM)
-module.dump()
-print(bool(module))
-# CHECK: True

diff  --git a/mlir/test/Bindings/Python/smoke_test.py b/mlir/test/Bindings/Python/smoke_test.py
deleted file mode 100644
index 3904e72e25bb..000000000000
--- a/mlir/test/Bindings/Python/smoke_test.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# RUN: %PYTHON %s | FileCheck %s
-
-import mlir
-
-# CHECK: From the native module
-print(mlir.get_test_value())


        


More information about the Mlir-commits mailing list