[Mlir-commits] [mlir] 3545426 - [mlir][CAPI] Expose [u]int8 DenseElementsAttr.

Sean Silva llvmlistbot at llvm.org
Wed May 19 13:41:58 PDT 2021


Author: Sean Silva
Date: 2021-05-19T13:41:44-07:00
New Revision: 35454268cf93f5561439980d6baeb27a874a380c

URL: https://github.com/llvm/llvm-project/commit/35454268cf93f5561439980d6baeb27a874a380c
DIFF: https://github.com/llvm/llvm-project/commit/35454268cf93f5561439980d6baeb27a874a380c.diff

LOG: [mlir][CAPI] Expose [u]int8 DenseElementsAttr.

Also, fix a small typo where the "unsigned" splat variants were not
being created with an unsigned type.

Differential Revision: https://reviews.llvm.org/D102797

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index c85825c8d91d9..247de5cc0bd62 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -313,6 +313,10 @@ mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element);
 MLIR_CAPI_EXPORTED MlirAttribute
+mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element);
+MLIR_CAPI_EXPORTED MlirAttribute
 mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element);
@@ -330,6 +334,10 @@ mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element);
 /// data element type.
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolGet(
     MlirType shapedType, intptr_t numElements, const int *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
+    MlirType shapedType, intptr_t numElements, const uint8_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
+    MlirType shapedType, intptr_t numElements, const int8_t *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
     MlirType shapedType, intptr_t numElements, const uint32_t *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@@ -364,6 +372,10 @@ MLIR_CAPI_EXPORTED MlirAttribute
 mlirDenseElementsAttrGetSplatValue(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int
 mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr);
+MLIR_CAPI_EXPORTED int8_t
+mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr);
+MLIR_CAPI_EXPORTED uint8_t
+mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr);
 MLIR_CAPI_EXPORTED int32_t
 mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr);
 MLIR_CAPI_EXPORTED uint32_t
@@ -383,6 +395,10 @@ mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr);
 /// contained by the given dense elements attribute.
 MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr,
                                                           intptr_t pos);
+MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
+                                                            intptr_t pos);
+MLIR_CAPI_EXPORTED uint8_t
+mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED int32_t
 mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED uint32_t

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 7580786def865..93a6eff996302 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -341,6 +341,16 @@ MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
   return wrap(
       DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
 }
+MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
+                                                 uint8_t element) {
+  return wrap(
+      DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
+                                                int8_t element) {
+  return wrap(
+      DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
 MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
                                                   uint32_t element) {
   return wrap(
@@ -390,6 +400,16 @@ static MlirAttribute getDenseAttribute(MlirType shapedType,
                              llvm::makeArrayRef(elements, numElements)));
 }
 
+MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
+                                            intptr_t numElements,
+                                            const uint8_t *elements) {
+  return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
+                                           intptr_t numElements,
+                                           const int8_t *elements) {
+  return getDenseAttribute(shapedType, numElements, elements);
+}
 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
                                              intptr_t numElements,
                                              const uint32_t *elements) {
@@ -452,6 +472,12 @@ MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
 int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
 }
+int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
+  return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int8_t>();
+}
+uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
+  return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint8_t>();
+}
 int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
   return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
 }
@@ -482,6 +508,14 @@ bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
            pos);
 }
+int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
+  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
+           pos);
+}
+uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
+  return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
+           pos);
+}
 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
   return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
            pos);

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 42dfa532727eb..be9799e249bab 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -886,6 +886,8 @@ int printBuiltinAttributes(MlirContext ctx) {
   int64_t shape[] = {1, 2};
 
   int bools[] = {0, 1};
+  uint8_t uints8[] = {0u, 1u};
+  int8_t ints8[] = {0, 1};
   uint32_t uints32[] = {0u, 1u};
   int32_t ints32[] = {0, 1};
   uint64_t uints64[] = {0u, 1u};
@@ -896,6 +898,13 @@ int printBuiltinAttributes(MlirContext ctx) {
   MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
       2, bools);
+  MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
+                              encoding),
+      2, uints8);
+  MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
+      2, ints8);
   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
       mlirRankedTensorTypeGet(2, shape,
                               mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
@@ -918,6 +927,8 @@ int printBuiltinAttributes(MlirContext ctx) {
       2, doubles);
 
   if (!mlirAttributeIsADenseElements(boolElements) ||
+      !mlirAttributeIsADenseElements(uint8Elements) ||
+      !mlirAttributeIsADenseElements(int8Elements) ||
       !mlirAttributeIsADenseElements(uint32Elements) ||
       !mlirAttributeIsADenseElements(int32Elements) ||
       !mlirAttributeIsADenseElements(uint64Elements) ||
@@ -927,6 +938,8 @@ int printBuiltinAttributes(MlirContext ctx) {
     return 14;
 
   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
+      mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
+      mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
       mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
       mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
       mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
@@ -937,6 +950,8 @@ int printBuiltinAttributes(MlirContext ctx) {
     return 15;
 
   mlirAttributeDump(boolElements);
+  mlirAttributeDump(uint8Elements);
+  mlirAttributeDump(int8Elements);
   mlirAttributeDump(uint32Elements);
   mlirAttributeDump(int32Elements);
   mlirAttributeDump(uint64Elements);
@@ -944,6 +959,8 @@ int printBuiltinAttributes(MlirContext ctx) {
   mlirAttributeDump(floatElements);
   mlirAttributeDump(doubleElements);
   // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
+  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
+  // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
@@ -952,20 +969,29 @@ int printBuiltinAttributes(MlirContext ctx) {
   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
 
   MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeGet(ctx, 1), encoding), 1);
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
+      1);
+  MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
+                              encoding),
+      1);
+  MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
+      1);
   MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeGet(ctx, 32), encoding), 1);
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
+                              encoding),
+      1);
   MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeGet(ctx, 32), encoding), 1);
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
+      1);
   MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeGet(ctx, 64), encoding), 1);
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
+                              encoding),
+      1);
   MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
-      mlirRankedTensorTypeGet(2, shape,
-                              mlirIntegerTypeGet(ctx, 64), encoding), 1);
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
+      1);
   MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
       mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
   MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
@@ -973,6 +999,10 @@ int printBuiltinAttributes(MlirContext ctx) {
 
   if (!mlirAttributeIsADenseElements(splatBool) ||
       !mlirDenseElementsAttrIsSplat(splatBool) ||
+      !mlirAttributeIsADenseElements(splatUInt8) ||
+      !mlirDenseElementsAttrIsSplat(splatUInt8) ||
+      !mlirAttributeIsADenseElements(splatInt8) ||
+      !mlirDenseElementsAttrIsSplat(splatInt8) ||
       !mlirAttributeIsADenseElements(splatUInt32) ||
       !mlirDenseElementsAttrIsSplat(splatUInt32) ||
       !mlirAttributeIsADenseElements(splatInt32) ||
@@ -988,6 +1018,8 @@ int printBuiltinAttributes(MlirContext ctx) {
     return 16;
 
   if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
+      mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
+      mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
       mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
       mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
       mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
@@ -997,6 +1029,9 @@ int printBuiltinAttributes(MlirContext ctx) {
       fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
     return 17;
 
+  uint8_t *uint8RawData =
+      (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
+  int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
   uint32_t *uint32RawData =
       (uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
   int32_t *int32RawData =
@@ -1008,7 +1043,8 @@ int printBuiltinAttributes(MlirContext ctx) {
   float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
   double *doubleRawData =
       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
-  if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
+  if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
+      int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
       int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
       uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
@@ -1016,6 +1052,8 @@ int printBuiltinAttributes(MlirContext ctx) {
     return 18;
 
   mlirAttributeDump(splatBool);
+  mlirAttributeDump(splatUInt8);
+  mlirAttributeDump(splatInt8);
   mlirAttributeDump(splatUInt32);
   mlirAttributeDump(splatInt32);
   mlirAttributeDump(splatUInt64);
@@ -1023,9 +1061,11 @@ int printBuiltinAttributes(MlirContext ctx) {
   mlirAttributeDump(splatFloat);
   mlirAttributeDump(splatDouble);
   // CHECK: dense<true> : tensor<1x2xi1>
+  // CHECK: dense<1> : tensor<1x2xui8>
+  // CHECK: dense<1> : tensor<1x2xi8>
+  // CHECK: dense<1> : tensor<1x2xui32>
   // CHECK: dense<1> : tensor<1x2xi32>
-  // CHECK: dense<1> : tensor<1x2xi32>
-  // CHECK: dense<1> : tensor<1x2xi64>
+  // CHECK: dense<1> : tensor<1x2xui64>
   // CHECK: dense<1> : tensor<1x2xi64>
   // CHECK: dense<1.000000e+00> : tensor<1x2xf32>
   // CHECK: dense<1.000000e+00> : tensor<1x2xf64>


        


More information about the Mlir-commits mailing list