[Mlir-commits] [mlir] 74a58ec - [mlir][CAPI][Python] Plumb OpPrintingFlags to C and Python APIs.
Stella Laurenzo
llvmlistbot at llvm.org
Wed Oct 21 12:15:26 PDT 2020
Author: Stella Laurenzo
Date: 2020-10-21T12:14:06-07:00
New Revision: 74a58ec9c27f48eb26094667156934c6ca9d0012
URL: https://github.com/llvm/llvm-project/commit/74a58ec9c27f48eb26094667156934c6ca9d0012
DIFF: https://github.com/llvm/llvm-project/commit/74a58ec9c27f48eb26094667156934c6ca9d0012.diff
LOG: [mlir][CAPI][Python] Plumb OpPrintingFlags to C and Python APIs.
* Adds a new MlirOpPrintingFlags type and supporting accessors.
* Adds a new mlirOperationPrintWithFlags function.
* Adds a full featured python Operation.print method with all options and the ability to print directly to files/stdout in text or binary.
* Adds an Operation.get_asm which delegates to print and returns a str or bytes.
* Reworks Operation.__str__ to be based on get_asm.
Differential Revision: https://reviews.llvm.org/D89848
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/Bindings/Python/IRModules.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/Bindings/Python/ir_operation.py
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 2aeb306f7256..a08fe77da37c 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
DEFINE_C_API_STRUCT(MlirContext, void);
DEFINE_C_API_STRUCT(MlirDialect, void);
DEFINE_C_API_STRUCT(MlirOperation, void);
+DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
DEFINE_C_API_STRUCT(MlirBlock, void);
DEFINE_C_API_STRUCT(MlirRegion, void);
@@ -228,6 +229,42 @@ void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
MlirNamedAttribute *attributes);
+/*============================================================================*/
+/* Op Printing flags API. */
+/* While many of these are simple settings that could be represented in a */
+/* struct, they are wrapped in a heap allocated object and accessed via */
+/* functions to maximize the possibility of compatibility over time. */
+/*============================================================================*/
+
+/** Creates new printing flags with defaults, intended for customization.
+ * Must be freed with a call to mlirOpPrintingFlagsDestroy(). */
+MlirOpPrintingFlags mlirOpPrintingFlagsCreate();
+
+/** Destroys printing flags created with mlirOpPrintingFlagsCreate. */
+void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags);
+
+/** Enables the elision of large elements attributes by printing a lexically
+ * valid but otherwise meaningless form instead of the element data. The
+ * `largeElementLimit` is used to configure what is considered to be a "large"
+ * ElementsAttr by providing an upper limit to the number of elements. */
+void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
+ intptr_t largeElementLimit);
+
+/** Enable printing of debug information. If 'prettyForm' is set to true,
+ * debug information is printed in a more readable 'pretty' form. Note: The
+ * IR generated with 'prettyForm' is not parsable. */
+void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
+ int prettyForm);
+
+/** Always print operations in the generic form. */
+void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);
+
+/** Use local scope when printing the operation. This allows for using the
+ * printer in a more localized and thread-safe setting, but may not
+ * necessarily be identical to what the IR will look like when dumping
+ * the full module. */
+void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
+
/*============================================================================*/
/* Operation API. */
/*============================================================================*/
@@ -298,6 +335,11 @@ int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData);
+/** Same as mlirOperationPrint but accepts flags controlling the printing
+ * behavior. */
+void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
+ MlirStringCallback callback, void *userData);
+
/** Prints an operation to stderr. */
void mlirOperationDump(MlirOperation op);
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index dce293d05588..b3e481dfb665 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -24,6 +24,7 @@ DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
+DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags);
DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e1e34f8da6c6..c745c1dedea3 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -562,10 +562,10 @@ class OpPrintingFlags {
OpPrintingFlags();
OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
- /// Enable the elision of large elements attributes, by printing a '...'
- /// instead of the element data. Note: The IR generated with this option is
- /// not parsable. `largeElementLimit` is used to configure what is considered
- /// to be a "large" ElementsAttr by providing an upper limit to the number of
+ /// Enables the elision of large elements attributes by printing a lexically
+ /// valid but otherwise meaningless form instead of the element data. The
+ /// `largeElementLimit` is used to configure what is considered to be a
+ /// "large" ElementsAttr by providing an upper limit to the number of
/// elements.
OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index db8a220c9d31..014b312971b7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -64,12 +64,44 @@ static const char kContextGetUnknownLocationDocstring[] =
static const char kContextGetFileLocationDocstring[] =
R"(Gets a Location representing a file, line and column)";
+static const char kOperationPrintDocstring[] =
+ R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+ file: The file like object to write to. Defaults to sys.stdout.
+ binary: Whether to write bytes (True) or str (False). Defaults to False.
+ large_elements_limit: Whether to elide elements attributes above this
+ number of elements. Defaults to None (no limit).
+ enable_debug_info: Whether to print debug/location information. Defaults
+ to False.
+ pretty_debug_info: Whether to format debug information for easier reading
+ by a human (warning: the result is unparseable).
+ print_generic_op_form: Whether to print the generic assembly forms of all
+ ops. Defaults to False.
+ use_local_Scope: Whether to print in a way that is more optimized for
+ multi-threaded access but may not be consistent with how the overall
+ module prints.
+)";
+
+static const char kOperationGetAsmDocstring[] =
+ R"(Gets the assembly form of the operation with all options available.
+
+Args:
+ binary: Whether to return a bytes (True) or str (False) object. Defaults to
+ False.
+ ... others ...: See the print() method for common keyword arguments for
+ configuring the printout.
+Returns:
+ Either a bytes or str object, depending on the setting of the 'binary'
+ argument.
+)";
+
static const char kOperationStrDunderDocstring[] =
- R"(Prints the assembly form of the operation with default options.
+ R"(Gets the assembly form of the operation with default options.
If more advanced control over the assembly formatting or I/O options is needed,
-use the dedicated print method, which supports keyword arguments to customize
-behavior.
+use the dedicated print or get_asm method, which supports keyword arguments to
+customize behavior.
)";
static const char kDumpDocstring[] =
@@ -118,6 +150,35 @@ struct PyPrintAccumulator {
}
};
+/// Accumulates int a python file-like object, either writing text (default)
+/// or binary.
+class PyFileAccumulator {
+public:
+ PyFileAccumulator(py::object fileObject, bool binary)
+ : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
+
+ void *getUserData() { return this; }
+
+ MlirStringCallback getCallback() {
+ return [](const char *part, intptr_t size, void *userData) {
+ py::gil_scoped_acquire();
+ PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
+ if (accum->binary) {
+ // Note: Still has to copy and not avoidable with this API.
+ py::bytes pyBytes(part, size);
+ accum->pyWriteFunction(pyBytes);
+ } else {
+ py::str pyStr(part, size); // Decodes as UTF-8 by default.
+ accum->pyWriteFunction(pyStr);
+ }
+ };
+ }
+
+private:
+ py::object pyWriteFunction;
+ bool binary;
+};
+
/// Accumulates into a python string from a method that is expected to make
/// one (no more, no less) call to the callback (asserts internally on
/// violation).
@@ -712,6 +773,48 @@ void PyOperation::checkValid() {
}
}
+void PyOperation::print(py::object fileObject, bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm, bool useLocalScope) {
+ checkValid();
+ if (fileObject.is_none())
+ fileObject = py::module::import("sys").attr("stdout");
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ if (largeElementsLimit)
+ mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
+ if (enableDebugInfo)
+ mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
+ if (printGenericOpForm)
+ mlirOpPrintingFlagsPrintGenericOpForm(flags);
+
+ PyFileAccumulator accum(fileObject, binary);
+ py::gil_scoped_release();
+ mlirOperationPrintWithFlags(get(), flags, accum.getCallback(),
+ accum.getUserData());
+ mlirOpPrintingFlagsDestroy(flags);
+}
+
+py::object PyOperation::getAsm(bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm, bool useLocalScope) {
+ py::object fileObject;
+ if (binary) {
+ fileObject = py::module::import("io").attr("BytesIO")();
+ } else {
+ fileObject = py::module::import("io").attr("StringIO")();
+ }
+ print(fileObject, /*binary=*/binary,
+ /*largeElementsLimit=*/largeElementsLimit,
+ /*enableDebugInfo=*/enableDebugInfo,
+ /*prettyDebugInfo=*/prettyDebugInfo,
+ /*printGenericOpForm=*/printGenericOpForm,
+ /*useLocalScope=*/useLocalScope);
+
+ return fileObject.attr("getvalue")();
+}
+
//------------------------------------------------------------------------------
// PyAttribute.
//------------------------------------------------------------------------------
@@ -745,7 +848,8 @@ namespace {
/// CRTP base class for Python MLIR values that subclass Value and should be
/// castable from it. The value hierarchy is one level deep and is not supposed
/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy> class PyConcreteValue : public PyValue {
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
@@ -1969,13 +2073,30 @@ void mlir::python::populateIRSubmodule(py::module &m) {
.def(
"__str__",
[](PyOperation &self) {
- self.checkValid();
- PyPrintAccumulator printAccum;
- mlirOperationPrint(self.get(), printAccum.getCallback(),
- printAccum.getUserData());
- return printAccum.join();
+ return self.getAsm(/*binary=*/false,
+ /*largeElementsLimit=*/llvm::None,
+ /*enableDebugInfo=*/false,
+ /*prettyDebugInfo=*/false,
+ /*printGenericOpForm=*/false,
+ /*useLocalScope=*/false);
},
- "Returns the assembly form of the operation.");
+ "Returns the assembly form of the operation.")
+ .def("print", &PyOperation::print,
+ // Careful: Lots of arguments must match up with print method.
+ py::arg("file") = py::none(), py::arg("binary") = false,
+ py::arg("large_elements_limit") = py::none(),
+ py::arg("enable_debug_info") = false,
+ py::arg("pretty_debug_info") = false,
+ py::arg("print_generic_op_form") = false,
+ py::arg("use_local_scope") = false, kOperationPrintDocstring)
+ .def("get_asm", &PyOperation::getAsm,
+ // Careful: Lots of arguments must match up with get_asm method.
+ py::arg("binary") = false,
+ py::arg("large_elements_limit") = py::none(),
+ py::arg("enable_debug_info") = false,
+ py::arg("pretty_debug_info") = false,
+ py::arg("print_generic_op_form") = false,
+ py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
// Mapping of PyRegion.
py::class_<PyRegion>(m, "Region")
diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 947b7343e35a..b438e8ac408d 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -277,6 +277,15 @@ class PyOperation : public BaseContextObject {
}
void checkValid();
+ /// Implements the bound 'print' method and helps with others.
+ void print(pybind11::object fileObject, bool binary,
+ llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+ bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+ pybind11::object getAsm(bool binary,
+ llvm::Optional<int64_t> largeElementsLimit,
+ bool enableDebugInfo, bool prettyDebugInfo,
+ bool printGenericOpForm, bool useLocalScope);
+
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 104f6fda5c02..379770c8962f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -74,6 +74,36 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
return wrap(unwrap(dialect)->getNamespace());
}
+/* ========================================================================== */
+/* Printing flags API. */
+/* ========================================================================== */
+
+MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
+ return wrap(new OpPrintingFlags());
+}
+
+void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
+ delete unwrap(flags);
+}
+
+void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
+ intptr_t largeElementLimit) {
+ unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
+}
+
+void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
+ int prettyForm) {
+ unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
+}
+
+void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
+ unwrap(flags)->printGenericOpForm();
+}
+
+void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
+ unwrap(flags)->useLocalScope();
+}
+
/* ========================================================================== */
/* Location API. */
/* ========================================================================== */
@@ -282,6 +312,13 @@ void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
stream.flush();
}
+void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
+ MlirStringCallback callback, void *userData) {
+ detail::CallbackOstream stream(callback, userData);
+ unwrap(op)->print(stream, *unwrap(flags));
+ stream.flush();
+}
+
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
/* ========================================================================== */
diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index e4dc71ac26ef..84f303ca570b 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -1,6 +1,7 @@
# RUN: %PYTHON %s | FileCheck %s
import gc
+import io
import itertools
import mlir
@@ -248,3 +249,44 @@ def testOperationResultList():
run(testOperationResultList)
+
+
+# CHECK-LABEL: TEST: testOperationPrint
+def testOperationPrint():
+ ctx = mlir.ir.Context()
+ module = ctx.parse_module(r"""
+ func @f1(%arg0: i32) -> i32 {
+ %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+ return %arg0 : i32
+ }
+ """)
+
+ # Test print to stdout.
+ # CHECK: return %arg0 : i32
+ module.operation.print()
+
+ # Test print to text file.
+ f = io.StringIO()
+ # CHECK: <class 'str'>
+ # CHECK: return %arg0 : i32
+ module.operation.print(file=f)
+ str_value = f.getvalue()
+ print(str_value.__class__)
+ print(f.getvalue())
+
+ # Test print to binary file.
+ f = io.BytesIO()
+ # CHECK: <class 'bytes'>
+ # CHECK: return %arg0 : i32
+ module.operation.print(file=f, binary=True)
+ bytes_value = f.getvalue()
+ print(bytes_value.__class__)
+ print(bytes_value)
+
+ # Test get_asm with options.
+ # CHECK: value = opaque<"", "0xDEADBEEF"> : tensor<4xi32>
+ # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
+ module.operation.print(large_elements_limit=2, enable_debug_info=True,
+ pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
+
+run(testOperationPrint)
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 7c86f403b339..fa9a6258a472 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -10,9 +10,9 @@
/* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
*/
+#include "mlir-c/IR.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
-#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardDialect.h"
@@ -319,6 +319,25 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
fprintf(stderr, "Removed attr is null: %d\n",
mlirAttributeIsNull(
mlirOperationGetAttributeByName(operation, "custom_attr")));
+
+ // Add a large attribute to verify printing flags.
+ int64_t eltsShape[] = {4};
+ int32_t eltsData[] = {1, 2, 3, 4};
+ mlirOperationSetAttributeByName(
+ operation, "elts",
+ mlirDenseElementsAttrInt32Get(
+ mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
+ eltsData));
+ MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+ mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
+ mlirOpPrintingFlagsPrintGenericOpForm(flags);
+ mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ fprintf(stderr, "Op print with all flags: ");
+ mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ mlirOpPrintingFlagsDestroy(flags);
}
/// Creates an operation with a region containing multiple blocks with
@@ -991,6 +1010,7 @@ int main() {
// CHECK: Remove attr: 1
// CHECK: Remove attr again: 0
// CHECK: Removed attr is null: 1
+ // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
// clang-format on
mlirModuleDestroy(moduleOp);
More information about the Mlir-commits
mailing list