[Mlir-commits] [mlir] 74f5778 - [mlir] expose standard types to C API
Alex Zinenko
llvmlistbot at llvm.org
Tue Aug 18 04:11:49 PDT 2020
Author: Alex Zinenko
Date: 2020-08-18T13:11:37+02:00
New Revision: 74f577845e8174e255688589d845d43eacf3923f
URL: https://github.com/llvm/llvm-project/commit/74f577845e8174e255688589d845d43eacf3923f
DIFF: https://github.com/llvm/llvm-project/commit/74f577845e8174e255688589d845d43eacf3923f.diff
LOG: [mlir] expose standard types to C API
Provide C API for MLIR standard types. Since standard types live under lib/IR
in core MLIR, place the C APIs in the IR library as well (standard ops will go
into a separate library). This also defines a placeholder for affine maps that
are necessary to construct a memref, but are not yet exposed to the C API.
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D86094
Added:
mlir/include/mlir-c/AffineMap.h
mlir/include/mlir-c/StandardTypes.h
mlir/include/mlir/CAPI/AffineMap.h
mlir/include/mlir/CAPI/IR.h
mlir/include/mlir/CAPI/Wrap.h
mlir/lib/CAPI/IR/AffineMap.cpp
mlir/lib/CAPI/IR/StandardTypes.cpp
Modified:
mlir/docs/CAPI.md
mlir/include/mlir-c/IR.h
mlir/include/mlir/IR/AffineMap.h
mlir/lib/CAPI/IR/CMakeLists.txt
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index 6adb9db3331c..a8fcfbafb8b1 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -75,6 +75,28 @@ check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
expect null objects as arguments unless explicitly stated otherwise. API
functions _may_ return null objects.
+### Type Hierarchies
+
+MLIR objects can form type hierarchies in C++. For example, all IR classes
+representing types are derived from `mlir::Type`, some of them may also be also
+derived from common base classes such as `mlir::ShapedType` or dialect-specific
+base classes. Type hierarchies are exposed to C API through naming conventions
+as follows.
+
+- Only the top-level class of each hierarchy is exposed, e.g. `MlirType` is
+ defined as a type but `MlirShapedType` is not. This avoids the need for
+ explicit upcasting when passing an object of a derived type to a function
+ that expects a base type (this happens more often in core/standard APIs,
+ while downcasting usually involves further checks anyway).
+- A type `Y` that derives from `X` provides a function `int mlirXIsAY(MlirX)`
+ that returns a non-zero value if the given dynamic instance of `X` is also
+ an instance of `Y`. For example, `int MlirTypeIsAInteger(MlirType)`.
+- A function that expects a derived type as its first argument takes the base
+ type instead and documents the expectation by using `Y` in its name
+ `MlirY<...>(MlirX, ...)`. This function asserts that the dynamic instance of
+ its first argument is `Y`, and it is the responsibility of the caller to
+ ensure it is indeed the case.
+
### Conversion To String and Printing
IR objects can be converted to a string representation, for example for
@@ -96,11 +118,11 @@ allocation and avoid unnecessary allocation and copying inside the printer.
For convenience, `mlirXDump(MlirX)` functions are provided to print the given
object to the standard error stream.
-### Common Patterns
+## Common Patterns
The API adopts the following patterns for recurrent functionality in MLIR.
-#### Indexed Components
+### Indexed Components
An object has an _indexed component_ if it has fields accessible using a
zero-based contiguous integer index, typically arrays. For example, an
@@ -120,7 +142,7 @@ Note that the name of subobject in the function does not necessarily match the
type of the subobject. For example, `mlirOperationGetOperand` returns a
`MlirValue`.
-#### Iterable Components
+### Iterable Components
An object has an _iterable component_ if it has iterators accessing its fields
in some order other than integer indexing, typically linked lists. For example,
@@ -146,3 +168,17 @@ for (iter = mlirXGetFirst<Y>(x); !mlirYIsNull(iter);
/* User 'iter'. */
}
```
+
+## Extending the API
+
+### Extensions for Dialect Attributes and Types
+
+Dialect attributes and types can follow the example of standard attrbutes and
+types, provided that implementations live in separate directories, i.e.
+`include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs
+provide implementation-private headers in `include/mlir/CAPI/IR` that allow one
+to convert between opaque C structures for core IR components and their C++
+counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
+the inverse conversion. Once the a C++ object is available, the API
+implementation should rely on `isa` to implement `mlirXIsAY` and is expected to
+use `cast` inside other API calls.
diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
new file mode 100644
index 000000000000..bef13fd0bfa8
--- /dev/null
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -0,0 +1,25 @@
+/*===-- mlir-c/AffineMap.h - C API for MLIR Affine maps -----------*- C -*-===*\
+|* *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
+|* *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_AFFINEMAP_H
+#define MLIR_C_AFFINEMAP_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+DEFINE_C_API_STRUCT(MlirAffineMap, const void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_AFFINEMAP_H
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6b5be2d0195b..68546bf35625 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -56,8 +56,6 @@ DEFINE_C_API_STRUCT(MlirType, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
-#undef DEFINE_C_API_STRUCT
-
/** Named MLIR attribute.
*
* A named attribute is essentially a (name, attribute) pair where the name is
@@ -314,6 +312,9 @@ void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
/** Parses a type. The type is owned by the context. */
MlirType mlirTypeParseGet(MlirContext context, const char *type);
+/** Checks if two types are equal. */
+int mlirTypeEqual(MlirType t1, MlirType t2);
+
/** Prints a location by sending chunks of the string representation and
* forwarding `userData to `callback`. Note that the callback may be called
* several times with consecutive chunks of the string. */
diff --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h
new file mode 100644
index 000000000000..ad28ea546717
--- /dev/null
+++ b/mlir/include/mlir-c/StandardTypes.h
@@ -0,0 +1,249 @@
+/*===-- mlir-c/StandardTypes.h - C API for MLIR Standard types ----*- C -*-===*\
+|* *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
+|* *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_STANDARDTYPES_H
+#define MLIR_C_STANDARDTYPES_H
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*============================================================================*/
+/* Integer types. */
+/*============================================================================*/
+
+/** Checks whether the given type is an integer type. */
+int mlirTypeIsAInteger(MlirType type);
+
+/** Creates a signless integer type of the given bitwidth in the context. The
+ * type is owned by the context. */
+MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth);
+
+/** Creates a signed integer type of the given bitwidth in the context. The type
+ * is owned by the context. */
+MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth);
+
+/** Creates an unsigned integer type of the given bitwidth in the context. The
+ * type is owned by the context. */
+MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth);
+
+/** Returns the bitwidth of an integer type. */
+unsigned mlirIntegerTypeGetWidth(MlirType type);
+
+/** Checks whether the given integer type is signless. */
+int mlirIntegerTypeIsSignless(MlirType type);
+
+/** Checks whether the given integer type is signed. */
+int mlirIntegerTypeIsSigned(MlirType type);
+
+/** Checks whether the given integer type is unsigned. */
+int mlirIntegerTypeIsUnsigned(MlirType type);
+
+/*============================================================================*/
+/* Index type. */
+/*============================================================================*/
+
+/** Checks whether the given type is an index type. */
+int mlirTypeIsAIndex(MlirType type);
+
+/** Creates an index type in the given context. The type is owned by the
+ * context. */
+MlirType mlirIndexTypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* Floating-point types. */
+/*============================================================================*/
+
+/** Checks whether the given type is a bf16 type. */
+int mlirTypeIsABF16(MlirType type);
+
+/** Creates a bf16 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirBF16TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f16 type. */
+int mlirTypeIsAF16(MlirType type);
+
+/** Creates an f16 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF16TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f32 type. */
+int mlirTypeIsAF32(MlirType type);
+
+/** Creates an f32 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF32TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f64 type. */
+int mlirTypeIsAF64(MlirType type);
+
+/** Creates a f64 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF64TypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* None type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a None type. */
+int mlirTypeIsANone(MlirType type);
+
+/** Creates a None type in the given context. The type is owned by the
+ * context. */
+MlirType mlirNoneTypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* Complex type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a Complex type. */
+int mlirTypeIsAComplex(MlirType type);
+
+/** Creates a complex type with the given element type in the same context as
+ * the element type. The type is owned by the context. */
+MlirType mlirComplexTypeGet(MlirType elementType);
+
+/** Returns the element type of the given complex type. */
+MlirType mlirComplexTypeGetElementType(MlirType type);
+
+/*============================================================================*/
+/* Shaped type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a Shaped type. */
+int mlirTypeIsAShaped(MlirType type);
+
+/** Returns the element type of the shaped type. */
+MlirType mlirShapedTypeGetElementType(MlirType type);
+
+/** Checks whether the given shaped type is ranked. */
+int mlirShapedTypeHasRank(MlirType type);
+
+/** Returns the rank of the given ranked shaped type. */
+int64_t mlirShapedTypeGetRank(MlirType type);
+
+/** Checks whether the given shaped type has a static shape. */
+int mlirShapedTypeHasStaticShape(MlirType type);
+
+/** Checks wither the dim-th dimension of the given shaped type is dynamic. */
+int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
+
+/** Returns the dim-th dimension of the given ranked shaped type. */
+int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim);
+
+/** Checks whether the given value is used as a placeholder for dynamic sizes
+ * in shaped types. */
+int mlirShapedTypeIsDynamicSize(int64_t size);
+
+/** Checks whether the given value is used as a placeholder for dynamic strides
+ * and offsets in shaped types. */
+int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
+
+/*============================================================================*/
+/* Vector type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a Vector type. */
+int mlirTypeIsAVector(MlirType type);
+
+/** Creates a vector type of the shape identified by its rank and dimensios,
+ * with the given element type in the same context as the element type. The type
+ * is owned by the context. */
+MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType);
+
+/*============================================================================*/
+/* Ranked / Unranked Tensor type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a Tensor type. */
+int mlirTypeIsATensor(MlirType type);
+
+/** Checks whether the given type is a ranked tensor type. */
+int mlirTypeIsARankedTensor(MlirType type);
+
+/** Checks whether the given type is an unranked tensor type. */
+int mlirTypeIsAUnrankedTensor(MlirType type);
+
+/** Creates a tensor type of a fixed rank with the given shape and element type
+ * in the same context as the element type. The type is owned by the context. */
+MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
+ MlirType elementType);
+
+/** Creates an unranked tensor type with the given element type in the same
+ * context as the element type. The type is owned by the context. */
+MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
+
+/*============================================================================*/
+/* Ranked / Unranked MemRef type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a MemRef type. */
+int mlirTypeIsAMemRef(MlirType type);
+
+/** Checks whether the given type is an UnrankedMemRef type. */
+int mlirTypeIsAUnrankedMemRef(MlirType type);
+
+/** Creates a MemRef type with the given rank and shape, a potentially empty
+ * list of affine layout maps, the given memory space and element type, in the
+ * same context as element type. The type is owned by the context. */
+MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
+ intptr_t numMaps, MlirAttribute *affineMaps,
+ unsigned memorySpace);
+
+/** Creates a MemRef type with the given rank, shape, memory space and element
+ * type in the same context as the element type. The type has no affine maps,
+ * i.e. represents a default row-major contiguous memref. The type is owned by
+ * the context. */
+MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
+ int64_t *shape, unsigned memorySpace);
+
+/** Creates an Unranked MemRef type with the given element type and in the given
+ * memory space. The type is owned by the context of element type. */
+MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace);
+
+/** Returns the number of affine layout maps in the given MemRef type. */
+intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
+
+/** Returns the pos-th affine map of the given MemRef type. */
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos);
+
+/** Returns the memory space of the given MemRef type. */
+unsigned mlirMemRefTypeGetMemorySpace(MlirType type);
+
+/** Returns the memory spcae of the given Unranked MemRef type. */
+unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type);
+
+/*============================================================================*/
+/* Tuple type. */
+/*============================================================================*/
+
+/** Checks whether the given type is a tuple type. */
+int mlirTypeIsATuple(MlirType type);
+
+/** Creates a tuple type that consists of the given list of elemental types. The
+ * type is owned by the context. */
+MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
+ MlirType *elements);
+
+/** Returns the number of types contained in a tuple. */
+intptr_t mlirTupleTypeGetNumTypes(MlirType type);
+
+/** Returns the pos-th type in the tuple type. */
+MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_STANDARDTYPES_H
diff --git a/mlir/include/mlir/CAPI/AffineMap.h b/mlir/include/mlir/CAPI/AffineMap.h
new file mode 100644
index 000000000000..cea48ffae8b6
--- /dev/null
+++ b/mlir/include/mlir/CAPI/AffineMap.h
@@ -0,0 +1,24 @@
+//===- AffineMap.h - C API Utils for Affine Maps ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// MLIR Affine maps. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_AFFINEMAP_H
+#define MLIR_CAPI_AFFINEMAP_H
+
+#include "mlir-c/AffineMap.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/AffineMap.h"
+
+DEFINE_C_API_METHODS(MlirAffineMap, mlir::AffineMap)
+
+#endif // MLIR_CAPI_AFFINEMAP_H
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
new file mode 100644
index 000000000000..9a60ecf04fc8
--- /dev/null
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -0,0 +1,34 @@
+//===- IR.h - C API Utils for Core MLIR classes -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// core MLIR classes. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_CAPI_IR_H
+#define MLIR_INCLUDE_MLIR_CAPI_IR_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+
+DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
+DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
+DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
+DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
+
+DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
+DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
+DEFINE_C_API_METHODS(MlirType, mlir::Type)
+DEFINE_C_API_METHODS(MlirValue, mlir::Value)
+DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
+
+#endif // MLIR_INCLUDE_MLIR_CAPI_IR_H
diff --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h
new file mode 100644
index 000000000000..940007caac06
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Wrap.h
@@ -0,0 +1,56 @@
+//===- Wrap.h - C API Utilities ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains common definitions for wrapping opaque C++ pointers into
+// C structures for the purpose of C API. This file should not be included from
+// C++ code other than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_WRAP_H
+#define MLIR_CAPI_WRAP_H
+
+#include "mlir-c/IR.h"
+#include "mlir/Support/LLVM.h"
+
+/* ========================================================================== */
+/* Definitions of methods for non-owning structures used in C API. */
+/* ========================================================================== */
+
+#define DEFINE_C_API_PTR_METHODS(name, cpptype) \
+ static inline name wrap(cpptype *cpp) { return name{cpp}; } \
+ static inline cpptype *unwrap(name c) { \
+ return static_cast<cpptype *>(c.ptr); \
+ }
+
+#define DEFINE_C_API_METHODS(name, cpptype) \
+ static inline name wrap(cpptype cpp) { \
+ return name{cpp.getAsOpaquePointer()}; \
+ } \
+ static inline cpptype unwrap(name c) { \
+ return cpptype::getFromOpaquePointer(c.ptr); \
+ }
+
+template <typename CppTy, typename CTy>
+static llvm::ArrayRef<CppTy> unwrapList(size_t size, CTy *first,
+ llvm::SmallVectorImpl<CppTy> &storage) {
+ static_assert(
+ std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
+ "incompatible C and C++ types");
+
+ if (size == 0)
+ return llvm::None;
+
+ assert(storage.empty() && "expected to populate storage");
+ storage.reserve(size);
+ for (size_t i = 0; i < size; ++i)
+ storage.push_back(unwrap(*(first + i)));
+ return storage;
+}
+
+#endif // MLIR_CAPI_WRAP_H
diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index d946c7591c2a..dd4960a02c5c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -196,6 +196,14 @@ class AffineMap {
friend ::llvm::hash_code hash_value(AffineMap arg);
+ /// Methods supporting C API.
+ const void *getAsOpaquePointer() const {
+ return static_cast<const void *>(map);
+ }
+ static AffineMap getFromOpaquePointer(const void *pointer) {
+ return AffineMap(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
+ }
+
private:
ImplType *map;
diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
new file mode 100644
index 000000000000..d80d9e20486a
--- /dev/null
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -0,0 +1,15 @@
+//===- AffineMap.cpp - C API for MLIR Affine Maps -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/AffineMap.h"
+#include "mlir/IR/AffineMap.h"
+
+// This is a placeholder for affine map bindings. The file is here to serve as a
+// compilation unit that includes the headers.
diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index fdf239f975d3..64e715e33d88 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
# Main API.
add_mlir_library(MLIRCAPIIR
+ AffineMap.cpp
IR.cpp
+ StandardTypes.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5231096af785..1ba1a6aca6f8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -8,6 +8,7 @@
#include "mlir-c/IR.h"
+#include "mlir/CAPI/IR.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
@@ -17,46 +18,6 @@
using namespace mlir;
-/* ========================================================================== */
-/* Definitions of methods for non-owning structures used in C API. */
-/* ========================================================================== */
-
-#define DEFINE_C_API_PTR_METHODS(name, cpptype) \
- static name wrap(cpptype *cpp) { return name{cpp}; } \
- static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
-
-DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
-DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
-DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
-DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
-
-#define DEFINE_C_API_METHODS(name, cpptype) \
- static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \
- static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
-
-DEFINE_C_API_METHODS(MlirAttribute, Attribute)
-DEFINE_C_API_METHODS(MlirLocation, Location);
-DEFINE_C_API_METHODS(MlirType, Type)
-DEFINE_C_API_METHODS(MlirValue, Value)
-DEFINE_C_API_METHODS(MlirModule, ModuleOp)
-
-template <typename CppTy, typename CTy>
-static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
- SmallVectorImpl<CppTy> &storage) {
- static_assert(
- std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
- "incompatible C and C++ types");
-
- if (size == 0)
- return llvm::None;
-
- assert(storage.empty() && "expected to populate storage");
- storage.reserve(size);
- for (intptr_t i = 0; i < size; ++i)
- storage.push_back(unwrap(*(first + i)));
- return storage;
-}
-
/* ========================================================================== */
/* Printing helper. */
/* ========================================================================== */
@@ -388,6 +349,8 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
return wrap(mlir::parseType(type, unwrap(context)));
}
+int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
+
void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
CallbackOstream stream(callback, userData);
unwrap(type).print(stream);
diff --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp
new file mode 100644
index 000000000000..eb006242e880
--- /dev/null
+++ b/mlir/lib/CAPI/IR/StandardTypes.cpp
@@ -0,0 +1,263 @@
+//===- StandardTypes.cpp - C Interface to MLIR Standard Types -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/StandardTypes.h"
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/AffineMap.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+/* ========================================================================== */
+/* Integer types. */
+/* ========================================================================== */
+
+int mlirTypeIsAInteger(MlirType type) {
+ return unwrap(type).isa<IntegerType>();
+}
+
+MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
+ return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
+}
+
+MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
+ return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
+}
+
+MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
+ return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
+}
+
+unsigned mlirIntegerTypeGetWidth(MlirType type) {
+ return unwrap(type).cast<IntegerType>().getWidth();
+}
+
+int mlirIntegerTypeIsSignless(MlirType type) {
+ return unwrap(type).cast<IntegerType>().isSignless();
+}
+
+int mlirIntegerTypeIsSigned(MlirType type) {
+ return unwrap(type).cast<IntegerType>().isSigned();
+}
+
+int mlirIntegerTypeIsUnsigned(MlirType type) {
+ return unwrap(type).cast<IntegerType>().isUnsigned();
+}
+
+/* ========================================================================== */
+/* Index type. */
+/* ========================================================================== */
+
+int mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
+
+MlirType mlirIndexTypeGet(MlirContext ctx) {
+ return wrap(IndexType::get(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* Floating-point types. */
+/* ========================================================================== */
+
+int mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+
+MlirType mlirBF16TypeGet(MlirContext ctx) {
+ return wrap(FloatType::getBF16(unwrap(ctx)));
+}
+
+int mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+
+MlirType mlirF16TypeGet(MlirContext ctx) {
+ return wrap(FloatType::getF16(unwrap(ctx)));
+}
+
+int mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+
+MlirType mlirF32TypeGet(MlirContext ctx) {
+ return wrap(FloatType::getF32(unwrap(ctx)));
+}
+
+int mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+
+MlirType mlirF64TypeGet(MlirContext ctx) {
+ return wrap(FloatType::getF64(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* None type. */
+/* ========================================================================== */
+
+int mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
+
+MlirType mlirNoneTypeGet(MlirContext ctx) {
+ return wrap(NoneType::get(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* Complex type. */
+/* ========================================================================== */
+
+int mlirTypeIsAComplex(MlirType type) {
+ return unwrap(type).isa<ComplexType>();
+}
+
+MlirType mlirComplexTypeGet(MlirType elementType) {
+ return wrap(ComplexType::get(unwrap(elementType)));
+}
+
+MlirType mlirComplexTypeGetElementType(MlirType type) {
+ return wrap(unwrap(type).cast<ComplexType>().getElementType());
+}
+
+/* ========================================================================== */
+/* Shaped type. */
+/* ========================================================================== */
+
+int mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
+
+MlirType mlirShapedTypeGetElementType(MlirType type) {
+ return wrap(unwrap(type).cast<ShapedType>().getElementType());
+}
+
+int mlirShapedTypeHasRank(MlirType type) {
+ return unwrap(type).cast<ShapedType>().hasRank();
+}
+
+int64_t mlirShapedTypeGetRank(MlirType type) {
+ return unwrap(type).cast<ShapedType>().getRank();
+}
+
+int mlirShapedTypeHasStaticShape(MlirType type) {
+ return unwrap(type).cast<ShapedType>().hasStaticShape();
+}
+
+int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
+ return unwrap(type).cast<ShapedType>().isDynamicDim(
+ static_cast<unsigned>(dim));
+}
+
+int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
+ return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
+}
+
+int mlirShapedTypeIsDynamicSize(int64_t size) {
+ return ShapedType::isDynamic(size);
+}
+
+int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
+ return ShapedType::isDynamicStrideOrOffset(val);
+}
+
+/* ========================================================================== */
+/* Vector type. */
+/* ========================================================================== */
+
+int mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
+
+MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape,
+ MlirType elementType) {
+ return wrap(
+ VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType)));
+}
+
+/* ========================================================================== */
+/* Ranked / Unranked tensor type. */
+/* ========================================================================== */
+
+int mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
+
+int mlirTypeIsARankedTensor(MlirType type) {
+ return unwrap(type).isa<RankedTensorType>();
+}
+
+int mlirTypeIsAUnrankedTensor(MlirType type) {
+ return unwrap(type).isa<UnrankedTensorType>();
+}
+
+MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
+ MlirType elementType) {
+ return wrap(RankedTensorType::get(
+ llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType)));
+}
+
+MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
+ return wrap(UnrankedTensorType::get(unwrap(elementType)));
+}
+
+/* ========================================================================== */
+/* Ranked / Unranked MemRef type. */
+/* ========================================================================== */
+
+int mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
+
+MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
+ intptr_t numMaps, MlirAffineMap *affineMaps,
+ unsigned memorySpace) {
+ SmallVector<AffineMap, 1> maps;
+ (void)unwrapList(numMaps, affineMaps, maps);
+ return wrap(
+ MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType), maps, memorySpace));
+}
+
+MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
+ int64_t *shape, unsigned memorySpace) {
+ return wrap(
+ MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+ unwrap(elementType), llvm::None, memorySpace));
+}
+
+intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
+ return static_cast<intptr_t>(
+ unwrap(type).cast<MemRefType>().getAffineMaps().size());
+}
+
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
+ return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
+}
+
+unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
+ return unwrap(type).cast<MemRefType>().getMemorySpace();
+}
+
+int mlirTypeIsAUnrankedMemRef(MlirType type) {
+ return unwrap(type).isa<UnrankedMemRefType>();
+}
+
+MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
+ return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
+}
+
+unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
+ return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
+}
+
+/* ========================================================================== */
+/* Tuple type. */
+/* ========================================================================== */
+
+int mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
+
+MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
+ MlirType *elements) {
+ SmallVector<Type, 4> types;
+ ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
+ return wrap(TupleType::get(typeRef, unwrap(ctx)));
+}
+
+intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
+ return unwrap(type).cast<TupleType>().size();
+}
+
+MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
+ return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d6ab3513384f..56b7ecd7fd7c 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -12,6 +12,7 @@
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
+#include "mlir-c/StandardTypes.h"
#include <assert.h>
#include <stdio.h>
@@ -240,6 +241,145 @@ static void printFirstOfEach(MlirOperation operation) {
fprintf(stderr, "\n");
}
+/// Dumps instances of all standard types to check that C API works correctly.
+/// Additionally, performs simple identity checks that a standard type
+/// constructed with C API can be inspected and has the expected type. The
+/// latter achieves full coverage of C API for standard types. Returns 0 on
+/// success and a non-zero error code on failure.
+static int printStandardTypes(MlirContext ctx) {
+ // Integer types.
+ MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+ MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
+ MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
+ if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
+ return 1;
+ if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
+ return 2;
+ if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
+ return 3;
+ if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
+ return 4;
+ if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
+ return 5;
+ mlirTypeDump(i32);
+ fprintf(stderr, "\n");
+ mlirTypeDump(si32);
+ fprintf(stderr, "\n");
+ mlirTypeDump(ui32);
+ fprintf(stderr, "\n");
+
+ // Index type.
+ MlirType index = mlirIndexTypeGet(ctx);
+ if (!mlirTypeIsAIndex(index))
+ return 6;
+ mlirTypeDump(index);
+ fprintf(stderr, "\n");
+
+ // Floating-point types.
+ MlirType bf16 = mlirBF16TypeGet(ctx);
+ MlirType f16 = mlirF16TypeGet(ctx);
+ MlirType f32 = mlirF32TypeGet(ctx);
+ MlirType f64 = mlirF64TypeGet(ctx);
+ if (!mlirTypeIsABF16(bf16))
+ return 7;
+ if (!mlirTypeIsAF16(f16))
+ return 9;
+ if (!mlirTypeIsAF32(f32))
+ return 10;
+ if (!mlirTypeIsAF64(f64))
+ return 11;
+ mlirTypeDump(bf16);
+ fprintf(stderr, "\n");
+ mlirTypeDump(f16);
+ fprintf(stderr, "\n");
+ mlirTypeDump(f32);
+ fprintf(stderr, "\n");
+ mlirTypeDump(f64);
+ fprintf(stderr, "\n");
+
+ // None type.
+ MlirType none = mlirNoneTypeGet(ctx);
+ if (!mlirTypeIsANone(none))
+ return 12;
+ mlirTypeDump(none);
+ fprintf(stderr, "\n");
+
+ // Complex type.
+ MlirType cplx = mlirComplexTypeGet(f32);
+ if (!mlirTypeIsAComplex(cplx) ||
+ !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
+ return 13;
+ mlirTypeDump(cplx);
+ fprintf(stderr, "\n");
+
+ // Vector (and Shaped) type. ShapedType is a common base class for vectors,
+ // memrefs and tensors, one cannot create instances of this class so it is
+ // tested on an instance of vector type.
+ int64_t shape[] = {2, 3};
+ MlirType vector =
+ mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
+ if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
+ return 14;
+ if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
+ !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
+ mlirShapedTypeGetDimSize(vector, 0) != 2 ||
+ mlirShapedTypeIsDynamicDim(vector, 0) ||
+ mlirShapedTypeGetDimSize(vector, 1) != 3 ||
+ !mlirShapedTypeHasStaticShape(vector))
+ return 15;
+ mlirTypeDump(vector);
+ fprintf(stderr, "\n");
+
+ // Ranked tensor type.
+ MlirType rankedTensor =
+ mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
+ if (!mlirTypeIsATensor(rankedTensor) ||
+ !mlirTypeIsARankedTensor(rankedTensor))
+ return 16;
+ mlirTypeDump(rankedTensor);
+ fprintf(stderr, "\n");
+
+ // Unranked tensor type.
+ MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
+ if (!mlirTypeIsATensor(unrankedTensor) ||
+ !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
+ mlirShapedTypeHasRank(unrankedTensor))
+ return 17;
+ mlirTypeDump(unrankedTensor);
+ fprintf(stderr, "\n");
+
+ // MemRef type.
+ MlirType memRef = mlirMemRefTypeContiguousGet(
+ f32, sizeof(shape) / sizeof(int64_t), shape, 2);
+ if (!mlirTypeIsAMemRef(memRef) ||
+ mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
+ mlirMemRefTypeGetMemorySpace(memRef) != 2)
+ return 18;
+ mlirTypeDump(memRef);
+ fprintf(stderr, "\n");
+
+ // Unranked MemRef type.
+ MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
+ if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
+ mlirTypeIsAMemRef(unrankedMemRef) ||
+ mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
+ return 19;
+ mlirTypeDump(unrankedMemRef);
+ fprintf(stderr, "\n");
+
+ // Tuple type.
+ MlirType types[] = {unrankedMemRef, f32};
+ MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
+ if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
+ !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
+ !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
+ return 20;
+ mlirTypeDump(tuple);
+ fprintf(stderr, "\n");
+
+ return 0;
+}
+
int main() {
mlirRegisterAllDialects();
MlirContext ctx = mlirContextCreate();
@@ -293,6 +433,31 @@ int main() {
// clang-format on
mlirModuleDestroy(moduleOp);
+
+ // clang-format off
+ // CHECK-LABEL: @types
+ // CHECK: i32
+ // CHECK: si32
+ // CHECK: ui32
+ // CHECK: index
+ // CHECK: bf16
+ // CHECK: f16
+ // CHECK: f32
+ // CHECK: f64
+ // CHECK: none
+ // CHECK: complex<f32>
+ // CHECK: vector<2x3xf32>
+ // CHECK: tensor<2x3xf32>
+ // CHECK: tensor<*xf32>
+ // CHECK: memref<2x3xf32, 2>
+ // CHECK: memref<*xf32, 4>
+ // CHECK: tuple<memref<*xf32, 4>, f32>
+ // CHECK: 0
+ // clang-format on
+ fprintf(stderr, "@types");
+ int errcode = printStandardTypes(ctx);
+ fprintf(stderr, "%d\n", errcode);
+
mlirContextDestroy(ctx);
return 0;
More information about the Mlir-commits
mailing list