[Mlir-commits] [mlir] 3545426 - [mlir][CAPI] Expose [u]int8 DenseElementsAttr.
Sean Silva
llvmlistbot at llvm.org
Wed May 19 13:41:58 PDT 2021
Author: Sean Silva
Date: 2021-05-19T13:41:44-07:00
New Revision: 35454268cf93f5561439980d6baeb27a874a380c
URL: https://github.com/llvm/llvm-project/commit/35454268cf93f5561439980d6baeb27a874a380c
DIFF: https://github.com/llvm/llvm-project/commit/35454268cf93f5561439980d6baeb27a874a380c.diff
LOG: [mlir][CAPI] Expose [u]int8 DenseElementsAttr.
Also, fix a small typo where the "unsigned" splat variants were not
being created with an unsigned type.
Differential Revision: https://reviews.llvm.org/D102797
Added:
Modified:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/test/CAPI/ir.c
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index c85825c8d91d9..247de5cc0bd62 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -313,6 +313,10 @@ mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element);
MLIR_CAPI_EXPORTED MlirAttribute
mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element);
MLIR_CAPI_EXPORTED MlirAttribute
+mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element);
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element);
+MLIR_CAPI_EXPORTED MlirAttribute
mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element);
MLIR_CAPI_EXPORTED MlirAttribute
mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element);
@@ -330,6 +334,10 @@ mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element);
/// data element type.
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolGet(
MlirType shapedType, intptr_t numElements, const int *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
+ MlirType shapedType, intptr_t numElements, const uint8_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
+ MlirType shapedType, intptr_t numElements, const int8_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
MlirType shapedType, intptr_t numElements, const uint32_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@@ -364,6 +372,10 @@ MLIR_CAPI_EXPORTED MlirAttribute
mlirDenseElementsAttrGetSplatValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED int
mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr);
+MLIR_CAPI_EXPORTED int8_t
+mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr);
+MLIR_CAPI_EXPORTED uint8_t
+mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED int32_t
mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr);
MLIR_CAPI_EXPORTED uint32_t
@@ -383,6 +395,10 @@ mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr);
/// contained by the given dense elements attribute.
MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr,
intptr_t pos);
+MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
+ intptr_t pos);
+MLIR_CAPI_EXPORTED uint8_t
+mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED int32_t
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED uint32_t
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 7580786def865..93a6eff996302 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -341,6 +341,16 @@ MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType,
return wrap(
DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
}
+MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType,
+ uint8_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
+MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType,
+ int8_t element) {
+ return wrap(
+ DenseElementsAttr::get(unwrap(shapedType).cast<ShapedType>(), element));
+}
MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType,
uint32_t element) {
return wrap(
@@ -390,6 +400,16 @@ static MlirAttribute getDenseAttribute(MlirType shapedType,
llvm::makeArrayRef(elements, numElements)));
}
+MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType,
+ intptr_t numElements,
+ const uint8_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
+ intptr_t numElements,
+ const int8_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
const uint32_t *elements) {
@@ -452,6 +472,12 @@ MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) {
int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<bool>();
}
+int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int8_t>();
+}
+uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) {
+ return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<uint8_t>();
+}
int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) {
return unwrap(attr).cast<DenseElementsAttr>().getSplatValue<int32_t>();
}
@@ -482,6 +508,14 @@ bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<bool>().begin() +
pos);
}
+int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int8_t>().begin() +
+ pos);
+}
+uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
+ return *(unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>().begin() +
+ pos);
+}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
return *(unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>().begin() +
pos);
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 42dfa532727eb..be9799e249bab 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -886,6 +886,8 @@ int printBuiltinAttributes(MlirContext ctx) {
int64_t shape[] = {1, 2};
int bools[] = {0, 1};
+ uint8_t uints8[] = {0u, 1u};
+ int8_t ints8[] = {0, 1};
uint32_t uints32[] = {0u, 1u};
int32_t ints32[] = {0, 1};
uint64_t uints64[] = {0u, 1u};
@@ -896,6 +898,13 @@ int printBuiltinAttributes(MlirContext ctx) {
MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
2, bools);
+ MlirAttribute uint8Elements = mlirDenseElementsAttrUInt8Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
+ encoding),
+ 2, uints8);
+ MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
+ 2, ints8);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
mlirRankedTensorTypeGet(2, shape,
mlirIntegerTypeUnsignedGet(ctx, 32), encoding),
@@ -918,6 +927,8 @@ int printBuiltinAttributes(MlirContext ctx) {
2, doubles);
if (!mlirAttributeIsADenseElements(boolElements) ||
+ !mlirAttributeIsADenseElements(uint8Elements) ||
+ !mlirAttributeIsADenseElements(int8Elements) ||
!mlirAttributeIsADenseElements(uint32Elements) ||
!mlirAttributeIsADenseElements(int32Elements) ||
!mlirAttributeIsADenseElements(uint64Elements) ||
@@ -927,6 +938,8 @@ int printBuiltinAttributes(MlirContext ctx) {
return 14;
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
+ mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
@@ -937,6 +950,8 @@ int printBuiltinAttributes(MlirContext ctx) {
return 15;
mlirAttributeDump(boolElements);
+ mlirAttributeDump(uint8Elements);
+ mlirAttributeDump(int8Elements);
mlirAttributeDump(uint32Elements);
mlirAttributeDump(int32Elements);
mlirAttributeDump(uint64Elements);
@@ -944,6 +959,8 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirAttributeDump(floatElements);
mlirAttributeDump(doubleElements);
// CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
+ // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32>
// CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64>
@@ -952,20 +969,29 @@ int printBuiltinAttributes(MlirContext ctx) {
// CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeGet(ctx, 1), encoding), 1);
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
+ 1);
+ MlirAttribute splatUInt8 = mlirDenseElementsAttrUInt8SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 8),
+ encoding),
+ 1);
+ MlirAttribute splatInt8 = mlirDenseElementsAttrInt8SplatGet(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
+ 1);
MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeGet(ctx, 32), encoding), 1);
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
+ encoding),
+ 1);
MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeGet(ctx, 32), encoding), 1);
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding),
+ 1);
MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeGet(ctx, 64), encoding), 1);
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64),
+ encoding),
+ 1);
MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet(
- mlirRankedTensorTypeGet(2, shape,
- mlirIntegerTypeGet(ctx, 64), encoding), 1);
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding),
+ 1);
MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet(
mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f);
MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet(
@@ -973,6 +999,10 @@ int printBuiltinAttributes(MlirContext ctx) {
if (!mlirAttributeIsADenseElements(splatBool) ||
!mlirDenseElementsAttrIsSplat(splatBool) ||
+ !mlirAttributeIsADenseElements(splatUInt8) ||
+ !mlirDenseElementsAttrIsSplat(splatUInt8) ||
+ !mlirAttributeIsADenseElements(splatInt8) ||
+ !mlirDenseElementsAttrIsSplat(splatInt8) ||
!mlirAttributeIsADenseElements(splatUInt32) ||
!mlirDenseElementsAttrIsSplat(splatUInt32) ||
!mlirAttributeIsADenseElements(splatInt32) ||
@@ -988,6 +1018,8 @@ int printBuiltinAttributes(MlirContext ctx) {
return 16;
if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 ||
+ mlirDenseElementsAttrGetUInt8SplatValue(splatUInt8) != 1 ||
+ mlirDenseElementsAttrGetInt8SplatValue(splatInt8) != 1 ||
mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 ||
mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 ||
mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 ||
@@ -997,6 +1029,9 @@ int printBuiltinAttributes(MlirContext ctx) {
fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6)
return 17;
+ uint8_t *uint8RawData =
+ (uint8_t *)mlirDenseElementsAttrGetRawData(uint8Elements);
+ int8_t *int8RawData = (int8_t *)mlirDenseElementsAttrGetRawData(int8Elements);
uint32_t *uint32RawData =
(uint32_t *)mlirDenseElementsAttrGetRawData(uint32Elements);
int32_t *int32RawData =
@@ -1008,7 +1043,8 @@ int printBuiltinAttributes(MlirContext ctx) {
float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements);
double *doubleRawData =
(double *)mlirDenseElementsAttrGetRawData(doubleElements);
- if (uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
+ if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
+ int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
@@ -1016,6 +1052,8 @@ int printBuiltinAttributes(MlirContext ctx) {
return 18;
mlirAttributeDump(splatBool);
+ mlirAttributeDump(splatUInt8);
+ mlirAttributeDump(splatInt8);
mlirAttributeDump(splatUInt32);
mlirAttributeDump(splatInt32);
mlirAttributeDump(splatUInt64);
@@ -1023,9 +1061,11 @@ int printBuiltinAttributes(MlirContext ctx) {
mlirAttributeDump(splatFloat);
mlirAttributeDump(splatDouble);
// CHECK: dense<true> : tensor<1x2xi1>
+ // CHECK: dense<1> : tensor<1x2xui8>
+ // CHECK: dense<1> : tensor<1x2xi8>
+ // CHECK: dense<1> : tensor<1x2xui32>
// CHECK: dense<1> : tensor<1x2xi32>
- // CHECK: dense<1> : tensor<1x2xi32>
- // CHECK: dense<1> : tensor<1x2xi64>
+ // CHECK: dense<1> : tensor<1x2xui64>
// CHECK: dense<1> : tensor<1x2xi64>
// CHECK: dense<1.000000e+00> : tensor<1x2xf32>
// CHECK: dense<1.000000e+00> : tensor<1x2xf64>
More information about the Mlir-commits
mailing list