[Mlir-commits] [mlir] 308d8b8 - [mlir][python] 8b/16b DenseIntElements access
Mehdi Amini
llvmlistbot at llvm.org
Thu Jan 20 21:26:33 PST 2022
Author: Rahul Kayaith
Date: 2022-01-21T05:21:09Z
New Revision: 308d8b8c6618f570166bcc7dbb87f97c04bba1b2
URL: https://github.com/llvm/llvm-project/commit/308d8b8c6618f570166bcc7dbb87f97c04bba1b2
DIFF: https://github.com/llvm/llvm-project/commit/308d8b8c6618f570166bcc7dbb87f97c04bba1b2.diff
LOG: [mlir][python] 8b/16b DenseIntElements access
This extends dense attribute element access to support 8b and 16b ints.
Also extends the corresponding parts of the C api.
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D117731
Added:
Modified:
mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/attributes.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 5839cd3d2408a..973b7e99469c0 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
MlirType shapedType, intptr_t numElements, const uint8_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
MlirType shapedType, intptr_t numElements, const int8_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get(
+ MlirType shapedType, intptr_t numElements, const uint16_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get(
+ MlirType shapedType, intptr_t numElements, const int16_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
MlirType shapedType, intptr_t numElements, const uint32_t *elements);
MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
intptr_t pos);
MLIR_CAPI_EXPORTED uint8_t
mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
+MLIR_CAPI_EXPORTED int16_t
+mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos);
+MLIR_CAPI_EXPORTED uint16_t
+mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED int32_t
mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
MLIR_CAPI_EXPORTED uint32_t
diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index fd44ffe6ba5fe..5d87641c379d8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -673,6 +673,12 @@ class PyDenseIntElementsAttribute
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
+ if (width == 8) {
+ return mlirDenseElementsAttrGetUInt8Value(*this, pos);
+ }
+ if (width == 16) {
+ return mlirDenseElementsAttrGetUInt16Value(*this, pos);
+ }
if (width == 32) {
return mlirDenseElementsAttrGetUInt32Value(*this, pos);
}
@@ -683,6 +689,12 @@ class PyDenseIntElementsAttribute
if (width == 1) {
return mlirDenseElementsAttrGetBoolValue(*this, pos);
}
+ if (width == 8) {
+ return mlirDenseElementsAttrGetInt8Value(*this, pos);
+ }
+ if (width == 16) {
+ return mlirDenseElementsAttrGetInt16Value(*this, pos);
+ }
if (width == 32) {
return mlirDenseElementsAttrGetInt32Value(*this, pos);
}
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index c20548bd47597..7b718da88ceef 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
const int8_t *elements) {
return getDenseAttribute(shapedType, numElements, elements);
}
+MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
+ intptr_t numElements,
+ const uint16_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
+ intptr_t numElements,
+ const int16_t *elements) {
+ return getDenseAttribute(shapedType, numElements, elements);
+}
MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
intptr_t numElements,
const uint32_t *elements) {
@@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
}
+int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
+ return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
+}
+uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
+ return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
+}
int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d01ccaeb0e93a..257d5e9b8683d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -904,6 +904,8 @@ int printBuiltinAttributes(MlirContext ctx) {
int bools[] = {0, 1};
uint8_t uints8[] = {0u, 1u};
int8_t ints8[] = {0, 1};
+ uint16_t uints16[] = {0u, 1u};
+ int16_t ints16[] = {0, 1};
uint32_t uints32[] = {0u, 1u};
int32_t ints32[] = {0, 1};
uint64_t uints64[] = {0u, 1u};
@@ -921,6 +923,13 @@ int printBuiltinAttributes(MlirContext ctx) {
MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
2, ints8);
+ MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
+ encoding),
+ 2, uints16);
+ MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
+ mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
+ 2, ints16);
MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
encoding),
@@ -956,6 +965,8 @@ int printBuiltinAttributes(MlirContext ctx) {
if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
+ mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||
diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 5f8dd0ad1183f..48f2d4b3df067 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -292,6 +292,50 @@ def testDenseIntAttr():
print(ShapedType(a.type).element_type)
+# CHECK-LABEL: TEST: testDenseIntAttrGetItem
+ at run
+def testDenseIntAttrGetItem():
+ def print_item(attr_asm):
+ attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
+ dtype = ShapedType(attr.type).element_type
+ try:
+ item = attr[0]
+ print(f"{dtype}:", item)
+ except TypeError as e:
+ print(f"{dtype}:", e)
+
+ with Context():
+ # CHECK: i1: 1
+ print_item("dense<true> : tensor<i1>")
+ # CHECK: i8: 123
+ print_item("dense<123> : tensor<i8>")
+ # CHECK: i16: 123
+ print_item("dense<123> : tensor<i16>")
+ # CHECK: i32: 123
+ print_item("dense<123> : tensor<i32>")
+ # CHECK: i64: 123
+ print_item("dense<123> : tensor<i64>")
+ # CHECK: ui8: 123
+ print_item("dense<123> : tensor<ui8>")
+ # CHECK: ui16: 123
+ print_item("dense<123> : tensor<ui16>")
+ # CHECK: ui32: 123
+ print_item("dense<123> : tensor<ui32>")
+ # CHECK: ui64: 123
+ print_item("dense<123> : tensor<ui64>")
+ # CHECK: si8: -123
+ print_item("dense<-123> : tensor<si8>")
+ # CHECK: si16: -123
+ print_item("dense<-123> : tensor<si16>")
+ # CHECK: si32: -123
+ print_item("dense<-123> : tensor<si32>")
+ # CHECK: si64: -123
+ print_item("dense<-123> : tensor<si64>")
+
+ # CHECK: i7: Unsupported integer type
+ print_item("dense<123> : tensor<i7>")
+
+
# CHECK-LABEL: TEST: testDenseFPAttr
@run
def testDenseFPAttr():
More information about the Mlir-commits
mailing list