[Mlir-commits] [mlir] c48e0cf - [mlir] Remove TypedAttr and ElementsAttr from DenseArrayAttr

Jeff Niu llvmlistbot at llvm.org
Mon Dec 5 13:28:02 PST 2022


Author: Jeff Niu
Date: 2022-12-05T13:27:55-08:00
New Revision: c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a

URL: https://github.com/llvm/llvm-project/commit/c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a
DIFF: https://github.com/llvm/llvm-project/commit/c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a.diff

LOG: [mlir] Remove TypedAttr and ElementsAttr from DenseArrayAttr

This patch removes the implementation of TypedAttr and ElementsAttr
from DenseArrayAttr and, in doing so, removes the need store a shaped
type. The attribute now stores a size (number of elements), an MLIR type
as a discriminator, and a raw byte array.

The intent of DenseArrayAttr was not to be a drop-in replacement for DenseElementsAttr. It was meant to be a simple container of integers or floats that map to C++ types. The ElementsAttr implementation on DenseArrayAttr had many holes in it, and fixing those holes would require evolving DenseArrayAttr in a way that is incompatible with its original purpose.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D137606

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
    mlir/include/mlir/IR/BuiltinAttributes.td
    mlir/lib/AsmParser/AttributeParser.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/BuiltinAttributes.cpp
    mlir/lib/IR/BuiltinDialectBytecode.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/IR/elements-attr-interface.mlir
    mlir/test/IR/invalid-builtin-attributes.mlir
    mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 64e5e0abfd763..2494773935dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -165,9 +165,8 @@ class GEPIndicesAdaptor {
       if (*rawConstantIter == GEPOp::kDynamicIndex)
         return *valuesIter;
 
-      return IntegerAttr::get(
-          ElementsAttr::getElementType(base->rawConstantIndices),
-          *rawConstantIter);
+      return IntegerAttr::get(base->rawConstantIndices.getElementType(),
+                              *rawConstantIter);
     }
 
     iterator &operator++() {

diff  --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 1a06c92e725ef..adb19dbcfcab2 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -155,8 +155,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
   }];
 }
 
-def Builtin_DenseArray : Builtin_Attr<
-    "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
+def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
   let summary = "A dense array of integer or floating point elements.";
   let description = [{
     A dense array attribute is an attribute that represents a dense array of
@@ -195,43 +194,26 @@ def Builtin_DenseArray : Builtin_Attr<
   }];
 
   let parameters = (ins
-    AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
+    "Type":$elementType,
+    "int64_t":$size,
     Builtin_DenseArrayRawDataParameter:$rawData
   );
 
   let builders = [
-    AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
+    AttrBuilderWithInferredContext<(ins "Type":$elementType, "unsigned":$size,
                                         "ArrayRef<char>":$rawData), [{
-      return $_get(type.getContext(), type, rawData);
+      return $_get(elementType.getContext(), elementType, size, rawData);
     }]>,
   ];
 
-  let extraClassDeclaration = [{
-    /// Allow implicit conversion to ElementsAttr.
-    operator ElementsAttr() const {
-      return *this ? cast<ElementsAttr>() : nullptr;
-    }
+  let genVerifyDecl = 1;
 
-    /// ElementsAttr implementation.
-    using ContiguousIterableTypesT =
-        std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
-    FailureOr<const bool *>
-    try_value_begin_impl(OverloadToken<bool>) const;
-    FailureOr<const int8_t *>
-    try_value_begin_impl(OverloadToken<int8_t>) const;
-    FailureOr<const int16_t *>
-    try_value_begin_impl(OverloadToken<int16_t>) const;
-    FailureOr<const int32_t *>
-    try_value_begin_impl(OverloadToken<int32_t>) const;
-    FailureOr<const int64_t *>
-    try_value_begin_impl(OverloadToken<int64_t>) const;
-    FailureOr<const float *>
-    try_value_begin_impl(OverloadToken<float>) const;
-    FailureOr<const double *>
-    try_value_begin_impl(OverloadToken<double>) const;
+  let extraClassDeclaration = [{
+    /// Get the number of elements in the array.
+    int64_t size() const { return getSize(); }
+    /// Return true if there are no elements in the dense array.
+    bool empty() const { return !size(); }
   }];
-
-  let genVerifyDecl = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index e10dd5e108cd2..b768be036778b 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -844,9 +844,7 @@ class DenseArrayElementParser {
   ParseResult parseFloatElement(Parser &p);
 
   /// Convert the current contents to a dense array.
-  DenseArrayAttr getAttr() {
-    return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
-  }
+  DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
 
 private:
   /// Append the raw data of an APInt to the result.
@@ -934,18 +932,9 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
     return {};
 
   SMLoc typeLoc = getToken().getLoc();
-  Type eltType;
-  // If an attribute type was provided, use its element type.
-  if (attrType) {
-    auto tensorType = attrType.dyn_cast<RankedTensorType>();
-    if (!tensorType) {
-      emitError(typeLoc, "dense array attribute expected ranked tensor type");
-      return {};
-    }
-    eltType = tensorType.getElementType();
-
-    // Otherwise, parse a type.
-  } else if (!(eltType = parseType())) {
+  Type eltType = parseType();
+  if (!eltType) {
+    emitError(typeLoc, "expected an integer or floating point type");
     return {};
   }
 
@@ -960,23 +949,11 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
     return {};
   }
 
-  // If a type was provided, check that it matches the parsed type.
-  auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
-    if (attrType && result.getType() != attrType) {
-      emitError(typeLoc, "expected attribute type ")
-          << attrType << " does not match parsed type " << result.getType();
-      return {};
-    }
-    return result;
-  };
-
   // Check for empty list.
-  if (consumeIf(Token::greater)) {
-    return checkProvidedType(
-        DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
-  }
-  if (!attrType &&
-      parseToken(Token::colon, "expected ':' after dense array type"))
+  if (consumeIf(Token::greater))
+    return DenseArrayAttr::get(eltType, 0, {});
+
+  if (parseToken(Token::colon, "expected ':' after dense array type"))
     return {};
 
   DenseArrayElementParser eltParser(eltType);
@@ -991,7 +968,7 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
   }
   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
     return {};
-  return checkProvidedType(eltParser.getAttr());
+  return eltParser.getAttr();
 }
 
 /// Parse a dense elements attribute.

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8fa7b3ec1243c..f6916554016eb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2197,11 +2197,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
     stridedLayoutAttr.print(os);
   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
     os << "array<";
-    if (typeElision != AttrTypeElision::Must)
-      printType(denseArrayAttr.getType().getElementType());
+    printType(denseArrayAttr.getElementType());
     if (!denseArrayAttr.empty()) {
-      if (typeElision != AttrTypeElision::Must)
-        os << ": ";
+      os << ": ";
       printDenseArrayAttr(denseArrayAttr);
     }
     os << ">";

diff  --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 99f7380727774..e73ca99b0d90d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -690,69 +690,21 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
 
 LogicalResult
 DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
-                       RankedTensorType type, ArrayRef<char> rawData) {
-  if (type.getRank() != 1)
-    return emitError() << "expected rank 1 tensor type";
-  if (!type.getElementType().isIntOrIndexOrFloat())
+                       Type elementType, int64_t size, ArrayRef<char> rawData) {
+  if (!elementType.isIntOrIndexOrFloat())
     return emitError() << "expected integer or floating point element type";
   int64_t dataSize = rawData.size();
-  int64_t size = type.getShape().front();
-  if (type.getElementType().isInteger(1)) {
-    if (size != dataSize)
-      return emitError() << "expected " << size
-                         << " bytes for i1 array but got " << dataSize;
-  } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
+  int64_t elementSize =
+      llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
+  if (size * elementSize != dataSize) {
     return emitError() << "expected data size (" << size << " elements, "
-                       << type.getElementTypeBitWidth()
-                       << " bits each) does not match: " << dataSize
+                       << elementSize
+                       << " bytes each) does not match: " << dataSize
                        << " bytes";
   }
   return success();
 }
 
-FailureOr<const bool *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
-  if (auto attr = dyn_cast<DenseBoolArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int8_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
-  if (auto attr = dyn_cast<DenseI8ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int16_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
-  if (auto attr = dyn_cast<DenseI16ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int32_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
-  if (auto attr = dyn_cast<DenseI32ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const int64_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
-  if (auto attr = dyn_cast<DenseI64ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const float *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
-  if (auto attr = dyn_cast<DenseF32ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-FailureOr<const double *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
-  if (auto attr = dyn_cast<DenseF64ArrayAttr>())
-    return attr.asArrayRef().begin();
-  return failure();
-}
-
 namespace {
 /// Instantiations of this class provide utilities for interacting with native
 /// data types in the context of DenseArrayAttr.
@@ -898,12 +850,11 @@ DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
 template <typename T>
 DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
                                                  ArrayRef<T> content) {
-  auto shapedType = RankedTensorType::get(
-      content.size(), DenseArrayAttrUtil<T>::getElementType(context));
+  Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
   auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
                                  content.size() * sizeof(T));
-  return Base::get(context, shapedType, rawArray)
-      .template cast<DenseArrayAttrImpl<T>>();
+  return llvm::cast<DenseArrayAttrImpl<T>>(
+      Base::get(context, elementType, content.size(), rawArray));
 }
 
 template <typename T>

diff  --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 18f3ea8f2382e..22a563dd7b2a2 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -494,17 +494,20 @@ void BuiltinDialectBytecodeInterface::write(
 
 DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
     DialectBytecodeReader &reader) const {
-  RankedTensorType type;
+  Type elementType;
+  uint64_t size;
   ArrayRef<char> blob;
-  if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+  if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) ||
+      failed(reader.readBlob(blob)))
     return DenseArrayAttr();
-  return DenseArrayAttr::get(type, blob);
+  return DenseArrayAttr::get(elementType, size, blob);
 }
 
 void BuiltinDialectBytecodeInterface::write(
     DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
   writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
-  writer.writeType(attr.getType());
+  writer.writeType(attr.getElementType());
+  writer.writeVarInt(attr.getSize());
   writer.writeOwnedBlob(attr.getRawData());
 }
 

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 3081702682658..d494824ec7e7c 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -630,8 +630,6 @@ func.func @dense_array_attr() attributes {
     x7_f16 = array<f16: 1., 3.>
   }: () -> ()
 
-  // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
-  test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
   return
 }
 

diff  --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index 38b2b8aebb8a7..aac87f24ae31f 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -27,27 +27,6 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
 // expected-error at below {{Test iterating `IntegerAttr`: }}
 arith.constant dense<> : tensor<0xi64>
 
-// expected-error at below {{Test iterating `bool`: true, false, true, false, true, false}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i1: true, false, true, false, true, false>
-// expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i8: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i16: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i32: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
-arith.constant array<i64: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f32: 10., 11., -12., 13., 14.>
-// expected-error at below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f64: 10., 11., -12., 13., 14.>
-
 // Check that we handle an external constant parsed from the config.
 // expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
 // expected-error at below {{Test iterating `uint64_t`: 1, 2, 3}}

diff  --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 8e57afa41ba88..5343f971537b9 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -546,18 +546,3 @@ func.func @duplicate_dictionary_attr_key() {
 
 // expected-error at below {{expected '>' to close an array attribute}}
 #attr = array<i8: 1)
-
-// -----
-
-// expected-error at below {{dense array attribute expected ranked tensor type}}
-test.typed_attr i32 = array<1>
-
-// -----
-
-// expected-error at below {{does not match parsed type}}
-test.typed_attr tensor<1xi32> = array<>
-
-// -----
-
-// expected-error at below {{does not match parsed type}}
-test.typed_attr tensor<0xi32> = array<1>

diff  --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index cbef0bca2494d..9313f403ce1c5 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -21,10 +21,6 @@ template <typename T>
 static void printOneElement(InFlightDiagnostic &os, T value) {
   os << llvm::formatv("{0}", value).str();
 }
-template <>
-void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
-  os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
-}
 
 namespace {
 struct TestElementsAttrInterface
@@ -41,32 +37,6 @@ struct TestElementsAttrInterface
         auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
         if (!elementsAttr)
           continue;
-        if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
-          llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
-              .Case([&](DenseBoolArrayAttr attr) {
-                testElementsAttrIteration<bool>(op, attr, "bool");
-              })
-              .Case([&](DenseI8ArrayAttr attr) {
-                testElementsAttrIteration<int8_t>(op, attr, "int8_t");
-              })
-              .Case([&](DenseI16ArrayAttr attr) {
-                testElementsAttrIteration<int16_t>(op, attr, "int16_t");
-              })
-              .Case([&](DenseI32ArrayAttr attr) {
-                testElementsAttrIteration<int32_t>(op, attr, "int32_t");
-              })
-              .Case([&](DenseI64ArrayAttr attr) {
-                testElementsAttrIteration<int64_t>(op, attr, "int64_t");
-              })
-              .Case([&](DenseF32ArrayAttr attr) {
-                testElementsAttrIteration<float>(op, attr, "float");
-              })
-              .Case([&](DenseF64ArrayAttr attr) {
-                testElementsAttrIteration<double>(op, attr, "double");
-              });
-          testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
-          continue;
-        }
         testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
         testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
         testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");


        


More information about the Mlir-commits mailing list