[Mlir-commits] [mlir] 74a58ec - [mlir][CAPI][Python] Plumb OpPrintingFlags to C and Python APIs.

Stella Laurenzo llvmlistbot at llvm.org
Wed Oct 21 12:15:26 PDT 2020


Author: Stella Laurenzo
Date: 2020-10-21T12:14:06-07:00
New Revision: 74a58ec9c27f48eb26094667156934c6ca9d0012

URL: https://github.com/llvm/llvm-project/commit/74a58ec9c27f48eb26094667156934c6ca9d0012
DIFF: https://github.com/llvm/llvm-project/commit/74a58ec9c27f48eb26094667156934c6ca9d0012.diff

LOG: [mlir][CAPI][Python] Plumb OpPrintingFlags to C and Python APIs.

* Adds a new MlirOpPrintingFlags type and supporting accessors.
* Adds a new mlirOperationPrintWithFlags function.
* Adds a full featured python Operation.print method with all options and the ability to print directly to files/stdout in text or binary.
* Adds an Operation.get_asm which delegates to print and returns a str or bytes.
* Reworks Operation.__str__ to be based on get_asm.

Differential Revision: https://reviews.llvm.org/D89848

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/CAPI/IR.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/Bindings/Python/IRModules.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 2aeb306f7256..a08fe77da37c 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -50,6 +50,7 @@ extern "C" {
 DEFINE_C_API_STRUCT(MlirContext, void);
 DEFINE_C_API_STRUCT(MlirDialect, void);
 DEFINE_C_API_STRUCT(MlirOperation, void);
+DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void);
 DEFINE_C_API_STRUCT(MlirBlock, void);
 DEFINE_C_API_STRUCT(MlirRegion, void);
 
@@ -228,6 +229,42 @@ void mlirOperationStateAddSuccessors(MlirOperationState *state, intptr_t n,
 void mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n,
                                      MlirNamedAttribute *attributes);
 
+/*============================================================================*/
+/* Op Printing flags API.                                                     */
+/* While many of these are simple settings that could be represented in a     */
+/* struct, they are wrapped in a heap allocated object and accessed via       */
+/* functions to maximize the possibility of compatibility over time.          */
+/*============================================================================*/
+
+/** Creates new printing flags with defaults, intended for customization.
+ * Must be freed with a call to mlirOpPrintingFlagsDestroy(). */
+MlirOpPrintingFlags mlirOpPrintingFlagsCreate();
+
+/** Destroys printing flags created with mlirOpPrintingFlagsCreate. */
+void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags);
+
+/** Enables the elision of large elements attributes by printing a lexically
+ * valid but otherwise meaningless form instead of the element data. The
+ * `largeElementLimit` is used to configure what is considered to be a "large"
+ * ElementsAttr by providing an upper limit to the number of elements. */
+void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
+                                                intptr_t largeElementLimit);
+
+/** Enable printing of debug information. If 'prettyForm' is set to true,
+ * debug information is printed in a more readable 'pretty' form. Note: The
+ * IR generated with 'prettyForm' is not parsable. */
+void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
+                                        int prettyForm);
+
+/** Always print operations in the generic form. */
+void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);
+
+/** Use local scope when printing the operation. This allows for using the
+ * printer in a more localized and thread-safe setting, but may not
+ * necessarily be identical to what the IR will look like when dumping
+ * the full module. */
+void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
+
 /*============================================================================*/
 /* Operation API.                                                             */
 /*============================================================================*/
@@ -298,6 +335,11 @@ int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name);
 void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
                         void *userData);
 
+/** Same as mlirOperationPrint but accepts flags controlling the printing
+ * behavior. */
+void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
+                                 MlirStringCallback callback, void *userData);
+
 /** Prints an operation to stderr. */
 void mlirOperationDump(MlirOperation op);
 

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index dce293d05588..b3e481dfb665 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -24,6 +24,7 @@ DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
 DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect)
 DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
 DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
+DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags);
 DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
 
 DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index e1e34f8da6c6..c745c1dedea3 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -562,10 +562,10 @@ class OpPrintingFlags {
   OpPrintingFlags();
   OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
 
-  /// Enable the elision of large elements attributes, by printing a '...'
-  /// instead of the element data. Note: The IR generated with this option is
-  /// not parsable. `largeElementLimit` is used to configure what is considered
-  /// to be a "large" ElementsAttr by providing an upper limit to the number of
+  /// Enables the elision of large elements attributes by printing a lexically
+  /// valid but otherwise meaningless form instead of the element data. The
+  /// `largeElementLimit` is used to configure what is considered to be a
+  /// "large" ElementsAttr by providing an upper limit to the number of
   /// elements.
   OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);
 

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index db8a220c9d31..014b312971b7 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -64,12 +64,44 @@ static const char kContextGetUnknownLocationDocstring[] =
 static const char kContextGetFileLocationDocstring[] =
     R"(Gets a Location representing a file, line and column)";
 
+static const char kOperationPrintDocstring[] =
+    R"(Prints the assembly form of the operation to a file like object.
+
+Args:
+  file: The file like object to write to. Defaults to sys.stdout.
+  binary: Whether to write bytes (True) or str (False). Defaults to False.
+  large_elements_limit: Whether to elide elements attributes above this
+    number of elements. Defaults to None (no limit).
+  enable_debug_info: Whether to print debug/location information. Defaults
+    to False.
+  pretty_debug_info: Whether to format debug information for easier reading
+    by a human (warning: the result is unparseable).
+  print_generic_op_form: Whether to print the generic assembly forms of all
+    ops. Defaults to False.
+  use_local_Scope: Whether to print in a way that is more optimized for
+    multi-threaded access but may not be consistent with how the overall
+    module prints.
+)";
+
+static const char kOperationGetAsmDocstring[] =
+    R"(Gets the assembly form of the operation with all options available.
+
+Args:
+  binary: Whether to return a bytes (True) or str (False) object. Defaults to
+    False.
+  ... others ...: See the print() method for common keyword arguments for
+    configuring the printout.
+Returns:
+  Either a bytes or str object, depending on the setting of the 'binary'
+  argument.
+)";
+
 static const char kOperationStrDunderDocstring[] =
-    R"(Prints the assembly form of the operation with default options.
+    R"(Gets the assembly form of the operation with default options.
 
 If more advanced control over the assembly formatting or I/O options is needed,
-use the dedicated print method, which supports keyword arguments to customize
-behavior.
+use the dedicated print or get_asm method, which supports keyword arguments to
+customize behavior.
 )";
 
 static const char kDumpDocstring[] =
@@ -118,6 +150,35 @@ struct PyPrintAccumulator {
   }
 };
 
+/// Accumulates int a python file-like object, either writing text (default)
+/// or binary.
+class PyFileAccumulator {
+public:
+  PyFileAccumulator(py::object fileObject, bool binary)
+      : pyWriteFunction(fileObject.attr("write")), binary(binary) {}
+
+  void *getUserData() { return this; }
+
+  MlirStringCallback getCallback() {
+    return [](const char *part, intptr_t size, void *userData) {
+      py::gil_scoped_acquire();
+      PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
+      if (accum->binary) {
+        // Note: Still has to copy and not avoidable with this API.
+        py::bytes pyBytes(part, size);
+        accum->pyWriteFunction(pyBytes);
+      } else {
+        py::str pyStr(part, size); // Decodes as UTF-8 by default.
+        accum->pyWriteFunction(pyStr);
+      }
+    };
+  }
+
+private:
+  py::object pyWriteFunction;
+  bool binary;
+};
+
 /// Accumulates into a python string from a method that is expected to make
 /// one (no more, no less) call to the callback (asserts internally on
 /// violation).
@@ -712,6 +773,48 @@ void PyOperation::checkValid() {
   }
 }
 
+void PyOperation::print(py::object fileObject, bool binary,
+                        llvm::Optional<int64_t> largeElementsLimit,
+                        bool enableDebugInfo, bool prettyDebugInfo,
+                        bool printGenericOpForm, bool useLocalScope) {
+  checkValid();
+  if (fileObject.is_none())
+    fileObject = py::module::import("sys").attr("stdout");
+  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+  if (largeElementsLimit)
+    mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
+  if (enableDebugInfo)
+    mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
+  if (printGenericOpForm)
+    mlirOpPrintingFlagsPrintGenericOpForm(flags);
+
+  PyFileAccumulator accum(fileObject, binary);
+  py::gil_scoped_release();
+  mlirOperationPrintWithFlags(get(), flags, accum.getCallback(),
+                              accum.getUserData());
+  mlirOpPrintingFlagsDestroy(flags);
+}
+
+py::object PyOperation::getAsm(bool binary,
+                               llvm::Optional<int64_t> largeElementsLimit,
+                               bool enableDebugInfo, bool prettyDebugInfo,
+                               bool printGenericOpForm, bool useLocalScope) {
+  py::object fileObject;
+  if (binary) {
+    fileObject = py::module::import("io").attr("BytesIO")();
+  } else {
+    fileObject = py::module::import("io").attr("StringIO")();
+  }
+  print(fileObject, /*binary=*/binary,
+        /*largeElementsLimit=*/largeElementsLimit,
+        /*enableDebugInfo=*/enableDebugInfo,
+        /*prettyDebugInfo=*/prettyDebugInfo,
+        /*printGenericOpForm=*/printGenericOpForm,
+        /*useLocalScope=*/useLocalScope);
+
+  return fileObject.attr("getvalue")();
+}
+
 //------------------------------------------------------------------------------
 // PyAttribute.
 //------------------------------------------------------------------------------
@@ -745,7 +848,8 @@ namespace {
 /// CRTP base class for Python MLIR values that subclass Value and should be
 /// castable from it. The value hierarchy is one level deep and is not supposed
 /// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy> class PyConcreteValue : public PyValue {
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
 public:
   // Derived classes must define statics for:
   //   IsAFunctionTy isaFunction
@@ -1969,13 +2073,30 @@ void mlir::python::populateIRSubmodule(py::module &m) {
       .def(
           "__str__",
           [](PyOperation &self) {
-            self.checkValid();
-            PyPrintAccumulator printAccum;
-            mlirOperationPrint(self.get(), printAccum.getCallback(),
-                               printAccum.getUserData());
-            return printAccum.join();
+            return self.getAsm(/*binary=*/false,
+                               /*largeElementsLimit=*/llvm::None,
+                               /*enableDebugInfo=*/false,
+                               /*prettyDebugInfo=*/false,
+                               /*printGenericOpForm=*/false,
+                               /*useLocalScope=*/false);
           },
-          "Returns the assembly form of the operation.");
+          "Returns the assembly form of the operation.")
+      .def("print", &PyOperation::print,
+           // Careful: Lots of arguments must match up with print method.
+           py::arg("file") = py::none(), py::arg("binary") = false,
+           py::arg("large_elements_limit") = py::none(),
+           py::arg("enable_debug_info") = false,
+           py::arg("pretty_debug_info") = false,
+           py::arg("print_generic_op_form") = false,
+           py::arg("use_local_scope") = false, kOperationPrintDocstring)
+      .def("get_asm", &PyOperation::getAsm,
+           // Careful: Lots of arguments must match up with get_asm method.
+           py::arg("binary") = false,
+           py::arg("large_elements_limit") = py::none(),
+           py::arg("enable_debug_info") = false,
+           py::arg("pretty_debug_info") = false,
+           py::arg("print_generic_op_form") = false,
+           py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
 
   // Mapping of PyRegion.
   py::class_<PyRegion>(m, "Region")

diff  --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h
index 947b7343e35a..b438e8ac408d 100644
--- a/mlir/lib/Bindings/Python/IRModules.h
+++ b/mlir/lib/Bindings/Python/IRModules.h
@@ -277,6 +277,15 @@ class PyOperation : public BaseContextObject {
   }
   void checkValid();
 
+  /// Implements the bound 'print' method and helps with others.
+  void print(pybind11::object fileObject, bool binary,
+             llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
+             bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
+  pybind11::object getAsm(bool binary,
+                          llvm::Optional<int64_t> largeElementsLimit,
+                          bool enableDebugInfo, bool prettyDebugInfo,
+                          bool printGenericOpForm, bool useLocalScope);
+
 private:
   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
   static PyOperationRef createInstance(PyMlirContextRef contextRef,

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 104f6fda5c02..379770c8962f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -74,6 +74,36 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) {
   return wrap(unwrap(dialect)->getNamespace());
 }
 
+/* ========================================================================== */
+/* Printing flags API.                                                        */
+/* ========================================================================== */
+
+MlirOpPrintingFlags mlirOpPrintingFlagsCreate() {
+  return wrap(new OpPrintingFlags());
+}
+
+void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags) {
+  delete unwrap(flags);
+}
+
+void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags,
+                                                intptr_t largeElementLimit) {
+  unwrap(flags)->elideLargeElementsAttrs(largeElementLimit);
+}
+
+void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags,
+                                        int prettyForm) {
+  unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm);
+}
+
+void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) {
+  unwrap(flags)->printGenericOpForm();
+}
+
+void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
+  unwrap(flags)->useLocalScope();
+}
+
 /* ========================================================================== */
 /* Location API.                                                              */
 /* ========================================================================== */
@@ -282,6 +312,13 @@ void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
   stream.flush();
 }
 
+void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
+                                 MlirStringCallback callback, void *userData) {
+  detail::CallbackOstream stream(callback, userData);
+  unwrap(op)->print(stream, *unwrap(flags));
+  stream.flush();
+}
+
 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
 
 /* ========================================================================== */

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index e4dc71ac26ef..84f303ca570b 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -1,6 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+import io
 import itertools
 import mlir
 
@@ -248,3 +249,44 @@ def testOperationResultList():
 
 
 run(testOperationResultList)
+
+
+# CHECK-LABEL: TEST: testOperationPrint
+def testOperationPrint():
+  ctx = mlir.ir.Context()
+  module = ctx.parse_module(r"""
+    func @f1(%arg0: i32) -> i32 {
+      %0 = constant dense<[1, 2, 3, 4]> : tensor<4xi32>
+      return %arg0 : i32
+    }
+  """)
+
+  # Test print to stdout.
+  # CHECK: return %arg0 : i32
+  module.operation.print()
+
+  # Test print to text file.
+  f = io.StringIO()
+  # CHECK: <class 'str'>
+  # CHECK: return %arg0 : i32
+  module.operation.print(file=f)
+  str_value = f.getvalue()
+  print(str_value.__class__)
+  print(f.getvalue())
+
+  # Test print to binary file.
+  f = io.BytesIO()
+  # CHECK: <class 'bytes'>
+  # CHECK: return %arg0 : i32
+  module.operation.print(file=f, binary=True)
+  bytes_value = f.getvalue()
+  print(bytes_value.__class__)
+  print(bytes_value)
+
+  # Test get_asm with options.
+  # CHECK: value = opaque<"", "0xDEADBEEF"> : tensor<4xi32>
+  # CHECK: "std.return"(%arg0) : (i32) -> () -:4:7
+  module.operation.print(large_elements_limit=2, enable_debug_info=True,
+      pretty_debug_info=True, print_generic_op_form=True, use_local_scope=True)
+
+run(testOperationPrint)

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 7c86f403b339..fa9a6258a472 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -10,9 +10,9 @@
 /* RUN: mlir-capi-ir-test 2>&1 | FileCheck %s
  */
 
+#include "mlir-c/IR.h"
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/Diagnostics.h"
-#include "mlir-c/IR.h"
 #include "mlir-c/Registration.h"
 #include "mlir-c/StandardAttributes.h"
 #include "mlir-c/StandardDialect.h"
@@ -319,6 +319,25 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   fprintf(stderr, "Removed attr is null: %d\n",
           mlirAttributeIsNull(
               mlirOperationGetAttributeByName(operation, "custom_attr")));
+
+  // Add a large attribute to verify printing flags.
+  int64_t eltsShape[] = {4};
+  int32_t eltsData[] = {1, 2, 3, 4};
+  mlirOperationSetAttributeByName(
+      operation, "elts",
+      mlirDenseElementsAttrInt32Get(
+          mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4,
+          eltsData));
+  MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
+  mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2);
+  mlirOpPrintingFlagsPrintGenericOpForm(flags);
+  mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/0);
+  mlirOpPrintingFlagsUseLocalScope(flags);
+  fprintf(stderr, "Op print with all flags: ");
+  mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  mlirOpPrintingFlagsDestroy(flags);
 }
 
 /// Creates an operation with a region containing multiple blocks with
@@ -991,6 +1010,7 @@ int main() {
   // CHECK: Remove attr: 1
   // CHECK: Remove attr again: 0
   // CHECK: Removed attr is null: 1
+  // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown)
   // clang-format on
 
   mlirModuleDestroy(moduleOp);


        


More information about the Mlir-commits mailing list