[Mlir-commits] [mlir] 321aa19 - [mlir] Expose printing functions in C API
Alex Zinenko
llvmlistbot at llvm.org
Wed Aug 12 04:07:42 PDT 2020
Author: Alex Zinenko
Date: 2020-08-12T13:07:34+02:00
New Revision: 321aa19ec8ede62325b7e07d3fef4d12859275ab
URL: https://github.com/llvm/llvm-project/commit/321aa19ec8ede62325b7e07d3fef4d12859275ab
DIFF: https://github.com/llvm/llvm-project/commit/321aa19ec8ede62325b7e07d3fef4d12859275ab.diff
LOG: [mlir] Expose printing functions in C API
Provide printing functions for most IR objects in C API (except Region that
does not have a `print` function, and Module that is expected to be printed as
Operation instead). The printing is based on a callback that is called with
chunks of the string representation and forwarded user-defined data.
Reviewed By: stellaraccident, Jing, mehdi_amini
Differential Revision: https://reviews.llvm.org/D85748
Added:
Modified:
mlir/docs/CAPI.md
mlir/include/mlir-c/IR.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index 73f053db9eea..6adb9db3331c 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -71,10 +71,31 @@ owned by the `MlirContext` in which they were created.
### Nullity
A handle may refer to a _null_ object. It is the responsibility of the caller to
-check if an object is null by using `MlirXIsNull(MlirX)`. API functions do _not_
+check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
expect null objects as arguments unless explicitly stated otherwise. API
functions _may_ return null objects.
+### Conversion To String and Printing
+
+IR objects can be converted to a string representation, for example for
+printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These
+functions accept take arguments a callback with signature `void (*)(const char
+*, intptr_t, void *)` and a pointer to user-defined data. They call the callback
+and supply it with chunks of the string representation, provided as a pointer to
+the first character and a length, and forward the user-defined data unmodified.
+It is up to the caller to allocate memory if the string representation must be
+stored and perform the copy. There is no guarantee that the pointer supplied to
+the callback points to a null-terminated string, the size argument should be
+used to find the end of the string. The callback may be called multiple times
+with consecutive chunks of the string representation (the printing itself is
+bufferred).
+
+*Rationale*: this approach allows the caller to have full control of the
+allocation and avoid unnecessary allocation and copying inside the printer.
+
+For convenience, `mlirXDump(MlirX)` functions are provided to print the given
+object to the standard error stream.
+
### Common Patterns
The API adopts the following patterns for recurrent functionality in MLIR.
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 2aff1226fc0e..6b5be2d0195b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -60,7 +60,7 @@ DEFINE_C_API_STRUCT(MlirModule, const void);
/** Named MLIR attribute.
*
- * A named attribute is essentially a (name, attrbute) pair where the name is
+ * A named attribute is essentially a (name, attribute) pair where the name is
* a string.
*/
struct MlirNamedAttribute {
@@ -69,6 +69,17 @@ struct MlirNamedAttribute {
};
typedef struct MlirNamedAttribute MlirNamedAttribute;
+/** A callback for printing to IR objects.
+ *
+ * This function is called back by the printing functions with the following
+ * arguments:
+ * - a pointer to the beginning of a string;
+ * - the length of the string (the pointer may point to a larger buffer, not
+ * necessarily null-terminated);
+ * - a pointer to user data forwarded from the printing call.
+ */
+typedef void (*MlirPrintCallback)(const char *, intptr_t, void *);
+
/*============================================================================*/
/* Context API. */
/*============================================================================*/
@@ -91,6 +102,12 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context,
/** Creates a location with unknown position owned by the given context. */
MlirLocation mlirLocationUnknownGet(MlirContext context);
+/** Prints a location by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+ void *userData);
+
/*============================================================================*/
/* Module API. */
/*============================================================================*/
@@ -202,6 +219,14 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos);
/** Returns an attrbute attached to the operation given its name. */
MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
const char *name);
+
+/** Prints an operation by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+ void *userData);
+
+/** Prints an operation to stderr. */
void mlirOperationDump(MlirOperation op);
/*============================================================================*/
@@ -263,6 +288,12 @@ intptr_t mlirBlockGetNumArguments(MlirBlock block);
/** Returns `pos`-th argument of the block. */
MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
+/** Prints a block by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+ void *userData);
+
/*============================================================================*/
/* Value API. */
/*============================================================================*/
@@ -270,6 +301,12 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
/** Returns the type of the value. */
MlirType mlirValueGetType(MlirValue value);
+/** Prints a value by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+ void *userData);
+
/*============================================================================*/
/* Type API. */
/*============================================================================*/
@@ -277,6 +314,11 @@ MlirType mlirValueGetType(MlirValue value);
/** Parses a type. The type is owned by the context. */
MlirType mlirTypeParseGet(MlirContext context, const char *type);
+/** Prints a location by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData);
+
/** Prints the type to the standard error stream. */
void mlirTypeDump(MlirType type);
@@ -287,6 +329,12 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
+/** Prints an attribute by sending chunks of the string representation and
+ * forwarding `userData to `callback`. Note that the callback may be called
+ * several times with consecutive chunks of the string. */
+void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+ void *userData);
+
/** Prints the attrbute to the standard error stream. */
void mlirAttributeDump(MlirAttribute attr);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 0d2a544aaa2e..4ccfb45f2c43 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
#include "mlir/Parser.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -56,6 +57,33 @@ static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
return storage;
}
+/* ========================================================================== */
+/* Printing helper. */
+/* ========================================================================== */
+
+namespace {
+/// A simple raw ostream subclass that forwards write_impl calls to the
+/// user-supplied callback together with opaque user-supplied data.
+class CallbackOstream : public llvm::raw_ostream {
+public:
+ CallbackOstream(std::function<void(const char *, intptr_t, void *)> callback,
+ void *opaqueData)
+ : callback(callback), opaqueData(opaqueData), pos(0u) {}
+
+ void write_impl(const char *ptr, size_t size) override {
+ callback(ptr, size, opaqueData);
+ pos += size;
+ }
+
+ uint64_t current_pos() const override { return pos; }
+
+private:
+ std::function<void(const char *, intptr_t, void *)> callback;
+ void *opaqueData;
+ uint64_t pos;
+};
+} // end namespace
+
/* ========================================================================== */
/* Context API. */
/* ========================================================================== */
@@ -81,6 +109,13 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
+void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+ void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(location).print(stream);
+ stream.flush();
+}
+
/* ========================================================================== */
/* Module API. */
/* ========================================================================== */
@@ -239,6 +274,13 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name));
}
+void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+ void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(op)->print(stream);
+ stream.flush();
+}
+
void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
/* ========================================================================== */
@@ -314,6 +356,13 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
}
+void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+ void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(block)->print(stream);
+ stream.flush();
+}
+
/* ========================================================================== */
/* Value API. */
/* ========================================================================== */
@@ -322,6 +371,13 @@ MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
+void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+ void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(value).print(stream);
+ stream.flush();
+}
+
/* ========================================================================== */
/* Type API. */
/* ========================================================================== */
@@ -330,6 +386,12 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
return wrap(mlir::parseType(type, unwrap(context)));
}
+void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(type).print(stream);
+ stream.flush();
+}
+
void mlirTypeDump(MlirType type) { unwrap(type).dump(); }
/* ========================================================================== */
@@ -340,6 +402,13 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context)));
}
+void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+ void *userData) {
+ CallbackOstream stream(callback, userData);
+ unwrap(attr).print(stream);
+ stream.flush();
+}
+
void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); }
MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) {
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index b8c5e0d6e76b..d6ab3513384f 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -197,11 +197,47 @@ void collectStats(MlirOperation operation) {
head = next;
} while (head);
- printf("Number of operations: %u\n", stats.numOperations);
- printf("Number of attributes: %u\n", stats.numAttributes);
- printf("Number of blocks: %u\n", stats.numBlocks);
- printf("Number of regions: %u\n", stats.numRegions);
- printf("Number of values: %u\n", stats.numValues);
+ fprintf(stderr, "Number of operations: %u\n", stats.numOperations);
+ fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes);
+ fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks);
+ fprintf(stderr, "Number of regions: %u\n", stats.numRegions);
+ fprintf(stderr, "Number of values: %u\n", stats.numValues);
+}
+
+static void printToStderr(const char *str, intptr_t len, void *userData) {
+ (void)userData;
+ fwrite(str, 1, len, stderr);
+}
+
+static void printFirstOfEach(MlirOperation operation) {
+ // Assuming we are given a module, go to the first operation of the first
+ // function.
+ MlirRegion region = mlirOperationGetRegion(operation, 0);
+ MlirBlock block = mlirRegionGetFirstBlock(region);
+ operation = mlirBlockGetFirstOperation(block);
+ region = mlirOperationGetRegion(operation, 0);
+ block = mlirRegionGetFirstBlock(region);
+ operation = mlirBlockGetFirstOperation(block);
+
+ // In the module we created, the first operation of the first function is an
+ // "std.dim", which has an attribute an a single result that we can use to
+ // test the printing mechanism.
+ mlirBlockPrint(block, printToStderr, NULL);
+ fprintf(stderr, "\n");
+ mlirOperationPrint(operation, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0);
+ mlirAttributePrint(namedAttr.attribute, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ MlirValue value = mlirOperationGetResult(operation, 0);
+ mlirValuePrint(value, printToStderr, NULL);
+ fprintf(stderr, "\n");
+
+ MlirType type = mlirValueGetType(value);
+ mlirTypePrint(type, printToStderr, NULL);
+ fprintf(stderr, "\n");
}
int main() {
@@ -238,6 +274,24 @@ int main() {
// CHECK: Number of values: 9
// clang-format on
+ printFirstOfEach(module);
+ // clang-format off
+ // CHECK: %[[C0:.*]] = constant 0 : index
+ // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref<?xf32>
+ // CHECK: %[[C1:.*]] = constant 1 : index
+ // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] {
+ // CHECK: %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
+ // CHECK: %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref<?xf32>
+ // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32
+ // CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref<?xf32>
+ // CHECK: }
+ // CHECK: return
+ // CHECK: constant 0 : index
+ // CHECK: 0 : index
+ // CHECK: constant 0 : index
+ // CHECK: index
+ // clang-format on
+
mlirModuleDestroy(moduleOp);
mlirContextDestroy(ctx);
More information about the Mlir-commits
mailing list