[Mlir-commits] [mlir] 308d8b8 - [mlir][python] 8b/16b DenseIntElements access

Mehdi Amini llvmlistbot at llvm.org
Thu Jan 20 21:26:33 PST 2022


Author: Rahul Kayaith
Date: 2022-01-21T05:21:09Z
New Revision: 308d8b8c6618f570166bcc7dbb87f97c04bba1b2

URL: https://github.com/llvm/llvm-project/commit/308d8b8c6618f570166bcc7dbb87f97c04bba1b2
DIFF: https://github.com/llvm/llvm-project/commit/308d8b8c6618f570166bcc7dbb87f97c04bba1b2.diff

LOG: [mlir][python] 8b/16b DenseIntElements access

This extends dense attribute element access to support 8b and 16b ints.
Also extends the corresponding parts of the C api.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D117731

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/Bindings/Python/IRAttributes.cpp
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c
    mlir/test/python/ir/attributes.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index 5839cd3d2408a..973b7e99469c0 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get(
     MlirType shapedType, intptr_t numElements, const uint8_t *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get(
     MlirType shapedType, intptr_t numElements, const int8_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get(
+    MlirType shapedType, intptr_t numElements, const uint16_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get(
+    MlirType shapedType, intptr_t numElements, const int16_t *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get(
     MlirType shapedType, intptr_t numElements, const uint32_t *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get(
@@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr,
                                                             intptr_t pos);
 MLIR_CAPI_EXPORTED uint8_t
 mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos);
+MLIR_CAPI_EXPORTED int16_t
+mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos);
+MLIR_CAPI_EXPORTED uint16_t
+mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED int32_t
 mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos);
 MLIR_CAPI_EXPORTED uint32_t

diff  --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp
index fd44ffe6ba5fe..5d87641c379d8 100644
--- a/mlir/lib/Bindings/Python/IRAttributes.cpp
+++ b/mlir/lib/Bindings/Python/IRAttributes.cpp
@@ -673,6 +673,12 @@ class PyDenseIntElementsAttribute
       if (width == 1) {
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
+      if (width == 8) {
+        return mlirDenseElementsAttrGetUInt8Value(*this, pos);
+      }
+      if (width == 16) {
+        return mlirDenseElementsAttrGetUInt16Value(*this, pos);
+      }
       if (width == 32) {
         return mlirDenseElementsAttrGetUInt32Value(*this, pos);
       }
@@ -683,6 +689,12 @@ class PyDenseIntElementsAttribute
       if (width == 1) {
         return mlirDenseElementsAttrGetBoolValue(*this, pos);
       }
+      if (width == 8) {
+        return mlirDenseElementsAttrGetInt8Value(*this, pos);
+      }
+      if (width == 16) {
+        return mlirDenseElementsAttrGetInt16Value(*this, pos);
+      }
       if (width == 32) {
         return mlirDenseElementsAttrGetInt32Value(*this, pos);
       }

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index c20548bd47597..7b718da88ceef 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType,
                                            const int8_t *elements) {
   return getDenseAttribute(shapedType, numElements, elements);
 }
+MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType,
+                                             intptr_t numElements,
+                                             const uint16_t *elements) {
+  return getDenseAttribute(shapedType, numElements, elements);
+}
+MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType,
+                                            intptr_t numElements,
+                                            const int16_t *elements) {
+  return getDenseAttribute(shapedType, numElements, elements);
+}
 MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType,
                                              intptr_t numElements,
                                              const uint32_t *elements) {
@@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) {
 uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) {
   return unwrap(attr).cast<DenseElementsAttr>().getValues<uint8_t>()[pos];
 }
+int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<int16_t>()[pos];
+}
+uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) {
+  return unwrap(attr).cast<DenseElementsAttr>().getValues<uint16_t>()[pos];
+}
 int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) {
   return unwrap(attr).cast<DenseElementsAttr>().getValues<int32_t>()[pos];
 }

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index d01ccaeb0e93a..257d5e9b8683d 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -904,6 +904,8 @@ int printBuiltinAttributes(MlirContext ctx) {
   int bools[] = {0, 1};
   uint8_t uints8[] = {0u, 1u};
   int8_t ints8[] = {0, 1};
+  uint16_t uints16[] = {0u, 1u};
+  int16_t ints16[] = {0, 1};
   uint32_t uints32[] = {0u, 1u};
   int32_t ints32[] = {0, 1};
   uint64_t uints64[] = {0u, 1u};
@@ -921,6 +923,13 @@ int printBuiltinAttributes(MlirContext ctx) {
   MlirAttribute int8Elements = mlirDenseElementsAttrInt8Get(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 8), encoding),
       2, ints8);
+  MlirAttribute uint16Elements = mlirDenseElementsAttrUInt16Get(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 16),
+                              encoding),
+      2, uints16);
+  MlirAttribute int16Elements = mlirDenseElementsAttrInt16Get(
+      mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 16), encoding),
+      2, ints16);
   MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32),
                               encoding),
@@ -956,6 +965,8 @@ int printBuiltinAttributes(MlirContext ctx) {
   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
       mlirDenseElementsAttrGetUInt8Value(uint8Elements, 1) != 1 ||
       mlirDenseElementsAttrGetInt8Value(int8Elements, 1) != 1 ||
+      mlirDenseElementsAttrGetUInt16Value(uint16Elements, 1) != 1 ||
+      mlirDenseElementsAttrGetInt16Value(int16Elements, 1) != 1 ||
       mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 ||
       mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 ||
       mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 ||

diff  --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py
index 5f8dd0ad1183f..48f2d4b3df067 100644
--- a/mlir/test/python/ir/attributes.py
+++ b/mlir/test/python/ir/attributes.py
@@ -292,6 +292,50 @@ def testDenseIntAttr():
     print(ShapedType(a.type).element_type)
 
 
+# CHECK-LABEL: TEST: testDenseIntAttrGetItem
+ at run
+def testDenseIntAttrGetItem():
+  def print_item(attr_asm):
+    attr = DenseIntElementsAttr(Attribute.parse(attr_asm))
+    dtype = ShapedType(attr.type).element_type
+    try:
+      item = attr[0]
+      print(f"{dtype}:", item)
+    except TypeError as e:
+      print(f"{dtype}:", e)
+
+  with Context():
+    # CHECK: i1: 1
+    print_item("dense<true> : tensor<i1>")
+    # CHECK: i8: 123
+    print_item("dense<123> : tensor<i8>")
+    # CHECK: i16: 123
+    print_item("dense<123> : tensor<i16>")
+    # CHECK: i32: 123
+    print_item("dense<123> : tensor<i32>")
+    # CHECK: i64: 123
+    print_item("dense<123> : tensor<i64>")
+    # CHECK: ui8: 123
+    print_item("dense<123> : tensor<ui8>")
+    # CHECK: ui16: 123
+    print_item("dense<123> : tensor<ui16>")
+    # CHECK: ui32: 123
+    print_item("dense<123> : tensor<ui32>")
+    # CHECK: ui64: 123
+    print_item("dense<123> : tensor<ui64>")
+    # CHECK: si8: -123
+    print_item("dense<-123> : tensor<si8>")
+    # CHECK: si16: -123
+    print_item("dense<-123> : tensor<si16>")
+    # CHECK: si32: -123
+    print_item("dense<-123> : tensor<si32>")
+    # CHECK: si64: -123
+    print_item("dense<-123> : tensor<si64>")
+
+    # CHECK: i7: Unsupported integer type
+    print_item("dense<123> : tensor<i7>")
+
+
 # CHECK-LABEL: TEST: testDenseFPAttr
 @run
 def testDenseFPAttr():


        


More information about the Mlir-commits mailing list