[Mlir-commits] [mlir] 95b77f2 - Adds __str__ support to python mlir.ir.MlirModule.
Stella Laurenzo
llvmlistbot at llvm.org
Mon Aug 17 09:47:04 PDT 2020
Author: Stella Laurenzo
Date: 2020-08-17T09:46:33-07:00
New Revision: 95b77f2eac8f3498fc20299f2515b3ce3440b82e
URL: https://github.com/llvm/llvm-project/commit/95b77f2eac8f3498fc20299f2515b3ce3440b82e
DIFF: https://github.com/llvm/llvm-project/commit/95b77f2eac8f3498fc20299f2515b3ce3440b82e.diff
LOG: Adds __str__ support to python mlir.ir.MlirModule.
* Also raises an exception on parse error.
* Removes placeholder smoketest.
* Adds docstrings.
Differential Revision: https://reviews.llvm.org/D86046
Added:
mlir/lib/Bindings/Python/PybindUtils.cpp
mlir/lib/Bindings/Python/PybindUtils.h
mlir/test/Bindings/Python/ir_module_test.py
Modified:
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/Bindings/Python/MainModule.cpp
mlir/lib/CAPI/IR/IR.cpp
Removed:
mlir/test/Bindings/Python/ir_test.py
mlir/test/Bindings/Python/smoke_test.py
################################################################################
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index fab4061bcb9f..0f03445c2711 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -22,6 +22,7 @@ endif()
add_library(MLIRBindingsPythonExtension ${PYEXT_LINK_MODE}
MainModule.cpp
IRModules.cpp
+ PybindUtils.cpp
)
target_include_directories(MLIRBindingsPythonExtension PRIVATE
@@ -68,7 +69,6 @@ endif()
target_link_libraries(MLIRBindingsPythonExtension
PRIVATE
- MLIRIR
MLIRCAPIIR
MLIRCAPIRegistration
${PYEXT_LIBADD}
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index b3b790af5516..27e1854e7455 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -7,6 +7,61 @@
//===----------------------------------------------------------------------===//
#include "IRModules.h"
+#include "PybindUtils.h"
+
+namespace py = pybind11;
+using namespace mlir::python;
+
+//------------------------------------------------------------------------------
+// Docstrings (trivial, non-duplicated docstrings are included inline).
+//------------------------------------------------------------------------------
+
+static const char kContextParseDocstring[] =
+ R"(Parses a module's assembly format from a string.
+
+Returns a new MlirModule or raises a ValueError if the parsing fails.
+)";
+
+static const char kOperationStrDunderDocstring[] =
+ R"(Prints the assembly form of the operation with default options.
+
+If more advanced control over the assembly formatting or I/O options is needed,
+use the dedicated print method, which supports keyword arguments to customize
+behavior.
+)";
+
+static const char kDumpDocstring[] =
+ R"(Dumps a debug representation of the object to stderr.)";
+
+//------------------------------------------------------------------------------
+// Conversion utilities.
+//------------------------------------------------------------------------------
+
+namespace {
+
+/// Accumulates into a python string from a method that accepts an
+/// MlirPrintCallback.
+struct PyPrintAccumulator {
+ py::list parts;
+
+ void *getUserData() { return this; }
+
+ MlirPrintCallback getCallback() {
+ return [](const char *part, intptr_t size, void *userData) {
+ PyPrintAccumulator *printAccum =
+ static_cast<PyPrintAccumulator *>(userData);
+ py::str pyPart(part, size); // Decodes as UTF-8 by default.
+ printAccum->parts.append(std::move(pyPart));
+ };
+ }
+
+ py::str join() {
+ py::str delim("", 0);
+ return delim.attr("join")(parts);
+ }
+};
+
+} // namespace
//------------------------------------------------------------------------------
// Context Wrapper Class.
@@ -14,6 +69,10 @@
PyMlirModule PyMlirContext::parse(const std::string &module) {
auto moduleRef = mlirModuleCreateParse(context, module.c_str());
+ if (!moduleRef.ptr) {
+ throw SetPyError(PyExc_ValueError,
+ "Unable to parse module assembly (see diagnostics)");
+ }
return PyMlirModule(moduleRef);
}
@@ -27,10 +86,22 @@ void PyMlirModule::dump() { mlirOperationDump(mlirModuleGetOperation(module)); }
// Populates the pybind11 IR submodule.
//------------------------------------------------------------------------------
-void populateIRSubmodule(py::module &m) {
+void mlir::python::populateIRSubmodule(py::module &m) {
py::class_<PyMlirContext>(m, "MlirContext")
.def(py::init<>())
- .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>());
+ .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(),
+ kContextParseDocstring);
- py::class_<PyMlirModule>(m, "MlirModule").def("dump", &PyMlirModule::dump);
+ py::class_<PyMlirModule>(m, "MlirModule")
+ .def("dump", &PyMlirModule::dump, kDumpDocstring)
+ .def(
+ "__str__",
+ [](PyMlirModule &self) {
+ auto operation = mlirModuleGetOperation(self.module);
+ PyPrintAccumulator printAccum;
+ mlirOperationPrint(operation, printAccum.getCallback(),
+ printAccum.getUserData());
+ return printAccum.join();
+ },
+ kOperationStrDunderDocstring);
}
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 6bfd00656df9..325db497e2aa 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -13,7 +13,8 @@
#include "mlir-c/IR.h"
-namespace py = pybind11;
+namespace mlir {
+namespace python {
class PyMlirContext;
class PyMlirModule;
@@ -48,6 +49,9 @@ class PyMlirModule {
MlirModule module;
};
-void populateIRSubmodule(py::module &m);
+void populateIRSubmodule(pybind11::module &m);
+
+} // namespace python
+} // namespace mlir
#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 760aa86a7b1c..7dd525b4b340 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -11,21 +11,14 @@
#include <pybind11/pybind11.h>
#include "IRModules.h"
-#include "mlir/IR/MLIRContext.h"
+namespace py = pybind11;
using namespace mlir;
+using namespace mlir::python;
PYBIND11_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
- m.def("get_test_value", []() {
- // This is just calling a method on the MLIRContext as a smoketest
- // for linkage.
- MLIRContext context;
- return std::make_tuple(std::string("From the native module"),
- context.isMultithreadingEnabled());
- });
-
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRSubmodule(irModule);
diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp
new file mode 100644
index 000000000000..9013c0669794
--- /dev/null
+++ b/mlir/lib/Bindings/Python/PybindUtils.cpp
@@ -0,0 +1,18 @@
+//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PybindUtils.h"
+
+namespace py = pybind11;
+
+pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass,
+ llvm::Twine message) {
+ auto messageStr = message.str();
+ PyErr_SetString(excClass, messageStr.c_str());
+ return pybind11::error_already_set();
+}
diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h
new file mode 100644
index 000000000000..1a82f8e824ec
--- /dev/null
+++ b/mlir/lib/Bindings/Python/PybindUtils.h
@@ -0,0 +1,28 @@
+//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
+#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
+
+#include <pybind11/pybind11.h>
+
+#include "llvm/ADT/Twine.h"
+
+namespace mlir {
+namespace python {
+
+// Sets a python error, ready to be thrown to return control back to the
+// python runtime.
+// Correct usage:
+// throw SetPyError(PyExc_ValueError, "Foobar'd");
+pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message);
+
+} // namespace python
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4ccfb45f2c43..5231096af785 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -126,6 +126,8 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) {
MlirModule mlirModuleCreateParse(MlirContext context, const char *module) {
OwningModuleRef owning = parseSourceString(module, unwrap(context));
+ if (!owning)
+ return MlirModule{nullptr};
return MlirModule{owning.release().getOperation()};
}
diff --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module_test.py
new file mode 100644
index 000000000000..26b7fe63369c
--- /dev/null
+++ b/mlir/test/Bindings/Python/ir_module_test.py
@@ -0,0 +1,49 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import mlir
+
+def run(f):
+ print("TEST:", f.__name__)
+ f()
+
+# Verify successful parse.
+# CHECK-LABEL: TEST: testParseSuccess
+# CHECK: module @successfulParse
+def testParseSuccess():
+ ctx = mlir.ir.MlirContext()
+ module = ctx.parse(r"""module @successfulParse {}""")
+ module.dump() # Just outputs to stderr. Verifies that it functions.
+ print(str(module))
+
+run(testParseSuccess)
+
+
+# Verify parse error.
+# CHECK-LABEL: TEST: testParseError
+# CHECK: testParseError: Unable to parse module assembly (see diagnostics)
+def testParseError():
+ ctx = mlir.ir.MlirContext()
+ try:
+ module = ctx.parse(r"""}SYNTAX ERROR{""")
+ except ValueError as e:
+ print("testParseError:", e)
+ else:
+ print("Exception not produced")
+
+run(testParseError)
+
+
+# Verify round-trip of ASM that contains unicode.
+# Note that this does not test that the print path converts unicode properly
+# because MLIR asm always normalizes it to the hex encoding.
+# CHECK-LABEL: TEST: testRoundtripUnicode
+# CHECK: func @roundtripUnicode()
+# CHECK: foo = "\F0\9F\98\8A"
+def testRoundtripUnicode():
+ ctx = mlir.ir.MlirContext()
+ module = ctx.parse(r"""
+ func @roundtripUnicode() attributes { foo = "😊" }
+ """)
+ print(str(module))
+
+run(testRoundtripUnicode)
diff --git a/mlir/test/Bindings/Python/ir_test.py b/mlir/test/Bindings/Python/ir_test.py
deleted file mode 100644
index 1dfb356f956b..000000000000
--- a/mlir/test/Bindings/Python/ir_test.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# RUN: %PYTHON %s | FileCheck %s
-
-import mlir
-
-TEST_MLIR_ASM = r"""
-module {
-}
-"""
-
-ctx = mlir.ir.MlirContext()
-module = ctx.parse(TEST_MLIR_ASM)
-module.dump()
-print(bool(module))
-# CHECK: True
diff --git a/mlir/test/Bindings/Python/smoke_test.py b/mlir/test/Bindings/Python/smoke_test.py
deleted file mode 100644
index 3904e72e25bb..000000000000
--- a/mlir/test/Bindings/Python/smoke_test.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# RUN: %PYTHON %s | FileCheck %s
-
-import mlir
-
-# CHECK: From the native module
-print(mlir.get_test_value())
More information about the Mlir-commits
mailing list