[Mlir-commits] [mlir] aabfaf9 - [mlir] Allow empty lists for DenseArrayAttr.

Adrian Kuegel llvmlistbot at llvm.org
Wed Jul 13 00:16:21 PDT 2022


Author: Adrian Kuegel
Date: 2022-07-13T09:16:09+02:00
New Revision: aabfaf901b0f2961e11bc6a25d6bc1fa75ad6866

URL: https://github.com/llvm/llvm-project/commit/aabfaf901b0f2961e11bc6a25d6bc1fa75ad6866
DIFF: https://github.com/llvm/llvm-project/commit/aabfaf901b0f2961e11bc6a25d6bc1fa75ad6866.diff

LOG: [mlir] Allow empty lists for DenseArrayAttr.

Differential Revision: https://reviews.llvm.org/D129552

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1e46b1119586..7865652d0572 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1875,24 +1875,26 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
     typeElision = AttrTypeElision::Must;
     switch (denseArrayAttr.getElementType()) {
     case DenseArrayBaseAttr::EltType::I8:
-      os << "[:i8 ";
+      os << "[:i8";
       break;
     case DenseArrayBaseAttr::EltType::I16:
-      os << "[:i16 ";
+      os << "[:i16";
       break;
     case DenseArrayBaseAttr::EltType::I32:
-      os << "[:i32 ";
+      os << "[:i32";
       break;
     case DenseArrayBaseAttr::EltType::I64:
-      os << "[:i64 ";
+      os << "[:i64";
       break;
     case DenseArrayBaseAttr::EltType::F32:
-      os << "[:f32 ";
+      os << "[:f32";
       break;
     case DenseArrayBaseAttr::EltType::F64:
-      os << "[:f64 ";
+      os << "[:f64";
       break;
     }
+    if (denseArrayAttr.getType().cast<ShapedType>().getRank())
+      os << " ";
     denseArrayAttr.printWithoutBraces(os);
     os << "]";
   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 80b91ca5bb05..218622ba9d73 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -838,6 +838,9 @@ template <typename T>
 Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
   if (parser.parseLSquare())
     return {};
+  // Handle empty list case.
+  if (succeeded(parser.parseOptionalRSquare()))
+    return get(parser.getContext(), {});
   Attribute result = parseWithoutBraces(parser, odsType);
   if (parser.parseRSquare())
     return {};
@@ -860,42 +863,48 @@ struct denseArrayAttrEltTypeBuilder;
 template <>
 struct denseArrayAttrEltTypeBuilder<int8_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, IntegerType::get(context, 8));
   }
 };
 template <>
 struct denseArrayAttrEltTypeBuilder<int16_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, IntegerType::get(context, 16));
   }
 };
 template <>
 struct denseArrayAttrEltTypeBuilder<int32_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, IntegerType::get(context, 32));
   }
 };
 template <>
 struct denseArrayAttrEltTypeBuilder<int64_t> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, IntegerType::get(context, 64));
   }
 };
 template <>
 struct denseArrayAttrEltTypeBuilder<float> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, Float32Type::get(context));
   }
 };
 template <>
 struct denseArrayAttrEltTypeBuilder<double> {
   constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
-  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+  static ShapedType getShapedType(MLIRContext *context,
+                                  ArrayRef<int64_t> shape) {
     return VectorType::get(shape, Float64Type::get(context));
   }
 };
@@ -905,8 +914,9 @@ struct denseArrayAttrEltTypeBuilder<double> {
 template <typename T>
 DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
                                          ArrayRef<T> content) {
-  auto shapedType =
-      denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
+  auto size = static_cast<int64_t>(content.size());
+  auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType(
+      context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{});
   auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 3de6a0d60c23..4c4a43e43903 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -844,19 +844,34 @@ Attribute Parser::parseDenseArrayAttr() {
     return {};
   CustomAsmParser parser(*this);
   Attribute result;
+  // Check for empty list.
+  bool isEmptyList = getToken().is(Token::r_square);
+
   if (auto intType = type.dyn_cast<IntegerType>()) {
     switch (type.getIntOrFloatBitWidth()) {
     case 8:
-      result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseI8ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     case 16:
-      result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseI16ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     case 32:
-      result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseI32ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     case 64:
-      result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseI64ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     default:
       emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
@@ -865,10 +880,16 @@ Attribute Parser::parseDenseArrayAttr() {
   } else if (auto floatType = type.dyn_cast<FloatType>()) {
     switch (type.getIntOrFloatBitWidth()) {
     case 32:
-      result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseF32ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     case 64:
-      result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
+      if (isEmptyList)
+        result = DenseF64ArrayAttr::get(parser.getContext(), {});
+      else
+        result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
       break;
     default:
       emitError(typeLoc, "expected f32 or f64 but got: ") << type;

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index f6b274015a3c..b556b2f322c5 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -521,7 +521,19 @@ func.func @simple_scalar_example() {
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: func @dense_array_attr
-func.func @dense_array_attr() attributes{ 
+func.func @dense_array_attr() attributes{
+// CHECK-SAME: emptyf32attr = [:f32],
+               emptyf32attr = [:f32],
+// CHECK-SAME: emptyf64attr = [:f64],
+               emptyf64attr = [:f64],
+// CHECK-SAME: emptyi16attr = [:i16],
+               emptyi16attr = [:i16],
+// CHECK-SAME: emptyi32attr = [:i32],
+               emptyi32attr = [:i32],
+// CHECK-SAME: emptyi64attr = [:i64],
+               emptyi64attr = [:i64],
+// CHECK-SAME: emptyi8attr = [:i8],
+               emptyi8attr = [:i8],
 // CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
                f32attr = [:f32 1024., 453., -6435.],
 // CHECK-SAME: f64attr = [:f64 -1.420000e+02],
@@ -549,6 +561,8 @@ func.func @dense_array_attr() attributes{
                f32attr = [1024., 453., -6435.]
 // CHECK-SAME: f64attr = [-1.420000e+02]
                f64attr = [-142.]
+// CHECK-SAME: emptyattr = []
+               emptyattr = []
   return
 }
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 325e5d91caa9..b0a0cf3807fc 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -277,11 +277,13 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
     DenseI32ArrayAttr:$i32attr,
     DenseI64ArrayAttr:$i64attr,
     DenseF32ArrayAttr:$f32attr,
-    DenseF64ArrayAttr:$f64attr
+    DenseF64ArrayAttr:$f64attr,
+    DenseI32ArrayAttr:$emptyattr
   );
   let assemblyFormat = [{
    `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
    `i64attr` `=` $i64attr  `f32attr` `=` $f32attr `f64attr` `=` $f64attr
+   `emptyattr` `=` $emptyattr
    attr-dict
   }];
 }


        


More information about the Mlir-commits mailing list