[Mlir-commits] [mlir] 782a97a - [mlir][capi] Add TypeID to MLIR C-API
Daniel Resnick
llvmlistbot at llvm.org
Fri Oct 1 13:50:00 PDT 2021
Author: Daniel Resnick
Date: 2021-10-01T14:21:18-06:00
New Revision: 782a97a9776a945cf06a6defd37b227665ffe08b
URL: https://github.com/llvm/llvm-project/commit/782a97a9776a945cf06a6defd37b227665ffe08b
DIFF: https://github.com/llvm/llvm-project/commit/782a97a9776a945cf06a6defd37b227665ffe08b.diff
LOG: [mlir][capi] Add TypeID to MLIR C-API
Exposes mlir::TypeID to the C API as MlirTypeID along with various accessors
and helper functions.
Differential Revision: https://reviews.llvm.org/D110897
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/include/mlir/CAPI/IR.h
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 28a83cba0bbc0..92697a248b71b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
DEFINE_C_API_STRUCT(MlirLocation, const void);
DEFINE_C_API_STRUCT(MlirModule, const void);
DEFINE_C_API_STRUCT(MlirType, const void);
+DEFINE_C_API_STRUCT(MlirTypeID, const void);
DEFINE_C_API_STRUCT(MlirValue, const void);
#undef DEFINE_C_API_STRUCT
@@ -356,6 +357,11 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op,
/// Gets the context this operation is associated with
MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
+/// Gets the type id of the operation.
+/// Returns null if the operation does not have a registered operation
+/// description.
+MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op);
+
/// Gets the name of the operation as an identifier.
MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op);
@@ -626,6 +632,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context,
/// Gets the context that a type was created with.
MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
+/// Gets the type ID of the type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
+
/// Checks whether a type is null.
static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
@@ -655,6 +664,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute);
/// Gets the type of this attribute.
MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
+/// Gets the type id of the attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
+
/// Checks whether an attribute is null.
static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
@@ -693,6 +705,21 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
/// Gets the string value of the identifier.
MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether a type id is null.
+MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) {
+ return !typeID.ptr;
+}
+
+/// Checks if two type ids are equal.
+MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
+
+/// Returns the hash value of the type id.
+MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index ea7b265dd8efc..d5e961367e79a 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
DEFINE_C_API_METHODS(MlirType, mlir::Type)
+DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
DEFINE_C_API_METHODS(MlirValue, mlir::Value)
#endif // MLIR_CAPI_IR_H
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index eda176300dc30..ee5a5551133c9 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -23,6 +23,7 @@
#include "mlir/Parser.h"
#include "llvm/Support/Debug.h"
+#include <cstddef>
using namespace mlir;
@@ -345,6 +346,13 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
return wrap(unwrap(op)->getContext());
}
+MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
+ if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) {
+ return wrap(abstractOp->typeID);
+ }
+ return {nullptr};
+}
+
MlirIdentifier mlirOperationGetName(MlirOperation op) {
return wrap(unwrap(op)->getName().getIdentifier());
}
@@ -658,6 +666,10 @@ MlirContext mlirTypeGetContext(MlirType type) {
return wrap(unwrap(type).getContext());
}
+MlirTypeID mlirTypeGetTypeID(MlirType type) {
+ return wrap(unwrap(type).getTypeID());
+}
+
bool mlirTypeEqual(MlirType t1, MlirType t2) {
return unwrap(t1) == unwrap(t2);
}
@@ -685,6 +697,10 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) {
return wrap(unwrap(attribute).getType());
}
+MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
+ return wrap(unwrap(attr).getTypeID());
+}
+
bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
return unwrap(a1) == unwrap(a2);
}
@@ -721,3 +737,15 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
return wrap(unwrap(ident).strref());
}
+
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
+ return unwrap(typeID1) == unwrap(typeID2);
+}
+
+size_t mlirTypeIDHashValue(MlirTypeID typeID) {
+ return hash_value(unwrap(typeID));
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d85af8fb6b700..931f72f9e76b5 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1739,6 +1739,99 @@ void testDiagnostics() {
// CHECK: more test diagnostics
}
+int testTypeID(MlirContext ctx) {
+ fprintf(stderr, "@testTypeID\n");
+
+ // Test getting and comparing type and attribute type ids.
+ MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+ MlirTypeID i32ID = mlirTypeGetTypeID(i32);
+ MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
+ MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
+ MlirType f32 = mlirF32TypeGet(ctx);
+ MlirTypeID f32ID = mlirTypeGetTypeID(f32);
+ MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
+ MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);
+
+ if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
+ mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
+ fprintf(stderr, "ERROR: Expected type ids to be present\n");
+ return 1;
+ }
+
+ if (!mlirTypeIDEqual(i32ID, ui32ID) ||
+ mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
+ fprintf(
+ stderr,
+ "ERROR: Expected
diff erent integer types to have the same type id\n");
+ return 2;
+ }
+
+ if (mlirTypeIDEqual(i32ID, f32ID) ||
+ mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(f32ID)) {
+ fprintf(stderr,
+ "ERROR: Expected integer type id to not equal float type id\n");
+ return 3;
+ }
+
+ if (mlirTypeIDEqual(i32ID, i32AttrID) ||
+ mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(i32AttrID)) {
+ fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
+ "attribute type id\n");
+ return 4;
+ }
+
+ MlirLocation loc = mlirLocationUnknownGet(ctx);
+ MlirType indexType = mlirIndexTypeGet(ctx);
+ MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
+
+ // Create a registered operation, which should have a type id.
+ MlirAttribute indexZeroLiteral =
+ mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+ MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+ mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+ MlirOperationState constZeroState = mlirOperationStateGet(
+ mlirStringRefCreateFromCString("std.constant"), loc);
+ mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+ mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+ MlirOperation constZero = mlirOperationCreate(&constZeroState);
+
+ if (mlirOperationIsNull(constZero)) {
+ fprintf(stderr, "ERROR: Expected registered operation to be present\n");
+ return 5;
+ }
+
+ MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
+
+ if (mlirTypeIDIsNull(registeredOpID)) {
+ fprintf(stderr,
+ "ERROR: Expected registered operation type id to be present\n");
+ return 6;
+ }
+
+ // Create an unregistered operation, which should not have a type id.
+ mlirContextSetAllowUnregisteredDialects(ctx, true);
+ MlirOperationState opState =
+ mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
+ MlirOperation unregisteredOp = mlirOperationCreate(&opState);
+ if (mlirOperationIsNull(unregisteredOp)) {
+ fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
+ return 7;
+ }
+
+ MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
+
+ if (!mlirTypeIDIsNull(unregisteredOpID)) {
+ fprintf(stderr,
+ "ERROR: Expected unregistered operation type id to be null\n");
+ return 8;
+ }
+
+ mlirOperationDestroy(constZero);
+ mlirOperationDestroy(unregisteredOp);
+
+ return 0;
+}
+
int main() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
@@ -1768,6 +1861,9 @@ int main() {
return 11;
if (testClone())
return 12;
+ if (testTypeID(ctx)) {
+ return 13;
+ }
mlirContextDestroy(ctx);
More information about the Mlir-commits
mailing list