[Mlir-commits] [mlir] 74f5778 - [mlir] expose standard types to C API

Alex Zinenko llvmlistbot at llvm.org
Tue Aug 18 04:11:49 PDT 2020


Author: Alex Zinenko
Date: 2020-08-18T13:11:37+02:00
New Revision: 74f577845e8174e255688589d845d43eacf3923f

URL: https://github.com/llvm/llvm-project/commit/74f577845e8174e255688589d845d43eacf3923f
DIFF: https://github.com/llvm/llvm-project/commit/74f577845e8174e255688589d845d43eacf3923f.diff

LOG: [mlir] expose standard types to C API

Provide C API for MLIR standard types. Since standard types live under lib/IR
in core MLIR, place the C APIs in the IR library as well (standard ops will go
into a separate library). This also defines a placeholder for affine maps that
are necessary to construct a memref, but are not yet exposed to the C API.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D86094

Added: 
    mlir/include/mlir-c/AffineMap.h
    mlir/include/mlir-c/StandardTypes.h
    mlir/include/mlir/CAPI/AffineMap.h
    mlir/include/mlir/CAPI/IR.h
    mlir/include/mlir/CAPI/Wrap.h
    mlir/lib/CAPI/IR/AffineMap.cpp
    mlir/lib/CAPI/IR/StandardTypes.cpp

Modified: 
    mlir/docs/CAPI.md
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/IR/AffineMap.h
    mlir/lib/CAPI/IR/CMakeLists.txt
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index 6adb9db3331c..a8fcfbafb8b1 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -75,6 +75,28 @@ check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_
 expect null objects as arguments unless explicitly stated otherwise. API
 functions _may_ return null objects.
 
+### Type Hierarchies
+
+MLIR objects can form type hierarchies in C++. For example, all IR classes
+representing types are derived from `mlir::Type`, some of them may also be also
+derived from common base classes such as `mlir::ShapedType` or dialect-specific
+base classes. Type hierarchies are exposed to C API through naming conventions
+as follows.
+
+-   Only the top-level class of each hierarchy is exposed, e.g. `MlirType` is
+    defined as a type but `MlirShapedType` is not. This avoids the need for
+    explicit upcasting when passing an object of a derived type to a function
+    that expects a base type (this happens more often in core/standard APIs,
+    while downcasting usually involves further checks anyway).
+-   A type `Y` that derives from `X` provides a function `int mlirXIsAY(MlirX)`
+    that returns a non-zero value if the given dynamic instance of `X` is also
+    an instance of `Y`. For example, `int MlirTypeIsAInteger(MlirType)`.
+-   A function that expects a derived type as its first argument takes the base
+    type instead and documents the expectation by using `Y` in its name
+    `MlirY<...>(MlirX, ...)`. This function asserts that the dynamic instance of
+    its first argument is `Y`, and it is the responsibility of the caller to
+    ensure it is indeed the case.
+
 ### Conversion To String and Printing
 
 IR objects can be converted to a string representation, for example for
@@ -96,11 +118,11 @@ allocation and avoid unnecessary allocation and copying inside the printer.
 For convenience, `mlirXDump(MlirX)` functions are provided to print the given
 object to the standard error stream.
 
-### Common Patterns
+## Common Patterns
 
 The API adopts the following patterns for recurrent functionality in MLIR.
 
-#### Indexed Components
+### Indexed Components
 
 An object has an _indexed component_ if it has fields accessible using a
 zero-based contiguous integer index, typically arrays. For example, an
@@ -120,7 +142,7 @@ Note that the name of subobject in the function does not necessarily match the
 type of the subobject. For example, `mlirOperationGetOperand` returns a
 `MlirValue`.
 
-#### Iterable Components
+### Iterable Components
 
 An object has an _iterable component_ if it has iterators accessing its fields
 in some order other than integer indexing, typically linked lists. For example,
@@ -146,3 +168,17 @@ for (iter = mlirXGetFirst<Y>(x); !mlirYIsNull(iter);
   /* User 'iter'. */
 }
 ```
+
+## Extending the API
+
+### Extensions for Dialect Attributes and Types
+
+Dialect attributes and types can follow the example of standard attrbutes and
+types, provided that implementations live in separate directories, i.e.
+`include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs
+provide implementation-private headers in `include/mlir/CAPI/IR` that allow one
+to convert between opaque C structures for core IR components and their C++
+counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
+the inverse conversion. Once the a C++ object is available, the API
+implementation should rely on `isa` to implement `mlirXIsAY` and is expected to
+use `cast` inside other API calls.

diff  --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h
new file mode 100644
index 000000000000..bef13fd0bfa8
--- /dev/null
+++ b/mlir/include/mlir-c/AffineMap.h
@@ -0,0 +1,25 @@
+/*===-- mlir-c/AffineMap.h - C API for MLIR Affine maps -----------*- C -*-===*\
+|*                                                                            *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM          *|
+|* Exceptions.                                                                *|
+|* See https://llvm.org/LICENSE.txt for license information.                  *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception                    *|
+|*                                                                            *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_AFFINEMAP_H
+#define MLIR_C_AFFINEMAP_H
+
+#include "mlir-c/IR.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+DEFINE_C_API_STRUCT(MlirAffineMap, const void);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_AFFINEMAP_H

diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 6b5be2d0195b..68546bf35625 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -56,8 +56,6 @@ DEFINE_C_API_STRUCT(MlirType, const void);
 DEFINE_C_API_STRUCT(MlirLocation, const void);
 DEFINE_C_API_STRUCT(MlirModule, const void);
 
-#undef DEFINE_C_API_STRUCT
-
 /** Named MLIR attribute.
  *
  * A named attribute is essentially a (name, attribute) pair where the name is
@@ -314,6 +312,9 @@ void mlirValuePrint(MlirValue value, MlirPrintCallback callback,
 /** Parses a type. The type is owned by the context. */
 MlirType mlirTypeParseGet(MlirContext context, const char *type);
 
+/** Checks if two types are equal. */
+int mlirTypeEqual(MlirType t1, MlirType t2);
+
 /** Prints a location by sending chunks of the string representation and
  * forwarding `userData to `callback`. Note that the callback may be called
  * several times with consecutive chunks of the string. */

diff  --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h
new file mode 100644
index 000000000000..ad28ea546717
--- /dev/null
+++ b/mlir/include/mlir-c/StandardTypes.h
@@ -0,0 +1,249 @@
+/*===-- mlir-c/StandardTypes.h - C API for MLIR Standard types ----*- C -*-===*\
+|*                                                                            *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM          *|
+|* Exceptions.                                                                *|
+|* See https://llvm.org/LICENSE.txt for license information.                  *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception                    *|
+|*                                                                            *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef MLIR_C_STANDARDTYPES_H
+#define MLIR_C_STANDARDTYPES_H
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*============================================================================*/
+/* Integer types.                                                             */
+/*============================================================================*/
+
+/** Checks whether the given type is an integer type. */
+int mlirTypeIsAInteger(MlirType type);
+
+/** Creates a signless integer type of the given bitwidth in the context. The
+ * type is owned by the context. */
+MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth);
+
+/** Creates a signed integer type of the given bitwidth in the context. The type
+ * is owned by the context. */
+MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth);
+
+/** Creates an unsigned integer type of the given bitwidth in the context. The
+ * type is owned by the context. */
+MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth);
+
+/** Returns the bitwidth of an integer type. */
+unsigned mlirIntegerTypeGetWidth(MlirType type);
+
+/** Checks whether the given integer type is signless. */
+int mlirIntegerTypeIsSignless(MlirType type);
+
+/** Checks whether the given integer type is signed. */
+int mlirIntegerTypeIsSigned(MlirType type);
+
+/** Checks whether the given integer type is unsigned. */
+int mlirIntegerTypeIsUnsigned(MlirType type);
+
+/*============================================================================*/
+/* Index type.                                                                */
+/*============================================================================*/
+
+/** Checks whether the given type is an index type. */
+int mlirTypeIsAIndex(MlirType type);
+
+/** Creates an index type in the given context. The type is owned by the
+ * context. */
+MlirType mlirIndexTypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* Floating-point types.                                                      */
+/*============================================================================*/
+
+/** Checks whether the given type is a bf16 type. */
+int mlirTypeIsABF16(MlirType type);
+
+/** Creates a bf16 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirBF16TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f16 type. */
+int mlirTypeIsAF16(MlirType type);
+
+/** Creates an f16 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF16TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f32 type. */
+int mlirTypeIsAF32(MlirType type);
+
+/** Creates an f32 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF32TypeGet(MlirContext ctx);
+
+/** Checks whether the given type is an f64 type. */
+int mlirTypeIsAF64(MlirType type);
+
+/** Creates a f64 type in the given context. The type is owned by the
+ * context. */
+MlirType mlirF64TypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* None type.                                                                 */
+/*============================================================================*/
+
+/** Checks whether the given type is a None type. */
+int mlirTypeIsANone(MlirType type);
+
+/** Creates a None type in the given context. The type is owned by the
+ * context. */
+MlirType mlirNoneTypeGet(MlirContext ctx);
+
+/*============================================================================*/
+/* Complex type.                                                              */
+/*============================================================================*/
+
+/** Checks whether the given type is a Complex type. */
+int mlirTypeIsAComplex(MlirType type);
+
+/** Creates a complex type with the given element type in the same context as
+ * the element type. The type is owned by the context. */
+MlirType mlirComplexTypeGet(MlirType elementType);
+
+/** Returns the element type of the given complex type. */
+MlirType mlirComplexTypeGetElementType(MlirType type);
+
+/*============================================================================*/
+/* Shaped type.                                                               */
+/*============================================================================*/
+
+/** Checks whether the given type is a Shaped type. */
+int mlirTypeIsAShaped(MlirType type);
+
+/** Returns the element type of the shaped type. */
+MlirType mlirShapedTypeGetElementType(MlirType type);
+
+/** Checks whether the given shaped type is ranked. */
+int mlirShapedTypeHasRank(MlirType type);
+
+/** Returns the rank of the given ranked shaped type. */
+int64_t mlirShapedTypeGetRank(MlirType type);
+
+/** Checks whether the given shaped type has a static shape. */
+int mlirShapedTypeHasStaticShape(MlirType type);
+
+/** Checks wither the dim-th dimension of the given shaped type is dynamic. */
+int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim);
+
+/** Returns the dim-th dimension of the given ranked shaped type. */
+int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim);
+
+/** Checks whether the given value is used as a placeholder for dynamic sizes
+ * in shaped types. */
+int mlirShapedTypeIsDynamicSize(int64_t size);
+
+/** Checks whether the given value is used as a placeholder for dynamic strides
+ * and offsets in shaped types. */
+int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
+
+/*============================================================================*/
+/* Vector type.                                                               */
+/*============================================================================*/
+
+/** Checks whether the given type is a Vector type. */
+int mlirTypeIsAVector(MlirType type);
+
+/** Creates a vector type of the shape identified by its rank and dimensios,
+ * with the given element type in the same context as the element type. The type
+ * is owned by the context. */
+MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType);
+
+/*============================================================================*/
+/* Ranked / Unranked Tensor type.                                             */
+/*============================================================================*/
+
+/** Checks whether the given type is a Tensor type. */
+int mlirTypeIsATensor(MlirType type);
+
+/** Checks whether the given type is a ranked tensor type. */
+int mlirTypeIsARankedTensor(MlirType type);
+
+/** Checks whether the given type is an unranked tensor type. */
+int mlirTypeIsAUnrankedTensor(MlirType type);
+
+/** Creates a tensor type of a fixed rank with the given shape and element type
+ * in the same context as the element type. The type is owned by the context. */
+MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
+                                 MlirType elementType);
+
+/** Creates an unranked tensor type with the given element type in the same
+ * context as the element type. The type is owned by the context. */
+MlirType mlirUnrankedTensorTypeGet(MlirType elementType);
+
+/*============================================================================*/
+/* Ranked / Unranked MemRef type.                                             */
+/*============================================================================*/
+
+/** Checks whether the given type is a MemRef type. */
+int mlirTypeIsAMemRef(MlirType type);
+
+/** Checks whether the given type is an UnrankedMemRef type. */
+int mlirTypeIsAUnrankedMemRef(MlirType type);
+
+/** Creates a MemRef type with the given rank and shape, a potentially empty
+ * list of affine layout maps, the given memory space and element type, in the
+ * same context as element type. The type is owned by the context. */
+MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
+                           intptr_t numMaps, MlirAttribute *affineMaps,
+                           unsigned memorySpace);
+
+/** Creates a MemRef type with the given rank, shape, memory space and element
+ * type in the same context as the element type. The type has no affine maps,
+ * i.e. represents a default row-major contiguous memref. The type is owned by
+ * the context. */
+MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
+                                     int64_t *shape, unsigned memorySpace);
+
+/** Creates an Unranked MemRef type with the given element type and in the given
+ * memory space. The type is owned by the context of element type. */
+MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace);
+
+/** Returns the number of affine layout maps in the given MemRef type. */
+intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type);
+
+/** Returns the pos-th affine map of the given MemRef type. */
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos);
+
+/** Returns the memory space of the given MemRef type. */
+unsigned mlirMemRefTypeGetMemorySpace(MlirType type);
+
+/** Returns the memory spcae of the given Unranked MemRef type. */
+unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type);
+
+/*============================================================================*/
+/* Tuple type.                                                                */
+/*============================================================================*/
+
+/** Checks whether the given type is a tuple type. */
+int mlirTypeIsATuple(MlirType type);
+
+/** Creates a tuple type that consists of the given list of elemental types. The
+ * type is owned by the context. */
+MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
+                          MlirType *elements);
+
+/** Returns the number of types contained in a tuple. */
+intptr_t mlirTupleTypeGetNumTypes(MlirType type);
+
+/** Returns the pos-th type in the tuple type. */
+MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_STANDARDTYPES_H

diff  --git a/mlir/include/mlir/CAPI/AffineMap.h b/mlir/include/mlir/CAPI/AffineMap.h
new file mode 100644
index 000000000000..cea48ffae8b6
--- /dev/null
+++ b/mlir/include/mlir/CAPI/AffineMap.h
@@ -0,0 +1,24 @@
+//===- AffineMap.h - C API Utils for Affine Maps ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// MLIR Affine maps. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_AFFINEMAP_H
+#define MLIR_CAPI_AFFINEMAP_H
+
+#include "mlir-c/AffineMap.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/AffineMap.h"
+
+DEFINE_C_API_METHODS(MlirAffineMap, mlir::AffineMap)
+
+#endif // MLIR_CAPI_AFFINEMAP_H

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
new file mode 100644
index 000000000000..9a60ecf04fc8
--- /dev/null
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -0,0 +1,34 @@
+//===- IR.h - C API Utils for Core MLIR classes -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// core MLIR classes. This file should not be included from C++ code other than
+// C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INCLUDE_MLIR_CAPI_IR_H
+#define MLIR_INCLUDE_MLIR_CAPI_IR_H
+
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+
+DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext)
+DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation)
+DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block)
+DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region)
+
+DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute)
+DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
+DEFINE_C_API_METHODS(MlirType, mlir::Type)
+DEFINE_C_API_METHODS(MlirValue, mlir::Value)
+DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
+
+#endif // MLIR_INCLUDE_MLIR_CAPI_IR_H

diff  --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h
new file mode 100644
index 000000000000..940007caac06
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Wrap.h
@@ -0,0 +1,56 @@
+//===- Wrap.h - C API Utilities ---------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains common definitions for wrapping opaque C++ pointers into
+// C structures for the purpose of C API. This file should not be included from
+// C++ code other than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_WRAP_H
+#define MLIR_CAPI_WRAP_H
+
+#include "mlir-c/IR.h"
+#include "mlir/Support/LLVM.h"
+
+/* ========================================================================== */
+/* Definitions of methods for non-owning structures used in C API.            */
+/* ========================================================================== */
+
+#define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
+  static inline name wrap(cpptype *cpp) { return name{cpp}; }                  \
+  static inline cpptype *unwrap(name c) {                                      \
+    return static_cast<cpptype *>(c.ptr);                                      \
+  }
+
+#define DEFINE_C_API_METHODS(name, cpptype)                                    \
+  static inline name wrap(cpptype cpp) {                                       \
+    return name{cpp.getAsOpaquePointer()};                                     \
+  }                                                                            \
+  static inline cpptype unwrap(name c) {                                       \
+    return cpptype::getFromOpaquePointer(c.ptr);                               \
+  }
+
+template <typename CppTy, typename CTy>
+static llvm::ArrayRef<CppTy> unwrapList(size_t size, CTy *first,
+                                        llvm::SmallVectorImpl<CppTy> &storage) {
+  static_assert(
+      std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
+      "incompatible C and C++ types");
+
+  if (size == 0)
+    return llvm::None;
+
+  assert(storage.empty() && "expected to populate storage");
+  storage.reserve(size);
+  for (size_t i = 0; i < size; ++i)
+    storage.push_back(unwrap(*(first + i)));
+  return storage;
+}
+
+#endif // MLIR_CAPI_WRAP_H

diff  --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h
index d946c7591c2a..dd4960a02c5c 100644
--- a/mlir/include/mlir/IR/AffineMap.h
+++ b/mlir/include/mlir/IR/AffineMap.h
@@ -196,6 +196,14 @@ class AffineMap {
 
   friend ::llvm::hash_code hash_value(AffineMap arg);
 
+  /// Methods supporting C API.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(map);
+  }
+  static AffineMap getFromOpaquePointer(const void *pointer) {
+    return AffineMap(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
+  }
+
 private:
   ImplType *map;
 

diff  --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp
new file mode 100644
index 000000000000..d80d9e20486a
--- /dev/null
+++ b/mlir/lib/CAPI/IR/AffineMap.cpp
@@ -0,0 +1,15 @@
+//===- AffineMap.cpp - C API for MLIR Affine Maps -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/AffineMap.h"
+#include "mlir/IR/AffineMap.h"
+
+// This is a placeholder for affine map bindings. The file is here to serve as a
+// compilation unit that includes the headers.

diff  --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt
index fdf239f975d3..64e715e33d88 100644
--- a/mlir/lib/CAPI/IR/CMakeLists.txt
+++ b/mlir/lib/CAPI/IR/CMakeLists.txt
@@ -1,6 +1,8 @@
 # Main API.
 add_mlir_library(MLIRCAPIIR
+  AffineMap.cpp
   IR.cpp
+  StandardTypes.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5231096af785..1ba1a6aca6f8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir-c/IR.h"
 
+#include "mlir/CAPI/IR.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Operation.h"
@@ -17,46 +18,6 @@
 
 using namespace mlir;
 
-/* ========================================================================== */
-/* Definitions of methods for non-owning structures used in C API.            */
-/* ========================================================================== */
-
-#define DEFINE_C_API_PTR_METHODS(name, cpptype)                                \
-  static name wrap(cpptype *cpp) { return name{cpp}; }                         \
-  static cpptype *unwrap(name c) { return static_cast<cpptype *>(c.ptr); }
-
-DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext)
-DEFINE_C_API_PTR_METHODS(MlirOperation, Operation)
-DEFINE_C_API_PTR_METHODS(MlirBlock, Block)
-DEFINE_C_API_PTR_METHODS(MlirRegion, Region)
-
-#define DEFINE_C_API_METHODS(name, cpptype)                                    \
-  static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; }     \
-  static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); }
-
-DEFINE_C_API_METHODS(MlirAttribute, Attribute)
-DEFINE_C_API_METHODS(MlirLocation, Location);
-DEFINE_C_API_METHODS(MlirType, Type)
-DEFINE_C_API_METHODS(MlirValue, Value)
-DEFINE_C_API_METHODS(MlirModule, ModuleOp)
-
-template <typename CppTy, typename CTy>
-static ArrayRef<CppTy> unwrapList(intptr_t size, CTy *first,
-                                  SmallVectorImpl<CppTy> &storage) {
-  static_assert(
-      std::is_same<decltype(unwrap(std::declval<CTy>())), CppTy>::value,
-      "incompatible C and C++ types");
-
-  if (size == 0)
-    return llvm::None;
-
-  assert(storage.empty() && "expected to populate storage");
-  storage.reserve(size);
-  for (intptr_t i = 0; i < size; ++i)
-    storage.push_back(unwrap(*(first + i)));
-  return storage;
-}
-
 /* ========================================================================== */
 /* Printing helper.                                                           */
 /* ========================================================================== */
@@ -388,6 +349,8 @@ MlirType mlirTypeParseGet(MlirContext context, const char *type) {
   return wrap(mlir::parseType(type, unwrap(context)));
 }
 
+int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); }
+
 void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) {
   CallbackOstream stream(callback, userData);
   unwrap(type).print(stream);

diff  --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp
new file mode 100644
index 000000000000..eb006242e880
--- /dev/null
+++ b/mlir/lib/CAPI/IR/StandardTypes.cpp
@@ -0,0 +1,263 @@
+//===- StandardTypes.cpp - C Interface to MLIR Standard Types -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/StandardTypes.h"
+#include "mlir-c/AffineMap.h"
+#include "mlir-c/IR.h"
+#include "mlir/CAPI/AffineMap.h"
+#include "mlir/CAPI/IR.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+
+/* ========================================================================== */
+/* Integer types.                                                             */
+/* ========================================================================== */
+
+int mlirTypeIsAInteger(MlirType type) {
+  return unwrap(type).isa<IntegerType>();
+}
+
+MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
+  return wrap(IntegerType::get(bitwidth, unwrap(ctx)));
+}
+
+MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
+  return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx)));
+}
+
+MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
+  return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx)));
+}
+
+unsigned mlirIntegerTypeGetWidth(MlirType type) {
+  return unwrap(type).cast<IntegerType>().getWidth();
+}
+
+int mlirIntegerTypeIsSignless(MlirType type) {
+  return unwrap(type).cast<IntegerType>().isSignless();
+}
+
+int mlirIntegerTypeIsSigned(MlirType type) {
+  return unwrap(type).cast<IntegerType>().isSigned();
+}
+
+int mlirIntegerTypeIsUnsigned(MlirType type) {
+  return unwrap(type).cast<IntegerType>().isUnsigned();
+}
+
+/* ========================================================================== */
+/* Index type.                                                                */
+/* ========================================================================== */
+
+int mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa<IndexType>(); }
+
+MlirType mlirIndexTypeGet(MlirContext ctx) {
+  return wrap(IndexType::get(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* Floating-point types.                                                      */
+/* ========================================================================== */
+
+int mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
+
+MlirType mlirBF16TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getBF16(unwrap(ctx)));
+}
+
+int mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
+
+MlirType mlirF16TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getF16(unwrap(ctx)));
+}
+
+int mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
+
+MlirType mlirF32TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getF32(unwrap(ctx)));
+}
+
+int mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
+
+MlirType mlirF64TypeGet(MlirContext ctx) {
+  return wrap(FloatType::getF64(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* None type.                                                                 */
+/* ========================================================================== */
+
+int mlirTypeIsANone(MlirType type) { return unwrap(type).isa<NoneType>(); }
+
+MlirType mlirNoneTypeGet(MlirContext ctx) {
+  return wrap(NoneType::get(unwrap(ctx)));
+}
+
+/* ========================================================================== */
+/* Complex type.                                                              */
+/* ========================================================================== */
+
+int mlirTypeIsAComplex(MlirType type) {
+  return unwrap(type).isa<ComplexType>();
+}
+
+MlirType mlirComplexTypeGet(MlirType elementType) {
+  return wrap(ComplexType::get(unwrap(elementType)));
+}
+
+MlirType mlirComplexTypeGetElementType(MlirType type) {
+  return wrap(unwrap(type).cast<ComplexType>().getElementType());
+}
+
+/* ========================================================================== */
+/* Shaped type.                                                               */
+/* ========================================================================== */
+
+int mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa<ShapedType>(); }
+
+MlirType mlirShapedTypeGetElementType(MlirType type) {
+  return wrap(unwrap(type).cast<ShapedType>().getElementType());
+}
+
+int mlirShapedTypeHasRank(MlirType type) {
+  return unwrap(type).cast<ShapedType>().hasRank();
+}
+
+int64_t mlirShapedTypeGetRank(MlirType type) {
+  return unwrap(type).cast<ShapedType>().getRank();
+}
+
+int mlirShapedTypeHasStaticShape(MlirType type) {
+  return unwrap(type).cast<ShapedType>().hasStaticShape();
+}
+
+int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<ShapedType>().isDynamicDim(
+      static_cast<unsigned>(dim));
+}
+
+int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
+  return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
+}
+
+int mlirShapedTypeIsDynamicSize(int64_t size) {
+  return ShapedType::isDynamic(size);
+}
+
+int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
+  return ShapedType::isDynamicStrideOrOffset(val);
+}
+
+/* ========================================================================== */
+/* Vector type.                                                               */
+/* ========================================================================== */
+
+int mlirTypeIsAVector(MlirType type) { return unwrap(type).isa<VectorType>(); }
+
+MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape,
+                           MlirType elementType) {
+  return wrap(
+      VectorType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+                      unwrap(elementType)));
+}
+
+/* ========================================================================== */
+/* Ranked / Unranked tensor type.                                             */
+/* ========================================================================== */
+
+int mlirTypeIsATensor(MlirType type) { return unwrap(type).isa<TensorType>(); }
+
+int mlirTypeIsARankedTensor(MlirType type) {
+  return unwrap(type).isa<RankedTensorType>();
+}
+
+int mlirTypeIsAUnrankedTensor(MlirType type) {
+  return unwrap(type).isa<UnrankedTensorType>();
+}
+
+MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape,
+                                 MlirType elementType) {
+  return wrap(RankedTensorType::get(
+      llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+      unwrap(elementType)));
+}
+
+MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
+  return wrap(UnrankedTensorType::get(unwrap(elementType)));
+}
+
+/* ========================================================================== */
+/* Ranked / Unranked MemRef type.                                             */
+/* ========================================================================== */
+
+int mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa<MemRefType>(); }
+
+MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape,
+                           intptr_t numMaps, MlirAffineMap *affineMaps,
+                           unsigned memorySpace) {
+  SmallVector<AffineMap, 1> maps;
+  (void)unwrapList(numMaps, affineMaps, maps);
+  return wrap(
+      MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+                      unwrap(elementType), maps, memorySpace));
+}
+
+MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
+                                     int64_t *shape, unsigned memorySpace) {
+  return wrap(
+      MemRefType::get(llvm::makeArrayRef(shape, static_cast<size_t>(rank)),
+                      unwrap(elementType), llvm::None, memorySpace));
+}
+
+intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) {
+  return static_cast<intptr_t>(
+      unwrap(type).cast<MemRefType>().getAffineMaps().size());
+}
+
+MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) {
+  return wrap(unwrap(type).cast<MemRefType>().getAffineMaps()[pos]);
+}
+
+unsigned mlirMemRefTypeGetMemorySpace(MlirType type) {
+  return unwrap(type).cast<MemRefType>().getMemorySpace();
+}
+
+int mlirTypeIsAUnrankedMemRef(MlirType type) {
+  return unwrap(type).isa<UnrankedMemRefType>();
+}
+
+MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) {
+  return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace));
+}
+
+unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) {
+  return unwrap(type).cast<UnrankedMemRefType>().getMemorySpace();
+}
+
+/* ========================================================================== */
+/* Tuple type.                                                                */
+/* ========================================================================== */
+
+int mlirTypeIsATuple(MlirType type) { return unwrap(type).isa<TupleType>(); }
+
+MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
+                          MlirType *elements) {
+  SmallVector<Type, 4> types;
+  ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
+  return wrap(TupleType::get(typeRef, unwrap(ctx)));
+}
+
+intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
+  return unwrap(type).cast<TupleType>().size();
+}
+
+MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
+  return wrap(unwrap(type).cast<TupleType>().getType(static_cast<size_t>(pos)));
+}

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d6ab3513384f..56b7ecd7fd7c 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -12,6 +12,7 @@
 
 #include "mlir-c/IR.h"
 #include "mlir-c/Registration.h"
+#include "mlir-c/StandardTypes.h"
 
 #include <assert.h>
 #include <stdio.h>
@@ -240,6 +241,145 @@ static void printFirstOfEach(MlirOperation operation) {
   fprintf(stderr, "\n");
 }
 
+/// Dumps instances of all standard types to check that C API works correctly.
+/// Additionally, performs simple identity checks that a standard type
+/// constructed with C API can be inspected and has the expected type. The
+/// latter achieves full coverage of C API for standard types. Returns 0 on
+/// success and a non-zero error code on failure.
+static int printStandardTypes(MlirContext ctx) {
+  // Integer types.
+  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+  MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32);
+  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
+  if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32))
+    return 1;
+  if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32))
+    return 2;
+  if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32))
+    return 3;
+  if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32))
+    return 4;
+  if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32))
+    return 5;
+  mlirTypeDump(i32);
+  fprintf(stderr, "\n");
+  mlirTypeDump(si32);
+  fprintf(stderr, "\n");
+  mlirTypeDump(ui32);
+  fprintf(stderr, "\n");
+
+  // Index type.
+  MlirType index = mlirIndexTypeGet(ctx);
+  if (!mlirTypeIsAIndex(index))
+    return 6;
+  mlirTypeDump(index);
+  fprintf(stderr, "\n");
+
+  // Floating-point types.
+  MlirType bf16 = mlirBF16TypeGet(ctx);
+  MlirType f16 = mlirF16TypeGet(ctx);
+  MlirType f32 = mlirF32TypeGet(ctx);
+  MlirType f64 = mlirF64TypeGet(ctx);
+  if (!mlirTypeIsABF16(bf16))
+    return 7;
+  if (!mlirTypeIsAF16(f16))
+    return 9;
+  if (!mlirTypeIsAF32(f32))
+    return 10;
+  if (!mlirTypeIsAF64(f64))
+    return 11;
+  mlirTypeDump(bf16);
+  fprintf(stderr, "\n");
+  mlirTypeDump(f16);
+  fprintf(stderr, "\n");
+  mlirTypeDump(f32);
+  fprintf(stderr, "\n");
+  mlirTypeDump(f64);
+  fprintf(stderr, "\n");
+
+  // None type.
+  MlirType none = mlirNoneTypeGet(ctx);
+  if (!mlirTypeIsANone(none))
+    return 12;
+  mlirTypeDump(none);
+  fprintf(stderr, "\n");
+
+  // Complex type.
+  MlirType cplx = mlirComplexTypeGet(f32);
+  if (!mlirTypeIsAComplex(cplx) ||
+      !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32))
+    return 13;
+  mlirTypeDump(cplx);
+  fprintf(stderr, "\n");
+
+  // Vector (and Shaped) type. ShapedType is a common base class for vectors,
+  // memrefs and tensors, one cannot create instances of this class so it is
+  // tested on an instance of vector type.
+  int64_t shape[] = {2, 3};
+  MlirType vector =
+      mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
+  if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector))
+    return 14;
+  if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) ||
+      !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 ||
+      mlirShapedTypeGetDimSize(vector, 0) != 2 ||
+      mlirShapedTypeIsDynamicDim(vector, 0) ||
+      mlirShapedTypeGetDimSize(vector, 1) != 3 ||
+      !mlirShapedTypeHasStaticShape(vector))
+    return 15;
+  mlirTypeDump(vector);
+  fprintf(stderr, "\n");
+
+  // Ranked tensor type.
+  MlirType rankedTensor =
+      mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32);
+  if (!mlirTypeIsATensor(rankedTensor) ||
+      !mlirTypeIsARankedTensor(rankedTensor))
+    return 16;
+  mlirTypeDump(rankedTensor);
+  fprintf(stderr, "\n");
+
+  // Unranked tensor type.
+  MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32);
+  if (!mlirTypeIsATensor(unrankedTensor) ||
+      !mlirTypeIsAUnrankedTensor(unrankedTensor) ||
+      mlirShapedTypeHasRank(unrankedTensor))
+    return 17;
+  mlirTypeDump(unrankedTensor);
+  fprintf(stderr, "\n");
+
+  // MemRef type.
+  MlirType memRef = mlirMemRefTypeContiguousGet(
+      f32, sizeof(shape) / sizeof(int64_t), shape, 2);
+  if (!mlirTypeIsAMemRef(memRef) ||
+      mlirMemRefTypeGetNumAffineMaps(memRef) != 0 ||
+      mlirMemRefTypeGetMemorySpace(memRef) != 2)
+    return 18;
+  mlirTypeDump(memRef);
+  fprintf(stderr, "\n");
+
+  // Unranked MemRef type.
+  MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4);
+  if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) ||
+      mlirTypeIsAMemRef(unrankedMemRef) ||
+      mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4)
+    return 19;
+  mlirTypeDump(unrankedMemRef);
+  fprintf(stderr, "\n");
+
+  // Tuple type.
+  MlirType types[] = {unrankedMemRef, f32};
+  MlirType tuple = mlirTupleTypeGet(ctx, 2, types);
+  if (!mlirTypeIsATuple(tuple) || mlirTupleTypeGetNumTypes(tuple) != 2 ||
+      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) ||
+      !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32))
+    return 20;
+  mlirTypeDump(tuple);
+  fprintf(stderr, "\n");
+
+  return 0;
+}
+
 int main() {
   mlirRegisterAllDialects();
   MlirContext ctx = mlirContextCreate();
@@ -293,6 +433,31 @@ int main() {
   // clang-format on
 
   mlirModuleDestroy(moduleOp);
+
+  // clang-format off
+  // CHECK-LABEL: @types
+  // CHECK: i32
+  // CHECK: si32
+  // CHECK: ui32
+  // CHECK: index
+  // CHECK: bf16
+  // CHECK: f16
+  // CHECK: f32
+  // CHECK: f64
+  // CHECK: none
+  // CHECK: complex<f32>
+  // CHECK: vector<2x3xf32>
+  // CHECK: tensor<2x3xf32>
+  // CHECK: tensor<*xf32>
+  // CHECK: memref<2x3xf32, 2>
+  // CHECK: memref<*xf32, 4>
+  // CHECK: tuple<memref<*xf32, 4>, f32>
+  // CHECK: 0
+  // clang-format on
+  fprintf(stderr, "@types");
+  int errcode = printStandardTypes(ctx);
+  fprintf(stderr, "%d\n", errcode);
+
   mlirContextDestroy(ctx);
 
   return 0;


        


More information about the Mlir-commits mailing list