[Mlir-commits] [mlir] da56297 - [mlir] expose standard attributes to C API
Alex Zinenko
llvmlistbot at llvm.org
Wed Aug 19 09:50:26 PDT 2020
Author: Alex Zinenko
Date: 2020-08-19T18:50:19+02:00
New Revision: da562974628017ae92c451ca064fea5b59ad71a4
URL: https://github.com/llvm/llvm-project/commit/da562974628017ae92c451ca064fea5b59ad71a4
DIFF: https://github.com/llvm/llvm-project/commit/da562974628017ae92c451ca064fea5b59ad71a4.diff
LOG: [mlir] expose standard attributes to C API
Provide C API for MLIR standard attributes. Since standard attributes live
under lib/IR in core MLIR, place the C APIs in the IR library as well (standard
ops will go in a separate library).
Affine map and integer set attributes are only exposed as placeholder types
with IsA support due to the lack of C APIs for the corresponding types.
Integer and floating point attribute APIs expecting APInt and APFloat are not
exposed pending decision on how to support APInt and APFloat.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D86143
Added:
mlir/include/mlir-c/StandardAttributes.h
mlir/lib/CAPI/IR/StandardAttributes.cpp
Modified:
mlir/docs/CAPI.md
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRModules.cpp
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index a8fcfbafb8b1..68a28950ebc3 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -97,10 +97,27 @@ as follows.
its first argument is `Y`, and it is the responsibility of the caller to
ensure it is indeed the case.
+### Returning String References
+
+Numerous MLIR functions return instances of `StringRef` to refer to a non-owning
+segment of a string. This segment may or may not be null-terminated. In C API,
+these functions take an additional callback argument of type
+`MlirStringCallback` (pointer to a function with signature `void (*)(const char
+*, intptr_t, void *)`) and a pointer to user-defined data. This callback is
+invoked with a pointer to the string segment, its size and is forwarded the
+user-defined data. The caller is in charge of managing the string segment
+according to its memory model: for strings owned by the object (e.g., string
+attributes), the caller can store the pointer and the size and use them directly
+as long as the parent object is live or copy the string to a new location with a
+null terminator if expected; for generated strings (e.g., in printing), the
+caller is expected to copy the string segment if it intends to use it later.
+
+**Note:** this interface may be revised in the near future.
+
### Conversion To String and Printing
IR objects can be converted to a string representation, for example for
-printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These
+printing, using `mlirXPrint(MlirX, MlirStringCallback, void *)` functions. These
functions accept take arguments a callback with signature `void (*)(const char
*, intptr_t, void *)` and a pointer to user-defined data. They call the callback
and supply it with chunks of the string representation, provided as a pointer to
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d97491b9f08a..9293a40ebbab 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -67,16 +67,16 @@ struct MlirNamedAttribute {
};
typedef struct MlirNamedAttribute MlirNamedAttribute;
-/** A callback for printing to IR objects.
+/** A callback for returning string referenes.
*
- * This function is called back by the printing functions with the following
- * arguments:
+ * This function is called back by the functions that need to return a reference
+ * to the portion of the string with the following arguments:
* - a pointer to the beginning of a string;
* - the length of the string (the pointer may point to a larger buffer, not
* necessarily null-terminated);
* - a pointer to user data forwarded from the printing call.
*/
-typedef void (*MlirPrintCallback)(const char *, intptr_t, void *);
+typedef void (*MlirStringCallback)(const char *, intptr_t, void *);
/*============================================================================*/
/* Context API. */
@@ -103,7 +103,7 @@ MlirLocation mlirLocationUnknownGet(MlirContext context);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData);
/*============================================================================*/
@@ -224,7 +224,7 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
/** Prints an operation by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData);
/** Prints an operation to stderr. */
@@ -292,7 +292,7 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos);
/** Prints a block by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
void *userData);
/*============================================================================*/
@@ -305,7 +305,7 @@ MlirType mlirValueGetType(MlirValue value);
/** Prints a value by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+void mlirValuePrint(MlirValue value, MlirStringCallback callback,
void *userData);
/*============================================================================*/
@@ -324,7 +324,7 @@ int mlirTypeEqual(MlirType t1, MlirType t2);
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData);
+void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData);
/** Prints the type to the standard error stream. */
void mlirTypeDump(MlirType type);
@@ -336,10 +336,13 @@ void mlirTypeDump(MlirType type);
/** Parses an attribute. The attribute is owned by the context. */
MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr);
+/** Checks if two attributes are equal. */
+int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2);
+
/** Prints an attribute by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
-void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
void *userData);
/** Prints the attrbute to the standard error stream. */
diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h
new file mode 100644
index 000000000000..ab8d837aeeb8
--- /dev/null
+++ b/mlir/include/mlir-c/StandardAttributes.h
@@ -0,0 +1,442 @@
+/*===-- mlir-c/StandardAttributes.h - C API for Std Attributes-----*- C -*-===*\
+|* *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
+|* *|
+|*===----------------------------------------------------------------------===*|
+|* *|
+|* This header declares the C interface to MLIR Standard attributes. *|
+|* *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_STANDARDATTRIBUTES_H
+#define MLIR_C_STANDARDATTRIBUTES_H
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*============================================================================*/
+/* Affine map attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is an affine map attribute. */
+int mlirAttributeIsAAffineMap(MlirAttribute attr);
+
+/** Creates an affine map attribute wrapping the given map. The attribute
+ * belongs to the same context as the affine map. */
+MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map);
+
+/** Returns the affine map wrapped in the given affine map attribute. */
+MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr);
+
+/*============================================================================*/
+/* Array attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is an array attribute. */
+int mlirAttributeIsAArray(MlirAttribute attr);
+
+/** Creates an array element containing the given list of elements in the given
+ * context. */
+MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
+ MlirAttribute *elements);
+
+/** Returns the number of elements stored in the given array attribute. */
+intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr);
+
+/** Returns pos-th element stored in the given array attribute. */
+MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos);
+
+/*============================================================================*/
+/* Dictionary attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a dictionary attribute. */
+int mlirAttributeIsADictionary(MlirAttribute attr);
+
+/** Creates a dictionary attribute containing the given list of elements in the
+ * provided context. */
+MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
+ MlirNamedAttribute *elements);
+
+/** Returns the number of attributes contained in a dictionary attribute. */
+intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr);
+
+/** Returns pos-th element of the given dictionary attribute. */
+MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
+ intptr_t pos);
+
+/** Returns the dictionary attribute element with the given name or NULL if the
+ * given name does not exist in the dictionary. */
+MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
+ const char *name);
+
+/*============================================================================*/
+/* Floating point attribute. */
+/*============================================================================*/
+
+/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
+ * relevant functions here. */
+
+/** Checks whether the given attribute is a floating point attribute. */
+int mlirAttributeIsAFloat(MlirAttribute attr);
+
+/** Creates a floating point attribute in the given context with the given
+ * double value and double-precision FP semantics. */
+MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
+ double value);
+
+/** Returns the value stored in the given floating point attribute, interpreting
+ * the value as double. */
+double mlirFloatAttrGetValueDouble(MlirAttribute attr);
+
+/*============================================================================*/
+/* Integer attribute. */
+/*============================================================================*/
+
+/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
+ * relevant functions here. */
+
+/** Checks whether the given attribute is an integer attribute. */
+int mlirAttributeIsAInteger(MlirAttribute attr);
+
+/** Creates an integer attribute of the given type with the given integer
+ * value. */
+MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value);
+
+/** Returns the value stored in the given integer attribute, assuming the value
+ * fits into a 64-bit integer. */
+int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr);
+
+/*============================================================================*/
+/* Bool attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a bool attribute. */
+int mlirAttributeIsABool(MlirAttribute attr);
+
+/** Creates a bool attribute in the given context with the given value. */
+MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value);
+
+/** Returns the value stored in the given bool attribute. */
+int mlirBoolAttrGetValue(MlirAttribute attr);
+
+/*============================================================================*/
+/* Integer set attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is an integer set attribute. */
+int mlirAttributeIsAIntegerSet(MlirAttribute attr);
+
+/*============================================================================*/
+/* Opaque attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is an opaque attribute. */
+int mlirAttributeIsAOpaque(MlirAttribute attr);
+
+/** Creates an opaque attribute in the given context associated with the dialect
+ * identified by its namespace. The attribute contains opaque byte data of the
+ * specified length (data need not be null-terminated). */
+MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace,
+ intptr_t dataLength, const char *data,
+ MlirType type);
+
+/** Returns the namepsace of the dialect with which the given opaque attribute
+ * is associated. The namespace string is owned by the context. */
+const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr);
+
+/** Calls the provided callback with the opaque byte data stored in the given
+ * opaque attribute. The callback is invoked once, and the data it receives is
+ * not necessarily null terminated. The data remains live as long as the context
+ * in which the attribute lives. */
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback,
+ void *userData);
+
+/*============================================================================*/
+/* String attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a string attribute. */
+int mlirAttributeIsAString(MlirAttribute attr);
+
+/** Creates a string attribute in the given context containing the given string.
+ * The string need not be null-terminated and its length must be specified. */
+MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length,
+ const char *data);
+
+/** Creates a string attribute in the given context containing the given string.
+ * The string need not be null-terminated and its length must be specified.
+ * Additionally, the attribute has the given type. */
+MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length,
+ const char *data);
+
+/** Calls the provided callback with the string stored in the given string
+ * attribute. The callback is invoked once, and the data it receives is not
+ * necessarily null terminated. The data remains live as long as the context in
+ * which the attribute lives. */
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback,
+ void *userData);
+
+/*============================================================================*/
+/* SymbolRef attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a symbol reference attribute. */
+int mlirAttributeIsASymbolRef(MlirAttribute attr);
+
+/** Creates a symbol reference attribute in the given context referencing a
+ * symbol identified by the given string inside a list of nested references.
+ * Each of the references in the list must not be nested. The string need not be
+ * null-terminated and its length must be specified. */
+MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length,
+ const char *symbol, intptr_t numReferences,
+ MlirAttribute *references);
+
+/** Calls the provided callback with the string containing the root referenced
+ * symbol. The callback is invoked once, and the data it receives is not
+ * necessarily null terminated. The data remains live as long as the context in
+ * which the attribute lives. */
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirSymbolRefAttrGetRootReference(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData);
+
+/** Calls the provided callback with the string containing the leaf referenced
+ * symbol. The callback is invoked once, and the data it receives is not
+ * necessarily null terminated. The data remains live as long as the context in
+ * which the attribute lives. */
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData);
+
+/** Returns the number of references nested in the given symbol reference
+ * attribute. */
+intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr);
+
+/** Returns pos-th reference nested in the given symbol reference attribute. */
+MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
+ intptr_t pos);
+
+/*============================================================================*/
+/* Flat SymbolRef attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a flat symbol reference attribute. */
+int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr);
+
+/** Creates a flat symbol reference attribute in the given context referencing a
+ * symbol identified by the given string. The string need not be null-terminated
+ * and its length must be specified. */
+MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length,
+ const char *symbol);
+
+/** Calls the provided callback with the string containing the referenced
+ * symbol. The callback is invoked once, and the data it receives is not
+ * necessarily null terminated. The data remains live as long as the context in
+ * which the attribute lives. */
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData);
+
+/*============================================================================*/
+/* Type attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a type attribute. */
+int mlirAttributeIsAType(MlirAttribute attr);
+
+/** Creates a type attribute wrapping the given type in the same context as the
+ * type. */
+MlirAttribute mlirTypeAttrGet(MlirType type);
+
+/** Returns the type stored in the given type attribute. */
+MlirType mlirTypeAttrGetValue(MlirAttribute attr);
+
+/*============================================================================*/
+/* Unit attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a unit attribute. */
+int mlirAttributeIsAUnit(MlirAttribute attr);
+
+/** Creates a unit attribute in the given context. */
+MlirAttribute mlirUnitAttrGet(MlirContext ctx);
+
+/*============================================================================*/
+/* Elements attributes. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is an elements attribute. */
+int mlirAttributeIsAElements(MlirAttribute attr);
+
+/** Returns the element at the given rank-dimensional index. */
+MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
+ uint64_t *idxs);
+
+/** Checks whether the given rank-dimensional index is valid in the given
+ * elements attribute. */
+int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
+ uint64_t *idxs);
+
+/** Gets the total number of elements in the given elements attribute. In order
+ * to iterate over the attribute, obtain its type, which must be a statically
+ * shaped type and use its sizes to build a multi-dimensional index. */
+int64_t mlirElementsAttrGetNumElements(MlirAttribute attr);
+
+/*============================================================================*/
+/* Dense elements attribute. */
+/*============================================================================*/
+
+/* TODO: decide on the interface and add support for complex elements. */
+/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the
+ * relevant functions here. */
+
+/** Checks whether the given attribute is a dense elements attribute. */
+int mlirAttributeIsADenseElements(MlirAttribute attr);
+int mlirAttributeIsADenseIntElements(MlirAttribute attr);
+int mlirAttributeIsADenseFPElements(MlirAttribute attr);
+
+/** Creates a dense elements attribute with the given Shaped type and elements
+ * in the same context as the type. */
+MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
+ intptr_t numElements,
+ MlirAttribute *elements);
+
+/** Creates a dense elements attribute with the given Shaped type containing a
+ * single replicated element (splat). */
+MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
+ MlirAttribute element);
+MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
+ int element);
+MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
+ uint32_t element);
+MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
+ int32_t element);
+MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
+ uint64_t element);
+MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
+ int64_t element);
+MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
+ float element);
+MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
+ double element);
+
+/** Creates a dense elements attribute with the given shaped type from elements
+ * of a specific type. Expects the element type of the shaped type to match the
+ * data element type. */
+MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
+ intptr_t numElements, int *elements);
+MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
+ intptr_t numElements,
+ uint32_t *elements);
+MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
+ intptr_t numElements,
+ int32_t *elements);
+MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
+ intptr_t numElements,
+ uint64_t *elements);
+MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
+ intptr_t numElements,
+ int64_t *elements);
+MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
+ intptr_t numElements,
+ float *elements);
+MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
+ intptr_t numElements,
+ double *elements);
+
+/** Creates a dense elements attribute with the given shaped type from string
+ * elements. The strings need not be null-terminated and their lengths are
+ * provided as a separate argument co-indexed with the strs argument. */
+MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
+ intptr_t numElements,
+ intptr_t *strLengths,
+ const char **strs);
+/** Creates a dense elements attribute that has the same data as the given dense
+ * elements attribute and a
diff erent shaped type. The new type must have the
+ * same total number of elements. */
+MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
+ MlirType shapedType);
+
+/** Checks whether the given dense elements attribute contains a single
+ * replicated value (splat). */
+int mlirDenseElementsAttrIsSplat(MlirAttribute attr);
+
+/** Returns the single replicated value (splat) of a specific type contained by
+ * the given dense elements attribute. */
+MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr);
+int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr);
+int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr);
+uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr);
+int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr);
+uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr);
+float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr);
+double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr);
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData);
+
+/** Returns the pos-th value (flat contiguous indexing) of a specific type
+ * contained by the given dense elements attribute. */
+int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos);
+int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
+uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos);
+int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos);
+uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos);
+float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos);
+double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos);
+/* TODO: consider exposing StringRef and using it instead of the callback. */
+void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos,
+ MlirStringCallback callback,
+ void *userData);
+
+/*============================================================================*/
+/* Opaque elements attribute. */
+/*============================================================================*/
+
+/* TODO: expose Dialect to the bindings and implement accessors here. */
+
+/** Checks whether the given attribute is an opaque elements attribute. */
+int mlirAttributeIsAOpaqueElements(MlirAttribute attr);
+
+/*============================================================================*/
+/* Sparse elements attribute. */
+/*============================================================================*/
+
+/** Checks whether the given attribute is a sparse elements attribute. */
+int mlirAttributeIsASparseElements(MlirAttribute attr);
+
+/** Creates a sparse elements attribute of the given shape from a list of
+ * indices and a list of associated values. Both lists are expected to be dense
+ * elements attributes with the same number of elements. The list of indices is
+ * expected to contain 64-bit integers. The attribute is created in the same
+ * context as the type. */
+MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
+ MlirAttribute denseIndices,
+ MlirAttribute denseValues);
+
+/** Returns the dense elements attribute containing 64-bit integer indices of
+ * non-null elements in the given sparse elements attribute. */
+MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr);
+
+/** Returns the dense elements attribute containing the non-null elements in the
+ * given sparse elements attribute. */
+MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_STANDARDATTRIBUTES_H
diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index bdce390188fa..188fdf39ff14 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -55,13 +55,13 @@ static const char kDumpDocstring[] =
namespace {
/// Accumulates into a python string from a method that accepts an
-/// MlirPrintCallback.
+/// MlirStringCallback.
struct PyPrintAccumulator {
py::list parts;
void *getUserData() { return this; }
- MlirPrintCallback getCallback() {
+ MlirStringCallback getCallback() {
return [](const char *part, intptr_t size, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index 64e715e33d88..3e2e3d6a22d8 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRCAPIIR
AffineMap.cpp
IR.cpp
+ StandardAttributes.cpp
StandardTypes.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 9d3028deffd2..23337213e2c7 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -71,7 +71,7 @@ MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(UnknownLoc::get(unwrap(context)));
}
-void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback,
+void mlirLocationPrint(MlirLocation location, MlirStringCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(location).print(stream);
@@ -238,7 +238,7 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op,
return wrap(unwrap(op)->getAttr(name));
}
-void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback,
+void mlirOperationPrint(MlirOperation op, MlirStringCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(op)->print(stream);
@@ -320,7 +320,7 @@ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) {
return wrap(unwrap(block)->getArgument(static_cast<unsigned>(pos)));
}
-void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback,
+void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(block)->print(stream);
@@ -335,7 +335,7 @@ MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
-void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
+void mlirValuePrint(MlirValue value, MlirStringCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(value).print(stream);
@@ -352,7 +352,7 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
-void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
+void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) {
CallbackOstream stream(callback, userData);
unwrap(type).print(stream);
stream.flush();
@@ -368,7 +368,11 @@ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr) {
return wrap(mlir::parseAttribute(attr, unwrap(context)));
}
-void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback,
+int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
+ return unwrap(a1) == unwrap(a2);
+}
+
+void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback,
void *userData) {
CallbackOstream stream(callback, userData);
unwrap(attr).print(stream);
diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp
new file mode 100644
index 000000000000..cade603132dc
--- /dev/null
+++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp
@@ -0,0 +1,561 @@
+//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/StandardAttributes.h"
+#include "mlir/CAPI/AffineMap.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+/*============================================================================*/
+/* Affine map attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAAffineMap(MlirAttribute attr) {
+ return unwrap(attr).isa<AffineMapAttr>();
+}
+
+MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) {
+ return wrap(AffineMapAttr::get(unwrap(map)));
+}
+
+MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) {
+ return wrap(unwrap(attr).cast<AffineMapAttr>().getValue());
+}
+
+/*============================================================================*/
+/* Array attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAArray(MlirAttribute attr) {
+ return unwrap(attr).isa<ArrayAttr>();
+}
+
+MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements,
+ MlirAttribute *elements) {
+ SmallVector<Attribute, 8> attrs;
+ return wrap(ArrayAttr::get(
+ unwrapList(static_cast<size_t>(numElements), elements, attrs),
+ unwrap(ctx)));
+}
+
+intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) {
+ return static_cast<intptr_t>(unwrap(attr).cast<ArrayAttr>().size());
+}
+
+MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) {
+ return wrap(unwrap(attr).cast<ArrayAttr>().getValue()[pos]);
+}
+
+/*============================================================================*/
+/* Dictionary attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsADictionary(MlirAttribute attr) {
+ return unwrap(attr).isa<DictionaryAttr>();
+}
+
+MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements,
+ MlirNamedAttribute *elements) {
+ SmallVector<NamedAttribute, 8> attributes;
+ attributes.reserve(numElements);
+ for (intptr_t i = 0; i < numElements; ++i)
+ attributes.emplace_back(Identifier::get(elements[i].name, unwrap(ctx)),
+ unwrap(elements[i].attribute));
+ return wrap(DictionaryAttr::get(attributes, unwrap(ctx)));
+}
+
+intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) {
+ return static_cast<intptr_t>(unwrap(attr).cast<DictionaryAttr>().size());
+}
+
+MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr,
+ intptr_t pos) {
+ NamedAttribute attribute =
+ unwrap(attr).cast<DictionaryAttr>().getValue()[pos];
+ return {attribute.first.c_str(), wrap(attribute.second)};
+}
+
+MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr,
+ const char *name) {
+ return wrap(unwrap(attr).cast<DictionaryAttr>().get(name));
+}
+
+/*============================================================================*/
+/* Floating point attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAFloat(MlirAttribute attr) {
+ return unwrap(attr).isa<FloatAttr>();
+}
+
+MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type,
+ double value) {
+ return wrap(FloatAttr::get(unwrap(type), value));
+}
+
+double mlirFloatAttrGetValueDouble(MlirAttribute attr) {
+ return unwrap(attr).cast<FloatAttr>().getValueAsDouble();
+}
+
+/*============================================================================*/
+/* Integer attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAInteger(MlirAttribute attr) {
+ return unwrap(attr).isa<IntegerAttr>();
+}
+
+MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) {
+ return wrap(IntegerAttr::get(unwrap(type), value));
+}
+
+int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
+ return unwrap(attr).cast<IntegerAttr>().getInt();
+}
+
+/*============================================================================*/
+/* Bool attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsABool(MlirAttribute attr) {
+ return unwrap(attr).isa<BoolAttr>();
+}
+
+MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) {
+ return wrap(BoolAttr::get(value, unwrap(ctx)));
+}
+
+int mlirBoolAttrGetValue(MlirAttribute attr) {
+ return unwrap(attr).cast<BoolAttr>().getValue();
+}
+
+/*============================================================================*/
+/* Integer set attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAIntegerSet(MlirAttribute attr) {
+ return unwrap(attr).isa<IntegerSetAttr>();
+}
+
+/*============================================================================*/
+/* Opaque attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAOpaque(MlirAttribute attr) {
+ return unwrap(attr).isa<OpaqueAttr>();
+}
+
+MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace,
+ intptr_t dataLength, const char *data,
+ MlirType type) {
+ return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)),
+ StringRef(data, dataLength), unwrap(type),
+ unwrap(ctx)));
+}
+
+const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {
+ return unwrap(attr).cast<OpaqueAttr>().getDialectNamespace().c_str();
+}
+
+void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback,
+ void *userData) {
+ StringRef data = unwrap(attr).cast<OpaqueAttr>().getAttrData();
+ callback(data.data(), static_cast<intptr_t>(data.size()), userData);
+}
+
+/*============================================================================*/
+/* String attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAString(MlirAttribute attr) {
+ return unwrap(attr).isa<StringAttr>();
+}
+
+MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length,
+ const char *data) {
+ return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx)));
+}
+
+MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length,
+ const char *data) {
+ return wrap(StringAttr::get(StringRef(data, length), unwrap(type)));
+}
+
+void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback,
+ void *userData) {
+ StringRef data = unwrap(attr).cast<StringAttr>().getValue();
+ callback(data.data(), static_cast<intptr_t>(data.size()), userData);
+}
+
+/*============================================================================*/
+/* SymbolRef attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsASymbolRef(MlirAttribute attr) {
+ return unwrap(attr).isa<SymbolRefAttr>();
+}
+
+MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length,
+ const char *symbol, intptr_t numReferences,
+ MlirAttribute *references) {
+ SmallVector<FlatSymbolRefAttr, 4> refs;
+ refs.reserve(numReferences);
+ for (intptr_t i = 0; i < numReferences; ++i)
+ refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
+ return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx)));
+}
+
+void mlirSymbolRefAttrGetRootReference(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData) {
+ StringRef ref = unwrap(attr).cast<SymbolRefAttr>().getRootReference();
+ callback(ref.data(), ref.size(), userData);
+}
+
+void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData) {
+ StringRef ref = unwrap(attr).cast<SymbolRefAttr>().getLeafReference();
+ callback(ref.data(), ref.size(), userData);
+}
+
+intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
+ return static_cast<intptr_t>(
+ unwrap(attr).cast<SymbolRefAttr>().getNestedReferences().size());
+}
+
+MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr,
+ intptr_t pos) {
+ return wrap(unwrap(attr).cast<SymbolRefAttr>().getNestedReferences()[pos]);
+}
+
+/*============================================================================*/
+/* Flat SymbolRef attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) {
+ return unwrap(attr).isa<FlatSymbolRefAttr>();
+}
+
+MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length,
+ const char *symbol) {
+ return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx)));
+}
+
+void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData) {
+ StringRef symbol = unwrap(attr).cast<FlatSymbolRefAttr>().getValue();
+ callback(symbol.data(), symbol.size(), userData);
+}
+
+/*============================================================================*/
+/* Type attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAType(MlirAttribute attr) {
+ return unwrap(attr).isa<TypeAttr>();
+}
+
+MlirAttribute mlirTypeAttrGet(MlirType type) {
+ return wrap(TypeAttr::get(unwrap(type)));
+}
+
+MlirType mlirTypeAttrGetValue(MlirAttribute attr) {
+ return wrap(unwrap(attr).cast<TypeAttr>().getValue());
+}
+
+/*============================================================================*/
+/* Unit attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAUnit(MlirAttribute attr) {
+ return unwrap(attr).isa<UnitAttr>();
+}
+
+MlirAttribute mlirUnitAttrGet(MlirContext ctx) {
+ return wrap(UnitAttr::get(unwrap(ctx)));
+}
+
+/*============================================================================*/
+/* Elements attributes. */
+/*============================================================================*/
+
+int mlirAttributeIsAElements(MlirAttribute attr) {
+ return unwrap(attr).isa<ElementsAttr>();
+}
+
+MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank,
+ uint64_t *idxs) {
+ return wrap(unwrap(attr).cast<ElementsAttr>().getValue(
+ llvm::makeArrayRef(idxs, rank)));
+}
+
+int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank,
+ uint64_t *idxs) {
+ return unwrap(attr).cast<ElementsAttr>().isValidIndex(
+ llvm::makeArrayRef(idxs, rank));
+}
+
+int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) {
+ return unwrap(attr).cast<ElementsAttr>().getNumElements();
+}
+
+/*============================================================================*/
+/* Dense elements attribute. */
+/*============================================================================*/
+
+//===----------------------------------------------------------------------===//
+// IsA support.
+
+int mlirAttributeIsADenseElements(MlirAttribute attr) {
+ return unwrap(attr).isa<DenseElementsAttr>();
+}
+int mlirAttributeIsADenseIntElements(MlirAttribute attr) {
+ return unwrap(attr).isa<DenseIntElementsAttr>();
+}
+int mlirAttributeIsADenseFPElements(MlirAttribute attr) {
+ return unwrap(attr).isa<DenseFPElementsAttr>();
+}
+
+//===----------------------------------------------------------------------===//
+// Constructors.
+
+MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType,
+ intptr_t numElements,
+ MlirAttribute *elements) {
+ SmallVector<Attribute, 8> attributes;
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ unwrapList(numElements, elements, attributes)));
+}
+
+MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType,
+ MlirAttribute element) {
+ return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ unwrap(element)));
+}
+MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
+ int element) {
+ return wrap(DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ static_cast<bool>(element)));
+}
+MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
+ uint32_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType,
+ int32_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType,
+ uint64_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType,
+ int64_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType,
+ float element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType,
+ double element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+
+MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType,
+ intptr_t numElements,
+ int *elements) {
+ SmallVector<bool, 8> values(elements, elements + numElements);
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
+}
+
+/// Creates a dense attribute with elements of the type deduced by templates.
+template <typename T>
+static MlirAttribute getDenseAttribute(MlirType shapedType,
+ intptr_t numElements, T *elements) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ llvm::makeArrayRef(elements, numElements)));
+}
+
+MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
+ intptr_t numElements,
+ uint32_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType,
+ intptr_t numElements,
+ int32_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType,
+ intptr_t numElements,
+ uint64_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType,
+ intptr_t numElements,
+ int64_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType,
+ intptr_t numElements,
+ float *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType,
+ intptr_t numElements,
+ double *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+
+MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
+ intptr_t numElements,
+ intptr_t *strLengths,
+ const char **strs) {
+ SmallVector<StringRef, 8> values;
+ values.reserve(numElements);
+ for (intptr_t i = 0; i < numElements; ++i)
+ values.push_back(StringRef(strs[i], strLengths[i]));
+
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), values));
+}
+
+MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr,
+ MlirType shapedType) {
+ return wrap(unwrap(attr).cast<DenseElementsAttr>().reshape(
+ unwrap(shapedType).cast<ShapedType>()));
+}
+
+//===----------------------------------------------------------------------===//
+// Splat accessors.
+
+int mlirDenseElementsAttrIsSplat(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().isSplat();
+}
+
+MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
+ return wrap(unwrap(attr).cast<DenseElementsAttr>().getSplatValue());
+}
+int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
+}
+int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
+}
+uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint32_t>();
+}
+int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int64_t>();
+}
+uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint64_t>();
+}
+float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<float>();
+}
+double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<double>();
+}
+void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr,
+ MlirStringCallback callback,
+ void *userData) {
+ StringRef str =
+ unwrap(attr).cast<DenseElementsAttr>().getSplatValue<StringRef>();
+ callback(str.data(), str.size(), userData);
+}
+
+//===----------------------------------------------------------------------===//
+// Indexed accessors.
+
+int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
+ pos);
+}
+int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
+ pos);
+}
+uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) {
+ return *(
+ unwrap(attr).cast<DenseElementsAttr>().getValues<uint32_t>().begin() +
+ pos);
+}
+int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int64_t>().begin() +
+ pos);
+}
+uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) {
+ return *(
+ unwrap(attr).cast<DenseElementsAttr>().getValues<uint64_t>().begin() +
+ pos);
+}
+float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<float>().begin() +
+ pos);
+}
+double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<double>().begin() +
+ pos);
+}
+void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos,
+ MlirStringCallback callback,
+ void *userData) {
+ StringRef str =
+ *(unwrap(attr).cast<DenseElementsAttr>().getValues<StringRef>().begin() +
+ pos);
+ callback(str.data(), str.size(), userData);
+}
+
+/*============================================================================*/
+/* Opaque elements attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsAOpaqueElements(MlirAttribute attr) {
+ return unwrap(attr).isa<OpaqueElementsAttr>();
+}
+
+/*============================================================================*/
+/* Sparse elements attribute. */
+/*============================================================================*/
+
+int mlirAttributeIsASparseElements(MlirAttribute attr) {
+ return unwrap(attr).isa<SparseElementsAttr>();
+}
+
+MlirAttribute mlirSparseElementsAttribute(MlirType shapedType,
+ MlirAttribute denseIndices,
+ MlirAttribute denseValues) {
+ return wrap(
+ SparseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(),
+ unwrap(denseIndices).cast<DenseElementsAttr>(),
+ unwrap(denseValues).cast<DenseElementsAttr>()));
+}
+
+MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) {
+ return wrap(unwrap(attr).cast<SparseElementsAttr>().getIndices());
+}
+
+MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) {
+ return wrap(unwrap(attr).cast<SparseElementsAttr>().getValues());
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 12dc100b0bec..0a8ebae4e19e 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -12,11 +12,14 @@
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
+#include "mlir-c/StandardAttributes.h"
#include "mlir-c/StandardTypes.h"
#include <assert.h>
+#include <math.h>
#include <stdio.h>
#include <stdlib.h>
+#include <string.h>
void populateLoopBody(MlirContext ctx, MlirBlock loopBody,
MlirLocation location, MlirBlock funcBody) {
@@ -380,6 +383,210 @@ static int printStandardTypes(MlirContext ctx) {
return 0;
}
+void callbackSetFixedLengthString(const char *data, intptr_t len,
+ void *userData) {
+ strncpy(userData, data, len);
+}
+
+int printStandardAttributes(MlirContext ctx) {
+ MlirAttribute floating =
+ mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0);
+ if (!mlirAttributeIsAFloat(floating) ||
+ fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6)
+ return 1;
+ mlirAttributeDump(floating);
+
+ MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
+ if (!mlirAttributeIsAInteger(integer) ||
+ mlirIntegerAttrGetValueInt(integer) != 42)
+ return 2;
+ mlirAttributeDump(integer);
+
+ MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
+ if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
+ return 3;
+ mlirAttributeDump(boolean);
+
+ const char data[] = "abcdefghijklmnopqestuvwxyz";
+ char buffer[10];
+ MlirAttribute opaque =
+ mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx));
+ if (!mlirAttributeIsAOpaque(opaque) ||
+ strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque)))
+ return 4;
+ mlirOpaqueAttrGetData(opaque, callbackSetFixedLengthString, buffer);
+ if (buffer[0] != 'a' || buffer[1] != 'b' || buffer[2] != 'c')
+ return 5;
+ mlirAttributeDump(opaque);
+
+ MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3);
+ if (!mlirAttributeIsAString(string))
+ return 6;
+ mlirStringAttrGetValue(string, callbackSetFixedLengthString, buffer);
+ if (buffer[0] != 'd' || buffer[1] != 'e')
+ return 7;
+ mlirAttributeDump(string);
+
+ MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5);
+ if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef))
+ return 8;
+ mlirFloatSymbolRefAttrGetValue(flatSymbolRef, callbackSetFixedLengthString,
+ buffer);
+ if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h')
+ return 9;
+ mlirAttributeDump(flatSymbolRef);
+
+ MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef};
+ MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols);
+ if (!mlirAttributeIsASymbolRef(symbolRef) ||
+ mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 ||
+ !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),
+ flatSymbolRef) ||
+ !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1),
+ flatSymbolRef))
+ return 10;
+ mlirSymbolRefAttrGetLeafReference(symbolRef, callbackSetFixedLengthString,
+ buffer);
+ mlirSymbolRefAttrGetRootReference(symbolRef, callbackSetFixedLengthString,
+ buffer + 3);
+ if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h' ||
+ buffer[3] != 'i' || buffer[4] != 'j')
+ return 11;
+ mlirAttributeDump(symbolRef);
+
+ MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx));
+ if (!mlirAttributeIsAType(type) ||
+ !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type)))
+ return 12;
+ mlirAttributeDump(type);
+
+ MlirAttribute unit = mlirUnitAttrGet(ctx);
+ if (!mlirAttributeIsAUnit(unit))
+ return 13;
+ mlirAttributeDump(unit);
+
+ int64_t shape[] = {1, 2};
+
+ int bools[] = {0, 1};
+ uint32_t uints32[] = {0u, 1u};
+ int32_t ints32[] = {0, 1};
+ uint64_t uints64[] = {0u, 1u};
+ int64_t ints64[] = {0, 1};
+ float floats[] = {0.0f, 1.0f};
+ double doubles[] = {0.0, 1.0};
+ MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools);
+ MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2,
+ uints32);
+ MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2,
+ ints32);
+ MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2,
+ uints64);
+ MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2,
+ ints64);
+ MlirAttribute floatElements = mlirDenseElementsAttrFloatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats);
+ MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet(
+ mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles);
+
+ if (!mlirAttributeIsADenseElements(boolElements) ||
+ !mlirAttributeIsADenseElements(uint32Elements) ||
+ !mlirAttributeIsADenseElements(int32Elements) ||
+ !mlirAttributeIsADenseElements(uint64Elements) ||
+ !mlirAttributeIsADenseElements(int64Elements) ||
+ !mlirAttributeIsADenseElements(floatElements) ||
+ !mlirAttributeIsADenseElements(doubleElements))
+ return 14;
+
+ if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
+ mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 ||
+ fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) >
+ 1E-6f ||
+ fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6)
+ return 15;
+
+ mlirAttributeDump(boolElements);
+ mlirAttributeDump(uint32Elements);
+ mlirAttributeDump(int32Elements);
+ mlirAttributeDump(uint64Elements);
+ mlirAttributeDump(int64Elements);
+ mlirAttributeDump(floatElements);
+ mlirAttributeDump(doubleElements);
+
+ MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1);
+ MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
+ MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1);
+ MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
+ MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1);
+ MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f);
+ MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0);
+
+ if (!mlirAttributeIsADenseElements(splatBool) ||
+ !mlirDenseElementsAttrIsSplat(splatBool) ||
+ !mlirAttributeIsADenseElements(splatUInt32) ||
+ !mlirDenseElementsAttrIsSplat(splatUInt32) ||
+ !mlirAttributeIsADenseElements(splatInt32) ||
+ !mlirDenseElementsAttrIsSplat(splatInt32) ||
+ !mlirAttributeIsADenseElements(splatUInt64) ||
+ !mlirDenseElementsAttrIsSplat(splatUInt64) ||
+ !mlirAttributeIsADenseElements(splatInt64) ||
+ !mlirDenseElementsAttrIsSplat(splatInt64) ||
+ !mlirAttributeIsADenseElements(splatFloat) ||
+ !mlirDenseElementsAttrIsSplat(splatFloat) ||
+ !mlirAttributeIsADenseElements(splatDouble) ||
+ !mlirDenseElementsAttrIsSplat(splatDouble))
+ return 16;
+
+ if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
+ mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
+ mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
+ mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
+ mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 ||
+ fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) >
+ 1E-6f ||
+ fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
+ return 17;
+
+ mlirAttributeDump(splatBool);
+ mlirAttributeDump(splatUInt32);
+ mlirAttributeDump(splatInt32);
+ mlirAttributeDump(splatUInt64);
+ mlirAttributeDump(splatInt64);
+ mlirAttributeDump(splatFloat);
+ mlirAttributeDump(splatDouble);
+
+ mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
+ mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
+
+ int64_t indices[] = {4, 7};
+ int64_t two = 2;
+ MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get(
+ mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2,
+ indices);
+ MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet(
+ mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats);
+ MlirAttribute sparseAttr = mlirSparseElementsAttribute(
+ mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr,
+ valuesAttr);
+ mlirAttributeDump(sparseAttr);
+
+ return 0;
+}
+
int main() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
@@ -454,10 +661,43 @@ int main() {
// CHECK: tuple<memref<*xf32, 4>, f32>
// CHECK: 0
// clang-format on
- fprintf(stderr, "@types");
+ fprintf(stderr, "@types\n");
int errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode);
+ // clang-format off
+ // CHECK-LABEL: @attrs
+ // CHECK: 2.000000e+00 : f64
+ // CHECK: 42 : i32
+ // CHECK: true
+ // CHECK: #std.abc
+ // CHECK: "de"
+ // CHECK: @fgh
+ // CHECK: @ij::@fgh::@fgh
+ // CHECK: f32
+ // CHECK: unit
+ // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64>
+ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
+ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
+ // CHECK: dense<true> : tensor<1x2xi1>
+ // CHECK: dense<1> : tensor<1x2xi32>
+ // CHECK: dense<1> : tensor<1x2xi32>
+ // CHECK: dense<1> : tensor<1x2xi64>
+ // CHECK: dense<1> : tensor<1x2xi64>
+ // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
+ // CHECK: dense<1.000000e+00> : tensor<1x2xf64>
+ // CHECK: 1.000000e+00 : f32
+ // CHECK: 1.000000e+00 : f64
+ // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32>
+ // clang-format on
+ fprintf(stderr, "@attrs\n");
+ errcode = printStandardAttributes(ctx);
+ fprintf(stderr, "%d\n", errcode);
+
mlirContextDestroy(ctx);
return 0;
More information about the Mlir-commits
mailing list