[Mlir-commits] [mlir] 782a97a - [mlir][capi] Add TypeID to MLIR C-API

Daniel Resnick llvmlistbot at llvm.org
Fri Oct 1 13:50:00 PDT 2021


Author: Daniel Resnick
Date: 2021-10-01T14:21:18-06:00
New Revision: 782a97a9776a945cf06a6defd37b227665ffe08b

URL: https://github.com/llvm/llvm-project/commit/782a97a9776a945cf06a6defd37b227665ffe08b
DIFF: https://github.com/llvm/llvm-project/commit/782a97a9776a945cf06a6defd37b227665ffe08b.diff

LOG: [mlir][capi] Add TypeID to MLIR C-API

Exposes mlir::TypeID to the C API as MlirTypeID along with various accessors
and helper functions.

Differential Revision: https://reviews.llvm.org/D110897

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/include/mlir/CAPI/IR.h
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 28a83cba0bbc0..92697a248b71b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void);
 DEFINE_C_API_STRUCT(MlirLocation, const void);
 DEFINE_C_API_STRUCT(MlirModule, const void);
 DEFINE_C_API_STRUCT(MlirType, const void);
+DEFINE_C_API_STRUCT(MlirTypeID, const void);
 DEFINE_C_API_STRUCT(MlirValue, const void);
 
 #undef DEFINE_C_API_STRUCT
@@ -356,6 +357,11 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op,
 /// Gets the context this operation is associated with
 MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op);
 
+/// Gets the type id of the operation.
+/// Returns null if the operation does not have a registered operation
+/// description.
+MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op);
+
 /// Gets the name of the operation as an identifier.
 MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op);
 
@@ -626,6 +632,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context,
 /// Gets the context that a type was created with.
 MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type);
 
+/// Gets the type ID of the type.
+MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type);
+
 /// Checks whether a type is null.
 static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; }
 
@@ -655,6 +664,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute);
 /// Gets the type of this attribute.
 MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute);
 
+/// Gets the type id of the attribute.
+MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute);
+
 /// Checks whether an attribute is null.
 static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; }
 
@@ -693,6 +705,21 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident,
 /// Gets the string value of the identifier.
 MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident);
 
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+/// Checks whether a type id is null.
+MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) {
+  return !typeID.ptr;
+}
+
+/// Checks if two type ids are equal.
+MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2);
+
+/// Returns the hash value of the type id.
+MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h
index ea7b265dd8efc..d5e961367e79a 100644
--- a/mlir/include/mlir/CAPI/IR.h
+++ b/mlir/include/mlir/CAPI/IR.h
@@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier)
 DEFINE_C_API_METHODS(MlirLocation, mlir::Location)
 DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp)
 DEFINE_C_API_METHODS(MlirType, mlir::Type)
+DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID)
 DEFINE_C_API_METHODS(MlirValue, mlir::Value)
 
 #endif // MLIR_CAPI_IR_H

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index eda176300dc30..ee5a5551133c9 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Parser.h"
 
 #include "llvm/Support/Debug.h"
+#include <cstddef>
 
 using namespace mlir;
 
@@ -345,6 +346,13 @@ MlirContext mlirOperationGetContext(MlirOperation op) {
   return wrap(unwrap(op)->getContext());
 }
 
+MlirTypeID mlirOperationGetTypeID(MlirOperation op) {
+  if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) {
+    return wrap(abstractOp->typeID);
+  }
+  return {nullptr};
+}
+
 MlirIdentifier mlirOperationGetName(MlirOperation op) {
   return wrap(unwrap(op)->getName().getIdentifier());
 }
@@ -658,6 +666,10 @@ MlirContext mlirTypeGetContext(MlirType type) {
   return wrap(unwrap(type).getContext());
 }
 
+MlirTypeID mlirTypeGetTypeID(MlirType type) {
+  return wrap(unwrap(type).getTypeID());
+}
+
 bool mlirTypeEqual(MlirType t1, MlirType t2) {
   return unwrap(t1) == unwrap(t2);
 }
@@ -685,6 +697,10 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) {
   return wrap(unwrap(attribute).getType());
 }
 
+MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) {
+  return wrap(unwrap(attr).getTypeID());
+}
+
 bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) {
   return unwrap(a1) == unwrap(a2);
 }
@@ -721,3 +737,15 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) {
 MlirStringRef mlirIdentifierStr(MlirIdentifier ident) {
   return wrap(unwrap(ident).strref());
 }
+
+//===----------------------------------------------------------------------===//
+// TypeID API.
+//===----------------------------------------------------------------------===//
+
+bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) {
+  return unwrap(typeID1) == unwrap(typeID2);
+}
+
+size_t mlirTypeIDHashValue(MlirTypeID typeID) {
+  return hash_value(unwrap(typeID));
+}

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d85af8fb6b700..931f72f9e76b5 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -1739,6 +1739,99 @@ void testDiagnostics() {
   // CHECK:     more test diagnostics
 }
 
+int testTypeID(MlirContext ctx) {
+  fprintf(stderr, "@testTypeID\n");
+
+  // Test getting and comparing type and attribute type ids.
+  MlirType i32 = mlirIntegerTypeGet(ctx, 32);
+  MlirTypeID i32ID = mlirTypeGetTypeID(i32);
+  MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32);
+  MlirTypeID ui32ID = mlirTypeGetTypeID(ui32);
+  MlirType f32 = mlirF32TypeGet(ctx);
+  MlirTypeID f32ID = mlirTypeGetTypeID(f32);
+  MlirAttribute i32Attr = mlirIntegerAttrGet(i32, 1);
+  MlirTypeID i32AttrID = mlirAttributeGetTypeID(i32Attr);
+
+  if (mlirTypeIDIsNull(i32ID) || mlirTypeIDIsNull(ui32ID) ||
+      mlirTypeIDIsNull(f32ID) || mlirTypeIDIsNull(i32AttrID)) {
+    fprintf(stderr, "ERROR: Expected type ids to be present\n");
+    return 1;
+  }
+
+  if (!mlirTypeIDEqual(i32ID, ui32ID) ||
+      mlirTypeIDHashValue(i32ID) != mlirTypeIDHashValue(ui32ID)) {
+    fprintf(
+        stderr,
+        "ERROR: Expected 
diff erent integer types to have the same type id\n");
+    return 2;
+  }
+
+  if (mlirTypeIDEqual(i32ID, f32ID) ||
+      mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(f32ID)) {
+    fprintf(stderr,
+            "ERROR: Expected integer type id to not equal float type id\n");
+    return 3;
+  }
+
+  if (mlirTypeIDEqual(i32ID, i32AttrID) ||
+      mlirTypeIDHashValue(i32ID) == mlirTypeIDHashValue(i32AttrID)) {
+    fprintf(stderr, "ERROR: Expected integer type id to not equal integer "
+                    "attribute type id\n");
+    return 4;
+  }
+
+  MlirLocation loc = mlirLocationUnknownGet(ctx);
+  MlirType indexType = mlirIndexTypeGet(ctx);
+  MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
+
+  // Create a registered operation, which should have a type id.
+  MlirAttribute indexZeroLiteral =
+      mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
+  MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
+      mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
+  MlirOperationState constZeroState = mlirOperationStateGet(
+      mlirStringRefCreateFromCString("std.constant"), loc);
+  mlirOperationStateAddResults(&constZeroState, 1, &indexType);
+  mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
+  MlirOperation constZero = mlirOperationCreate(&constZeroState);
+
+  if (mlirOperationIsNull(constZero)) {
+    fprintf(stderr, "ERROR: Expected registered operation to be present\n");
+    return 5;
+  }
+
+  MlirTypeID registeredOpID = mlirOperationGetTypeID(constZero);
+
+  if (mlirTypeIDIsNull(registeredOpID)) {
+    fprintf(stderr,
+            "ERROR: Expected registered operation type id to be present\n");
+    return 6;
+  }
+
+  // Create an unregistered operation, which should not have a type id.
+  mlirContextSetAllowUnregisteredDialects(ctx, true);
+  MlirOperationState opState =
+      mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc);
+  MlirOperation unregisteredOp = mlirOperationCreate(&opState);
+  if (mlirOperationIsNull(unregisteredOp)) {
+    fprintf(stderr, "ERROR: Expected unregistered operation to be present\n");
+    return 7;
+  }
+
+  MlirTypeID unregisteredOpID = mlirOperationGetTypeID(unregisteredOp);
+
+  if (!mlirTypeIDIsNull(unregisteredOpID)) {
+    fprintf(stderr,
+            "ERROR: Expected unregistered operation type id to be null\n");
+    return 8;
+  }
+
+  mlirOperationDestroy(constZero);
+  mlirOperationDestroy(unregisteredOp);
+
+  return 0;
+}
+
 int main() {
   MlirContext ctx = mlirContextCreate();
   mlirRegisterAllDialects(ctx);
@@ -1768,6 +1861,9 @@ int main() {
     return 11;
   if (testClone())
     return 12;
+  if (testTypeID(ctx)) {
+    return 13;
+  }
 
   mlirContextDestroy(ctx);
 


        


More information about the Mlir-commits mailing list