[Mlir-commits] [mlir] 508eb41 - Introduce a new Dense Array attribute

Mehdi Amini llvmlistbot at llvm.org
Tue Jun 28 05:08:37 PDT 2022


Author: Mehdi Amini
Date: 2022-06-28T12:08:25Z
New Revision: 508eb41d82ca956c30950d9a16b522a29aeeb333

URL: https://github.com/llvm/llvm-project/commit/508eb41d82ca956c30950d9a16b522a29aeeb333
DIFF: https://github.com/llvm/llvm-project/commit/508eb41d82ca956c30950d9a16b522a29aeeb333.diff

LOG: Introduce a new Dense Array attribute

This attribute is similar to DenseElementsAttr but does not support
splat. As such it has a much simpler API and does not need any smart
iterator: it exposes direct ArrayRef access.

A new syntax is introduced so that the generic printing/parsing looks
like:

  [:i64 1, -2, 3]

This attribute beings like an ArrayAttr but has a `:` token after the
opening square brace to introduce the element type (supported are I8,
I16, I32, I64, F32, F64) and the comma separated list for the data.

This is particularly convenient for attributes intended to be small,
like those referring to shapes.
For example a `transpose` operation with a `dims` attribute could be
defined as such:

  let arguments = (ins AnyTensor:$input, DenseI64ArrayAttr:$dims);
  let assemblyFormat = "$input `dims` `=` $dims attr-dict : type($input)";

And printed this way (the element type is elided in this case):

  transpose %input dims = [0, 2, 1] : tensor<2x3x4xf32>

The C++ API for dims would just directly return an ArrayRef<int64>

RFC: https://discourse.llvm.org/t/rfc-introduce-a-new-dense-array-attribute/63279

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123774

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinAttributes.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/include/mlir/IR/OpBase.td
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.h
    mlir/test/IR/attribute.mlir
    mlir/test/IR/elements-attr-interface.mlir
    mlir/test/IR/invalid.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 85f6d3f4e638e..f22f66fd6ac2c 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -66,8 +66,8 @@ template <typename T>
 struct is_complex_t<std::complex<T>> : public std::true_type {};
 } // namespace detail
 
-/// An attribute that represents a reference to a dense vector or tensor object.
-///
+/// An attribute that represents a reference to a dense vector or tensor
+/// object.
 class DenseElementsAttr : public Attribute {
 public:
   using Attribute::Attribute;
@@ -743,6 +743,55 @@ class SplatElementsAttr : public DenseElementsAttr {
 //===----------------------------------------------------------------------===//
 
 namespace mlir {
+namespace detail {
+/// Base class for DenseArrayAttr that is instantiated and specialized for each
+/// supported element type below.
+template <typename T>
+class DenseArrayAttr : public DenseArrayBaseAttr {
+public:
+  using DenseArrayBaseAttr::DenseArrayBaseAttr;
+
+  /// Implicit conversion to ArrayRef<T>.
+  operator ArrayRef<T>() const;
+  ArrayRef<T> asArrayRef() { return ArrayRef<T>{*this}; }
+
+  /// Builder from ArrayRef<T>.
+  static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
+
+  /// Print the short form `[42, 100, -1]` without any type prefix.
+  void print(AsmPrinter &printer) const;
+  void print(raw_ostream &os) const;
+  /// Print the short form `42, 100, -1` without any braces or type prefix.
+  void printWithoutBraces(raw_ostream &os) const;
+
+  /// Parse the short form `[42, 100, -1]` without any type prefix.
+  static Attribute parse(AsmParser &parser, Type odsType);
+
+  /// Parse the short form `42, 100, -1` without any type prefix or braces.
+  static Attribute parseWithoutBraces(AsmParser &parser, Type odsType);
+
+  /// Support for isa<>/cast<>.
+  static bool classof(Attribute attr);
+};
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
+
+extern template class DenseArrayAttr<int8_t>;
+extern template class DenseArrayAttr<int16_t>;
+extern template class DenseArrayAttr<int32_t>;
+extern template class DenseArrayAttr<int64_t>;
+extern template class DenseArrayAttr<float>;
+extern template class DenseArrayAttr<double>;
+} // namespace detail
+
+// Public name for all the supported DenseArrayAttr
+using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
+using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
+using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
+using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
+using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
+using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+
 //===----------------------------------------------------------------------===//
 // BoolAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 2fab392088381..8e6a8d08ad264 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -144,6 +144,76 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
 // DenseIntOrFPElementsAttr
 //===----------------------------------------------------------------------===//
 
+def Builtin_DenseArrayBase : Builtin_Attr<
+    "DenseArrayBase", [ElementsAttrInterface]> {
+  let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
+  let description = [{
+    A dense array attribute is an attribute that represents a dense array of
+    primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
+    flat unidimensional array which does not have a storage optimization for
+    splat. This allows to expose the raw array through a C++ API as
+    `ArrayRef<T>`. This is the base class attribute, the actual access is
+    intended to be managed through the subclasses `DenseI8ArrayAttr`,
+    `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
+    `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
+
+    Syntax:
+
+    ```
+    dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
+    ```
+    Examples:
+
+    ```mlir
+    [:i8]
+    [:i32 10, 42]
+    [:f64 42., 12.]
+    ```
+
+    when a specific subclass is used as argument of an operation, the declarative
+    assembly will omit the type and print directly:
+    ```
+    [1, 2, 3]
+    ```
+  }];
+  let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type,
+                        "DenseArrayBaseAttr::EltType":$eltType,
+                        ArrayRefParameter<"char">:$elements);
+  let extraClassDeclaration = [{
+    // All possible supported element type.
+    enum class EltType { I8, I16, I32, I64, F32, F64 };
+
+    /// Allow implicit conversion to ElementsAttr.
+    operator ElementsAttr() const {
+      return *this ? cast<ElementsAttr>() : nullptr;
+    }
+
+    /// ElementsAttr implementation.
+    using ContiguousIterableTypesT =
+        std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
+    const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
+    const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
+    const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
+    const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
+    const float *value_begin_impl(OverloadToken<float>) const;
+    const double *value_begin_impl(OverloadToken<double>) const;
+
+    /// Methods to support type inquiry through isa, cast, and dyn_cast.
+    EltType getElementType() const;
+    /// Printer for the short form: will dispatch to the appropriate subclass.
+    void print(AsmPrinter &printer) const;
+    void print(raw_ostream &os) const;
+    /// Print the short form `42, 100, -1` without any braces or prefix.
+    void printWithoutBraces(raw_ostream &os) const;
+  }];
+  let genAccessors = 0;
+  let skipDefaultBuilders = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// DenseIntOrFPElementsAttr
+//===----------------------------------------------------------------------===//
+
 def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr<
     "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr"
   > {

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 4e092484a6512..807216abd9813 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1258,6 +1258,19 @@ class IntElementsAttrBase<Pred condition, string summary> :
   let convertFromStorage = "$_self";
 }
 
+class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryName> :
+    ElementsAttrBase<CPred<"$_self.isa<::mlir::" # denseAttrName # ">()">,
+                     summaryName # " dense array attribute"> {
+  let storageType = "::mlir::" # denseAttrName;
+  let returnType = "::llvm::ArrayRef<" # cppType # ">";
+}
+def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
+def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
+def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">;
+def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">;
+def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">;
+
 def IndexElementsAttr
     : IntElementsAttrBase<CPred<[{$_self.cast<::mlir::DenseIntElementsAttr>()
                                       .getType()

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 2c8c9b0c640e1..981097e9101b6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1878,9 +1878,34 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
       }
       os << '>';
     }
-
+  } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
+    typeElision = AttrTypeElision::Must;
+    switch (denseArrayAttr.getElementType()) {
+    case DenseArrayBaseAttr::EltType::I8:
+      os << "[:i8 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I16:
+      os << "[:i16 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I32:
+      os << "[:i32 ";
+      break;
+    case DenseArrayBaseAttr::EltType::I64:
+      os << "[:i64 ";
+      break;
+    case DenseArrayBaseAttr::EltType::F32:
+      os << "[:f32 ";
+      break;
+    case DenseArrayBaseAttr::EltType::F64:
+      os << "[:f64 ";
+      break;
+    }
+    denseArrayAttr.printWithoutBraces(os);
+    os << "]";
   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
     printLocation(locAttr);
+  } else {
+    llvm::report_fatal_error("Unknown builtin attribute");
   }
   // Don't print the type if we must elide it, or if it is a None type.
   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 4358460badf2c..5daa219f7a74e 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/IR/Types.h"
@@ -35,11 +36,11 @@ using namespace mlir::detail;
 //===----------------------------------------------------------------------===//
 
 void BuiltinDialect::registerAttributes() {
-  addAttributes<AffineMapAttr, ArrayAttr, DenseIntOrFPElementsAttr,
-                DenseStringElementsAttr, DictionaryAttr, FloatAttr,
-                SymbolRefAttr, IntegerAttr, IntegerSetAttr, OpaqueAttr,
-                OpaqueElementsAttr, SparseElementsAttr, StringAttr, TypeAttr,
-                UnitAttr>();
+  addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
+                DenseIntOrFPElementsAttr, DenseStringElementsAttr,
+                DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
+                IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
+                SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -664,6 +665,234 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
           readBits(getData(), offset + storageWidth, bitWidth)};
 }
 
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
+  return getImpl()->eltType;
+}
+
+const int8_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
+  return cast<DenseI8ArrayAttr>().asArrayRef().begin();
+}
+const int16_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
+  return cast<DenseI16ArrayAttr>().asArrayRef().begin();
+}
+const int32_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
+  return cast<DenseI32ArrayAttr>().asArrayRef().begin();
+}
+const int64_t *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
+  return cast<DenseI64ArrayAttr>().asArrayRef().begin();
+}
+const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
+  return cast<DenseF32ArrayAttr>().asArrayRef().begin();
+}
+const double *
+DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
+  return cast<DenseF64ArrayAttr>().asArrayRef().begin();
+}
+
+void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
+  print(printer.getStream());
+}
+
+void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
+  switch (getElementType()) {
+  case DenseArrayBaseAttr::EltType::I8:
+    this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I16:
+    this->cast<DenseI16ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I32:
+    this->cast<DenseI32ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::I64:
+    this->cast<DenseI64ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::F32:
+    this->cast<DenseF32ArrayAttr>().printWithoutBraces(os);
+    return;
+  case DenseArrayBaseAttr::EltType::F64:
+    this->cast<DenseF64ArrayAttr>().printWithoutBraces(os);
+    return;
+  }
+  llvm_unreachable("<unknown DenseArrayBaseAttr>");
+}
+
+void DenseArrayBaseAttr::print(raw_ostream &os) const {
+  os << "[";
+  printWithoutBraces(os);
+  os << "]";
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
+  print(printer.getStream());
+}
+
+template <typename T>
+void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
+  ArrayRef<T> values{*this};
+  llvm::interleaveComma(values, os);
+}
+
+/// Specialization for int8_t for forcing printing as number instead of chars.
+template <>
+void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
+  ArrayRef<int8_t> values{*this};
+  llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
+}
+
+template <typename T>
+void DenseArrayAttr<T>::print(raw_ostream &os) const {
+  os << "[";
+  printWithoutBraces(os);
+  os << "]";
+}
+
+/// Parse a single element: generic template for int types, specialized for
+/// floating points below.
+template <typename T>
+static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
+  return parser.parseInteger(value);
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<float>(AsmParser &parser, float &value) {
+  double doubleVal;
+  if (parser.parseFloat(doubleVal))
+    return failure();
+  value = doubleVal;
+  return success();
+}
+
+template <>
+ParseResult parseDenseArrayAttrElt<double>(AsmParser &parser, double &value) {
+  return parser.parseFloat(value);
+}
+
+/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
+template <typename T>
+Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
+                                                Type odsType) {
+  SmallVector<T> data;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        T value;
+        if (parseDenseArrayAttrElt(parser, value))
+          return failure();
+        data.push_back(value);
+        return success();
+      })))
+    return {};
+  return get(parser.getContext(), data);
+}
+
+/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
+template <typename T>
+Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
+  if (parser.parseLSquare())
+    return {};
+  Attribute result = parseWithoutBraces(parser, odsType);
+  if (parser.parseRSquare())
+    return {};
+  return result;
+}
+
+/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T>::operator ArrayRef<T>() const {
+  ArrayRef<char> raw = getImpl()->elements;
+  assert((raw.size() % sizeof(T)) == 0);
+  return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
+                     raw.size() / sizeof(T));
+}
+
+namespace {
+/// Mapping from C++ element type to MLIR DenseArrayAttr internals.
+template <typename T>
+struct denseArrayAttrEltTypeBuilder;
+template <>
+struct denseArrayAttrEltTypeBuilder<int8_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 8));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int16_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 16));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int32_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 32));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<int64_t> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, IntegerType::get(context, 64));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<float> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, Float32Type::get(context));
+  }
+};
+template <>
+struct denseArrayAttrEltTypeBuilder<double> {
+  constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
+  static ShapedType getShapedType(MLIRContext *context, int64_t shape) {
+    return VectorType::get(shape, Float64Type::get(context));
+  }
+};
+} // namespace
+
+/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
+template <typename T>
+DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
+                                         ArrayRef<T> content) {
+  auto shapedType =
+      denseArrayAttrEltTypeBuilder<T>::getShapedType(context, content.size());
+  auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
+  auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
+                                 content.size() * sizeof(T));
+  return Base::get(context, shapedType, eltType, rawArray)
+      .template cast<DenseArrayAttr<T>>();
+}
+
+template <typename T>
+bool DenseArrayAttr<T>::classof(Attribute attr) {
+  return attr.isa<DenseArrayBaseAttr>() &&
+         attr.cast<DenseArrayBaseAttr>().getElementType() ==
+             denseArrayAttrEltTypeBuilder<T>::eltType;
+}
+
+namespace mlir {
+namespace detail {
+// Explicit instantiation for all the supported DenseArrayAttr.
+template class DenseArrayAttr<int8_t>;
+template class DenseArrayAttr<int16_t>;
+template class DenseArrayAttr<int32_t>;
+template class DenseArrayAttr<int64_t>;
+template class DenseArrayAttr<float>;
+template class DenseArrayAttr<double>;
+} // namespace detail
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // DenseElementsAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index efc1c226a2aa0..177420668a385 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -11,9 +11,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "Parser.h"
+
+#include "AsmParserImpl.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Parser/AsmParserState.h"
 #include "llvm/ADT/StringExtras.h"
@@ -30,6 +33,7 @@ using namespace mlir::detail;
 ///                    | float-literal (`:` float-type)?
 ///                    | string-literal (`:` type)?
 ///                    | type
+///                    | `[` `:` (integer-type | float-type) tensor-literal `]`
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
 ///                    | symbol-ref-id (`::` symbol-ref-id)*
@@ -67,13 +71,16 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse an array attribute.
   case Token::l_square: {
+    consumeToken(Token::l_square);
+    if (consumeIf(Token::colon))
+      return parseDenseArrayAttr();
     SmallVector<Attribute, 4> elements;
     auto parseElt = [&]() -> ParseResult {
       elements.push_back(parseAttribute());
       return elements.back() ? success() : failure();
     };
 
-    if (parseCommaSeparatedList(Delimiter::Square, parseElt))
+    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
       return nullptr;
     return builder.getArrayAttr(elements);
   }
@@ -812,6 +819,66 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
 // ElementsAttr Parser
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This class provides an implementation of AsmParser, allowing to call back
+/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
+/// in libMLIRIR.
+class CustomAsmParser : public AsmParserImpl<AsmParser> {
+public:
+  CustomAsmParser(Parser &parser)
+      : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
+};
+} // namespace
+
+/// Parse a dense array attribute.
+Attribute Parser::parseDenseArrayAttr() {
+  auto typeLoc = getToken().getLoc();
+  auto type = parseType();
+  if (!type)
+    return {};
+  CustomAsmParser parser(*this);
+  Attribute result;
+  if (auto intType = type.dyn_cast<IntegerType>()) {
+    switch (type.getIntOrFloatBitWidth()) {
+    case 8:
+      result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 16:
+      result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 32:
+      result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 64:
+      result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    default:
+      emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
+      return {};
+    }
+  } else if (auto floatType = type.dyn_cast<FloatType>()) {
+    switch (type.getIntOrFloatBitWidth()) {
+    case 32:
+      result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    case 64:
+      result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
+      break;
+    default:
+      emitError(typeLoc, "expected f32 or f64 but got: ") << type;
+      return {};
+    }
+  } else {
+    emitError(typeLoc, "expected integer or float type, got: ") << type;
+    return {};
+  }
+  if (!consumeIf(Token::r_square)) {
+    emitError("expected ']' to close an array attribute");
+    return {};
+  }
+  return result;
+}
+
 /// Parse a dense elements attribute.
 Attribute Parser::parseDenseElementsAttr(Type attrType) {
   auto attribLoc = getToken().getLoc();

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 357de93d73a1a..e97c62cd91d8f 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -264,6 +264,9 @@ class Parser {
   Attribute parseDenseElementsAttr(Type attrType);
   ShapedType parseElementsLiteralType(Type type);
 
+  /// Parse a DenseArrayAttr.
+  Attribute parseDenseArrayAttr();
+
   /// Parse a sparse elements attribute.
   Attribute parseSparseElementsAttr(Type attrType);
 

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 19f8767956062..f6b274015a3c4 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -513,6 +513,45 @@ func.func @simple_scalar_example() {
   return
 }
 
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Test DenseArrayAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @dense_array_attr
+func.func @dense_array_attr() attributes{ 
+// CHECK-SAME: f32attr = [:f32 1.024000e+03, 4.530000e+02, -6.435000e+03],
+               f32attr = [:f32 1024., 453., -6435.],
+// CHECK-SAME: f64attr = [:f64 -1.420000e+02],
+               f64attr = [:f64 -142.],
+// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
+               i16attr = [:i16 3, 5, -4, 10],
+// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
+               i32attr = [:i32 1024, 453, -6435],
+// CHECK-SAME: i64attr = [:i64 -142],
+               i64attr = [:i64 -142],
+// CHECK-SAME: i8attr = [:i8 1, -2, 3]
+               i8attr = [:i8 1, -2, 3]
+ } {
+// CHECK:  test.dense_array_attr
+  test.dense_array_attr
+// CHECK-SAME: i8attr = [1, -2, 3]
+               i8attr = [1, -2, 3]
+// CHECK-SAME: i16attr = [3, 5, -4, 10]
+               i16attr = [3, 5, -4, 10]
+// CHECK-SAME: i32attr = [1024, 453, -6435]
+               i32attr = [1024, 453, -6435]
+// CHECK-SAME: i64attr = [-142]
+               i64attr = [-142]
+// CHECK-SAME: f32attr = [1.024000e+03, 4.530000e+02, -6.435000e+03]
+               f32attr = [1024., 453., -6435.]
+// CHECK-SAME: f64attr = [-1.420000e+02]
+               f64attr = [-142.]
+  return
+}
+
 // -----
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index be1020d952271..a476fb1ca3b3b 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -5,23 +5,40 @@
 // This tests that the abstract iteration of ElementsAttr works properly, and
 // is properly failable when necessary.
 
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error at below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
 // expected-error at below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
 // expected-error at below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
 arith.constant #test.i64_elements<[10, 11, 12, 13, 14]> : tensor<5xi64>
 
+// expected-error at below {{Test iterating `int64_t`: 10, 11, 12, 13, 14}}
 // expected-error at below {{Test iterating `uint64_t`: 10, 11, 12, 13, 14}}
 // expected-error at below {{Test iterating `APInt`: 10, 11, 12, 13, 14}}
 // expected-error at below {{Test iterating `IntegerAttr`: 10 : i64, 11 : i64, 12 : i64, 13 : i64, 14 : i64}}
 arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 
+// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error at below {{Test iterating `uint64_t`: unable to iterate type}}
 // expected-error at below {{Test iterating `APInt`: unable to iterate type}}
 // expected-error at below {{Test iterating `IntegerAttr`: unable to iterate type}}
 arith.constant opaque<"_", "0xDEADBEEF"> : tensor<5xi64>
 
 // Check that we don't crash on empty element attributes.
+// expected-error at below {{Test iterating `int64_t`: }}
 // expected-error at below {{Test iterating `uint64_t`: }}
 // expected-error at below {{Test iterating `APInt`: }}
 // expected-error at below {{Test iterating `IntegerAttr`: }}
 arith.constant dense<> : tensor<0xi64>
+
+// expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i8 10, 11, -12, 13, 14]
+// expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i16 10, 11, -12, 13, 14]
+// expected-error at below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i32 10, 11, -12, 13, 14]
+// expected-error at below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
+arith.constant [:i64 10, 11, -12, 13, 14]
+// expected-error at below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f32 10., 11., -12., 13., 14.]
+// expected-error at below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
+arith.constant [:f64 10., 11., -12., 13., 14.]

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 2b98485bf8d21..3a8b7911638fe 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -1654,7 +1654,7 @@ func.func @foo() {} // expected-error {{expected non-empty function body}}
 
 // -----
 
-// expected-error at +1 {{expected ']'}}
+// expected-error at +1 {{expected ',' or ']'}}
 "f"() { b = [@m:
 
 // -----

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index c100aa2dbd67f..325e5d91caa9b 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -270,6 +270,22 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
   );
 }
 
+def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
+  let arguments = (ins
+    DenseI8ArrayAttr:$i8attr,
+    DenseI16ArrayAttr:$i16attr,
+    DenseI32ArrayAttr:$i32attr,
+    DenseI64ArrayAttr:$i64attr,
+    DenseF32ArrayAttr:$f32attr,
+    DenseF64ArrayAttr:$f64attr
+  );
+  let assemblyFormat = [{
+   `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
+   `i64attr` `=` $i64attr  `f32attr` `=` $f32attr `f64attr` `=` $f64attr
+   attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 783512d72aae2..f32a49bd5bedb 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -14,6 +14,17 @@
 using namespace mlir;
 using namespace test;
 
+// Helper to print one scalar value, force int8_t to print as integer instead of
+// char.
+template <typename T>
+static void printOneElement(InFlightDiagnostic &os, T value) {
+  os << llvm::formatv("{0}", value).str();
+}
+template <>
+void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
+  os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
+}
+
 namespace {
 struct TestElementsAttrInterface
     : public PassWrapper<TestElementsAttrInterface, OperationPass<ModuleOp>> {
@@ -29,6 +40,31 @@ struct TestElementsAttrInterface
         auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
         if (!elementsAttr)
           continue;
+        if (auto concreteAttr =
+                attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
+          switch (concreteAttr.getElementType()) {
+          case DenseArrayBaseAttr::EltType::I8:
+            testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I16:
+            testElementsAttrIteration<int16_t>(op, elementsAttr, "int16_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I32:
+            testElementsAttrIteration<int32_t>(op, elementsAttr, "int32_t");
+            break;
+          case DenseArrayBaseAttr::EltType::I64:
+            testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
+            break;
+          case DenseArrayBaseAttr::EltType::F32:
+            testElementsAttrIteration<float>(op, elementsAttr, "float");
+            break;
+          case DenseArrayBaseAttr::EltType::F64:
+            testElementsAttrIteration<double>(op, elementsAttr, "double");
+            break;
+          }
+          continue;
+        }
+        testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
         testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
         testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
         testElementsAttrIteration<IntegerAttr>(op, elementsAttr, "IntegerAttr");
@@ -48,9 +84,8 @@ struct TestElementsAttrInterface
       return;
     }
 
-    llvm::interleaveComma(*values, diag, [&](T value) {
-      diag << llvm::formatv("{0}", value).str();
-    });
+    llvm::interleaveComma(*values, diag,
+                          [&](T value) { printOneElement(diag, value); });
   }
 };
 } // namespace


        


More information about the Mlir-commits mailing list