[Mlir-commits] [mlir] 321aa19 - [mlir] Expose printing functions in C API

Alex Zinenko llvmlistbot at llvm.org
Wed Aug 12 04:07:42 PDT 2020


Author: Alex Zinenko
Date: 2020-08-12T13:07:34+02:00
New Revision: 321aa19ec8ede62325b7e07d3fef4d12859275ab

URL: https://github.com/llvm/llvm-project/commit/321aa19ec8ede62325b7e07d3fef4d12859275ab
DIFF: https://github.com/llvm/llvm-project/commit/321aa19ec8ede62325b7e07d3fef4d12859275ab.diff

LOG: [mlir] Expose printing functions in C API

Provide printing functions for most IR objects in C API (except Region that
does not have a `print` function, and Module that is expected to be printed as
Operation instead). The printing is based on a callback that is called with
chunks of the string representation and forwarded user-defined data.

Reviewed By: stellaraccident, Jing, mehdi_amini

Differential Revision: https://reviews.llvm.org/D85748

Added: 
    

Modified: 
    mlir/docs/CAPI.md
    mlir/include/mlir-c/IR.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index 73f053db9eea..6adb9db3331c 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -71,10 +71,31 @@ owned by the `MlirContext` in which they were created.
 ### Nullity
 
 A handle may refer to a _null_ object. It is the responsibility of the caller to
-check if an object is null by using `MlirXIsNull(MlirX)`. API functions do _not_
+check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
 expect null objects as arguments unless explicitly stated otherwise. API
 functions _may_ return null objects.
 
+### Conversion To String and Printing
+
+IR objects can be converted to a string representation, for example for
+printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These
+functions accept take arguments a callback with signature `void (*)(const char
+*, intptr_t, void *)` and a pointer to user-defined data. They call the callback
+and supply it with chunks of the string representation, provided as a pointer to
+the first character and a length, and forward the user-defined data unmodified.
+It is up to the caller to allocate memory if the string representation must be
+stored and perform the copy. There is no guarantee that the pointer supplied to
+the callback points to a null-terminated string, the size argument should be
+used to find the end of the string. The callback may be called multiple times
+with consecutive chunks of the string representation (the printing itself is
+bufferred).
+
+*Rationale*: this approach allows the caller to have full control of the
+allocation and avoid unnecessary allocation and copying inside the printer.
+
+For convenience, `mlirXDump(MlirX)` functions are provided to print the given
+object to the standard error stream.
+
 ### Common Patterns
 
 The API adopts the following patterns for recurrent functionality in MLIR.

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 2aff1226fc0e..6b5be2d0195b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -60,7 +60,7 @@ DEFINE_C_API_STRUCT(MlirModule, const void);
 
 /** Named MLIR attribute.
  *
- * A named attribute is essentially a (name, attrbute) pair where the name is
+ * A named attribute is essentially a (name, attribute) pair where the name is
  * a string.
  */
 struct MlirNamedAttribute {
@@ -69,6 +69,17 @@ struct MlirNamedAttribute {
 };
 typedef struct MlirNamedAttribute MlirNamedAttribute;
 
+/** A callback for printing to IR objects.
+ *
+ * This function is called back by the printing functions with the following
+ * arguments:
+ *   - a pointer to the beginning of a string;
+ *   - the length of the string (the pointer may point to a larger buffer, not
+ *     necessarily null-terminated);
+ *   - a pointer to user data forwarded from the printing call.
+ */
+typedef void (*MlirPrintCallback)(const char *, intptr_t, void *);
+
 /*============================================================================*/
 /* Context API.                                                               */
 /*============================================================================*/
@@ -91,6 +102,12 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
 /** Creates a location with unknown position owned by the given context. */
 MlirLocation mlirLocationUnknownGet(MlirContext context);
 
+/** Prints a location by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+                       void *userData);
+
 /*============================================================================*/
 /* Module API.                                                                */
 /*============================================================================*/
@@ -202,6 +219,14 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
 /** Returns an attrbute attached to the operation given its name. */
 MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
                                               const char *name);
+
+/** Prints an operation by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+                        void *userData);
+
+/** Prints an operation to stderr. */
 void mlirOperationDump(MlirOperation op);
 
 /*============================================================================*/
@@ -263,6 +288,12 @@ intptr_t mlirBlockGetNumArguments(MlirBlock block);
 /** Returns `pos`-th argument of the block. */
 MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
 
+/** Prints a block by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+                    void *userData);
+
 /*============================================================================*/
 /* Value API.                                                                 */
 /*============================================================================*/
@@ -270,6 +301,12 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
 /** Returns the type of the value. */
 MlirType mlirValueGetType(MlirValue value);
 
+/** Prints a value by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+                    void *userData);
+
 /*============================================================================*/
 /* Type API.                                                                  */
 /*============================================================================*/
@@ -277,6 +314,11 @@ MlirType mlirValueGetType(MlirValue value);
 /** Parses a type. The type is owned by the context. */
 MlirType mlirTypeParseGet(MlirContext context, const char *type);
 
+/** Prints a location by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData);
+
 /** Prints the type to the standard error stream. */
 void mlirTypeDump(MlirType type);
 
@@ -287,6 +329,12 @@ void mlirTypeDump(MlirType type);
 /** Parses an attribute. The attribute is owned by the context. */
 MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
 
+/** Prints an attribute by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+                        void *userData);
+
 /** Prints the attrbute to the standard error stream. */
 void mlirAttributeDump(MlirAttribute attr);
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0d2a544aaa2e..4ccfb45f2c43 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Parser.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 
@@ -56,6 +57,33 @@ static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
   return storage;
 }
 
+/* ========================================================================== */
+/* Printing helper.                                                           */
+/* ========================================================================== */
+
+namespace {
+/// A simple raw ostream subclass that forwards write_impl calls to the
+/// user-supplied callback together with opaque user-supplied data.
+class CallbackOstream : public llvm::raw_ostream {
+public:
+  CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
+                  void *opaqueData)
+      : callback(callback), opaqueData(opaqueData), pos(0u) {}
+
+  void write_impl(const char *ptr, size_t size) override {
+    callback(ptr, size, opaqueData);
+    pos += size;
+  }
+
+  uint64_t current_pos() const override { return pos; }
+
+private:
+  std::function<void(const char *, intptr_t, void *)> callback;
+  void *opaqueData;
+  uint64_t pos;
+};
+} // end namespace
+
 /* ========================================================================== */
 /* Context API.                                                               */
 /* ========================================================================== */
@@ -81,6 +109,13 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
   return wrap(UnknownLoc::get(unwrap(context)));
 }
 
+void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+                       void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(location).print(stream);
+  stream.flush();
+}
+
 /* ========================================================================== */
 /* Module API.                                                                */
 /* ========================================================================== */
@@ -239,6 +274,13 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
   return wrap(unwrap(op)->getAttr(name));
 }
 
+void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+                        void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(op)->print(stream);
+  stream.flush();
+}
+
 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
 
 /* ========================================================================== */
@@ -314,6 +356,13 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
   return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
 }
 
+void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+                    void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(block)->print(stream);
+  stream.flush();
+}
+
 /* ========================================================================== */
 /* Value API.                                                                 */
 /* ========================================================================== */
@@ -322,6 +371,13 @@ MlirType mlirValueGetType(MlirValue value) {
   return wrap(unwrap(value).getType());
 }
 
+void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+                    void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(value).print(stream);
+  stream.flush();
+}
+
 /* ========================================================================== */
 /* Type API.                                                                  */
 /* ========================================================================== */
@@ -330,6 +386,12 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
   return wrap(mlir::parseType(type, unwrap(context)));
 }
 
+void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(type).print(stream);
+  stream.flush();
+}
+
 void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
 
 /* ========================================================================== */
@@ -340,6 +402,13 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
   return wrap(mlir::parseAttribute(attr, unwrap(context)));
 }
 
+void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+                        void *userData) {
+  CallbackOstream stream(callback, userData);
+  unwrap(attr).print(stream);
+  stream.flush();
+}
+
 void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
 
 MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index b8c5e0d6e76b..d6ab3513384f 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -197,11 +197,47 @@ void collectStats(MlirOperation operation) {
     head = next;
   } while (head);
 
-  printf("Number of operations: %u\n", stats.numOperations);
-  printf("Number of attributes: %u\n", stats.numAttributes);
-  printf("Number of blocks: %u\n", stats.numBlocks);
-  printf("Number of regions: %u\n", stats.numRegions);
-  printf("Number of values: %u\n", stats.numValues);
+  fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
+  fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
+  fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
+  fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
+  fprintf(stderr, "Number of values: %u\n", stats.numValues);
+}
+
+static void printToStderr(const char *str, intptr_t len, void *userData) {
+  (void)userData;
+  fwrite(str, 1, len, stderr);
+}
+
+static void printFirstOfEach(MlirOperation operation) {
+  // Assuming we are given a module, go to the first operation of the first
+  // function.
+  MlirRegion region = mlirOperationGetRegion(operation, 0);
+  MlirBlock block = mlirRegionGetFirstBlock(region);
+  operation = mlirBlockGetFirstOperation(block);
+  region = mlirOperationGetRegion(operation, 0);
+  block = mlirRegionGetFirstBlock(region);
+  operation = mlirBlockGetFirstOperation(block);
+
+  // In the module we created, the first operation of the first function is an
+  // "std.dim", which has an attribute an a single result that we can use to
+  // test the printing mechanism.
+  mlirBlockPrint(block, printToStderr, NULL);
+  fprintf(stderr, "\n");
+  mlirOperationPrint(operation, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
+  mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  MlirValue value = mlirOperationGetResult(operation, 0);
+  mlirValuePrint(value, printToStderr, NULL);
+  fprintf(stderr, "\n");
+
+  MlirType type = mlirValueGetType(value);
+  mlirTypePrint(type, printToStderr, NULL);
+  fprintf(stderr, "\n");
 }
 
 int main() {
@@ -238,6 +274,24 @@ int main() {
   // CHECK: Number of values: 9
   // clang-format on
 
+  printFirstOfEach(module);
+  // clang-format off
+  // CHECK:   %[[C0:.*]] = constant 0 : index
+  // CHECK:   %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
+  // CHECK:   %[[C1:.*]] = constant 1 : index
+  // CHECK:   scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
+  // CHECK:     %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
+  // CHECK:     %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
+  // CHECK:     %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
+  // CHECK:     store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
+  // CHECK:   }
+  // CHECK: return
+  // CHECK: constant 0 : index
+  // CHECK: 0 : index
+  // CHECK: constant 0 : index
+  // CHECK: index
+  // clang-format on
+
   mlirModuleDestroy(moduleOp);
   mlirContextDestroy(ctx);
 


        


More information about the Mlir-commits mailing list