[Mlir-commits] [mlir] 5b0d6bf - [MLIR] Add function to create Float16 array attribute

Sean Silva llvmlistbot at llvm.org
Wed Jul 20 14:58:25 PDT 2022


Author: Tanyo Kwok
Date: 2022-07-20T21:58:15Z
New Revision: 5b0d6bf2102bd56b179238f2ce0dac11d43b4bc3

URL: https://github.com/llvm/llvm-project/commit/5b0d6bf2102bd56b179238f2ce0dac11d43b4bc3
DIFF: https://github.com/llvm/llvm-project/commit/5b0d6bf2102bd56b179238f2ce0dac11d43b4bc3.diff

LOG: [MLIR] Add function to create Float16 array attribute

This patch adds a new function mlirDenseElementsAttrFloat16Get(),
which accepts the shaped type, the number of Float16 values, and a
pointer to an array of Float16 values, each of which is a uint16_t
value.

This commit is repeating https://reviews.llvm.org/D123981 + #761 but for Float16

Differential Revision: https://reviews.llvm.org/D130069

Added: 
    

Modified: 
    mlir/include/mlir-c/BuiltinAttributes.h
    mlir/lib/CAPI/IR/BuiltinAttributes.cpp
    mlir/test/CAPI/ir.c

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h
index ce4514094ecb3..050408d1f87aa 100644
--- a/mlir/include/mlir-c/BuiltinAttributes.h
+++ b/mlir/include/mlir-c/BuiltinAttributes.h
@@ -381,6 +381,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet(
     MlirType shapedType, intptr_t numElements, const double *elements);
 MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get(
     MlirType shapedType, intptr_t numElements, const uint16_t *elements);
+MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get(
+    MlirType shapedType, intptr_t numElements, const uint16_t *elements);
 
 /// Creates a dense elements attribute with the given shaped type from string
 /// elements.

diff  --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index 759b708952e2f..ba3481ae1e2d3 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -479,6 +479,13 @@ MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType,
   const void *buffer = static_cast<const void *>(elements);
   return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
 }
+MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType,
+                                              intptr_t numElements,
+                                              const uint16_t *elements) {
+  size_t bufferSize = numElements * 2;
+  const void *buffer = static_cast<const void *>(elements);
+  return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer);
+}
 
 MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType,
                                              intptr_t numElements,

diff  --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 9c1f95c70c296..be5366012bbe2 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -959,6 +959,7 @@ int printBuiltinAttributes(MlirContext ctx) {
   float floats[] = {0.0f, 1.0f};
   double doubles[] = {0.0, 1.0};
   uint16_t bf16s[] = {0x0, 0x3f80};
+  uint16_t f16s[] = {0x0, 0x3c00};
   MlirAttribute encoding = mlirAttributeGetNull();
   MlirAttribute boolElements = mlirDenseElementsAttrBoolGet(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
@@ -1000,6 +1001,9 @@ int printBuiltinAttributes(MlirContext ctx) {
   MlirAttribute bf16Elements = mlirDenseElementsAttrBFloat16Get(
       mlirRankedTensorTypeGet(2, shape, mlirBF16TypeGet(ctx), encoding), 2,
       bf16s);
+  MlirAttribute f16Elements = mlirDenseElementsAttrFloat16Get(
+      mlirRankedTensorTypeGet(2, shape, mlirF16TypeGet(ctx), encoding), 2,
+      f16s);
 
   if (!mlirAttributeIsADenseElements(boolElements) ||
       !mlirAttributeIsADenseElements(uint8Elements) ||
@@ -1010,7 +1014,8 @@ int printBuiltinAttributes(MlirContext ctx) {
       !mlirAttributeIsADenseElements(int64Elements) ||
       !mlirAttributeIsADenseElements(floatElements) ||
       !mlirAttributeIsADenseElements(doubleElements) ||
-      !mlirAttributeIsADenseElements(bf16Elements))
+      !mlirAttributeIsADenseElements(bf16Elements) ||
+      !mlirAttributeIsADenseElements(f16Elements))
     return 14;
 
   if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 ||
@@ -1037,6 +1042,7 @@ int printBuiltinAttributes(MlirContext ctx) {
   mlirAttributeDump(floatElements);
   mlirAttributeDump(doubleElements);
   mlirAttributeDump(bf16Elements);
+  mlirAttributeDump(f16Elements);
   // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1>
   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui8>
   // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi8>
@@ -1047,6 +1053,7 @@ int printBuiltinAttributes(MlirContext ctx) {
   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32>
   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64>
   // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xbf16>
+  // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf16>
 
   MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet(
       mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding),
@@ -1125,13 +1132,16 @@ int printBuiltinAttributes(MlirContext ctx) {
       (double *)mlirDenseElementsAttrGetRawData(doubleElements);
   uint16_t *bf16RawData =
       (uint16_t *)mlirDenseElementsAttrGetRawData(bf16Elements);
+  uint16_t *f16RawData =
+      (uint16_t *)mlirDenseElementsAttrGetRawData(f16Elements);
   if (uint8RawData[0] != 0u || uint8RawData[1] != 1u || int8RawData[0] != 0 ||
       int8RawData[1] != 1 || uint32RawData[0] != 0u || uint32RawData[1] != 1u ||
       int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u ||
       uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 ||
       floatRawData[0] != 0.0f || floatRawData[1] != 1.0f ||
       doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0 ||
-      bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80)
+      bf16RawData[0] != 0 || bf16RawData[1] != 0x3f80 || f16RawData[0] != 0 ||
+      f16RawData[1] != 0x3c00)
     return 18;
 
   mlirAttributeDump(splatBool);
@@ -1156,9 +1166,11 @@ int printBuiltinAttributes(MlirContext ctx) {
   mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64));
   mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64));
   mlirAttributeDump(mlirElementsAttrGetValue(bf16Elements, 2, uints64));
+  mlirAttributeDump(mlirElementsAttrGetValue(f16Elements, 2, uints64));
   // CHECK: 1.000000e+00 : f32
   // CHECK: 1.000000e+00 : f64
   // CHECK: 1.000000e+00 : bf16
+  // CHECK: 1.000000e+00 : f16
 
   int64_t indices[] = {0, 1};
   int64_t one = 1;


        


More information about the Mlir-commits mailing list