[Mlir-commits] [mlir] c48e0cf - [mlir] Remove TypedAttr and ElementsAttr from DenseArrayAttr
Jeff Niu
llvmlistbot at llvm.org
Mon Dec 5 13:28:02 PST 2022
Author: Jeff Niu
Date: 2022-12-05T13:27:55-08:00
New Revision: c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a
URL: https://github.com/llvm/llvm-project/commit/c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a
DIFF: https://github.com/llvm/llvm-project/commit/c48e0cf03a50bb8a2043ac4bb5e9a83ff135247a.diff
LOG: [mlir] Remove TypedAttr and ElementsAttr from DenseArrayAttr
This patch removes the implementation of TypedAttr and ElementsAttr
from DenseArrayAttr and, in doing so, removes the need store a shaped
type. The attribute now stores a size (number of elements), an MLIR type
as a discriminator, and a raw byte array.
The intent of DenseArrayAttr was not to be a drop-in replacement for DenseElementsAttr. It was meant to be a simple container of integers or floats that map to C++ types. The ElementsAttr implementation on DenseArrayAttr had many holes in it, and fixing those holes would require evolving DenseArrayAttr in a way that is incompatible with its original purpose.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D137606
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialectBytecode.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/elements-attr-interface.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 64e5e0abfd763..2494773935dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -165,9 +165,8 @@ class GEPIndicesAdaptor {
if (*rawConstantIter == GEPOp::kDynamicIndex)
return *valuesIter;
- return IntegerAttr::get(
- ElementsAttr::getElementType(base->rawConstantIndices),
- *rawConstantIter);
+ return IntegerAttr::get(base->rawConstantIndices.getElementType(),
+ *rawConstantIter);
}
iterator &operator++() {
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 1a06c92e725ef..adb19dbcfcab2 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -155,8 +155,7 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
}];
}
-def Builtin_DenseArray : Builtin_Attr<
- "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
+def Builtin_DenseArray : Builtin_Attr<"DenseArray"> {
let summary = "A dense array of integer or floating point elements.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
@@ -195,43 +194,26 @@ def Builtin_DenseArray : Builtin_Attr<
}];
let parameters = (ins
- AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
+ "Type":$elementType,
+ "int64_t":$size,
Builtin_DenseArrayRawDataParameter:$rawData
);
let builders = [
- AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
+ AttrBuilderWithInferredContext<(ins "Type":$elementType, "unsigned":$size,
"ArrayRef<char>":$rawData), [{
- return $_get(type.getContext(), type, rawData);
+ return $_get(elementType.getContext(), elementType, size, rawData);
}]>,
];
- let extraClassDeclaration = [{
- /// Allow implicit conversion to ElementsAttr.
- operator ElementsAttr() const {
- return *this ? cast<ElementsAttr>() : nullptr;
- }
+ let genVerifyDecl = 1;
- /// ElementsAttr implementation.
- using ContiguousIterableTypesT =
- std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
- FailureOr<const bool *>
- try_value_begin_impl(OverloadToken<bool>) const;
- FailureOr<const int8_t *>
- try_value_begin_impl(OverloadToken<int8_t>) const;
- FailureOr<const int16_t *>
- try_value_begin_impl(OverloadToken<int16_t>) const;
- FailureOr<const int32_t *>
- try_value_begin_impl(OverloadToken<int32_t>) const;
- FailureOr<const int64_t *>
- try_value_begin_impl(OverloadToken<int64_t>) const;
- FailureOr<const float *>
- try_value_begin_impl(OverloadToken<float>) const;
- FailureOr<const double *>
- try_value_begin_impl(OverloadToken<double>) const;
+ let extraClassDeclaration = [{
+ /// Get the number of elements in the array.
+ int64_t size() const { return getSize(); }
+ /// Return true if there are no elements in the dense array.
+ bool empty() const { return !size(); }
}];
-
- let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index e10dd5e108cd2..b768be036778b 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -844,9 +844,7 @@ class DenseArrayElementParser {
ParseResult parseFloatElement(Parser &p);
/// Convert the current contents to a dense array.
- DenseArrayAttr getAttr() {
- return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
- }
+ DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); }
private:
/// Append the raw data of an APInt to the result.
@@ -934,18 +932,9 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
return {};
SMLoc typeLoc = getToken().getLoc();
- Type eltType;
- // If an attribute type was provided, use its element type.
- if (attrType) {
- auto tensorType = attrType.dyn_cast<RankedTensorType>();
- if (!tensorType) {
- emitError(typeLoc, "dense array attribute expected ranked tensor type");
- return {};
- }
- eltType = tensorType.getElementType();
-
- // Otherwise, parse a type.
- } else if (!(eltType = parseType())) {
+ Type eltType = parseType();
+ if (!eltType) {
+ emitError(typeLoc, "expected an integer or floating point type");
return {};
}
@@ -960,23 +949,11 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
return {};
}
- // If a type was provided, check that it matches the parsed type.
- auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
- if (attrType && result.getType() != attrType) {
- emitError(typeLoc, "expected attribute type ")
- << attrType << " does not match parsed type " << result.getType();
- return {};
- }
- return result;
- };
-
// Check for empty list.
- if (consumeIf(Token::greater)) {
- return checkProvidedType(
- DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
- }
- if (!attrType &&
- parseToken(Token::colon, "expected ':' after dense array type"))
+ if (consumeIf(Token::greater))
+ return DenseArrayAttr::get(eltType, 0, {});
+
+ if (parseToken(Token::colon, "expected ':' after dense array type"))
return {};
DenseArrayElementParser eltParser(eltType);
@@ -991,7 +968,7 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) {
}
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
return {};
- return checkProvidedType(eltParser.getAttr());
+ return eltParser.getAttr();
}
/// Parse a dense elements attribute.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 8fa7b3ec1243c..f6916554016eb 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2197,11 +2197,9 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
stridedLayoutAttr.print(os);
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
os << "array<";
- if (typeElision != AttrTypeElision::Must)
- printType(denseArrayAttr.getType().getElementType());
+ printType(denseArrayAttr.getElementType());
if (!denseArrayAttr.empty()) {
- if (typeElision != AttrTypeElision::Must)
- os << ": ";
+ os << ": ";
printDenseArrayAttr(denseArrayAttr);
}
os << ">";
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 99f7380727774..e73ca99b0d90d 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -690,69 +690,21 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
LogicalResult
DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- RankedTensorType type, ArrayRef<char> rawData) {
- if (type.getRank() != 1)
- return emitError() << "expected rank 1 tensor type";
- if (!type.getElementType().isIntOrIndexOrFloat())
+ Type elementType, int64_t size, ArrayRef<char> rawData) {
+ if (!elementType.isIntOrIndexOrFloat())
return emitError() << "expected integer or floating point element type";
int64_t dataSize = rawData.size();
- int64_t size = type.getShape().front();
- if (type.getElementType().isInteger(1)) {
- if (size != dataSize)
- return emitError() << "expected " << size
- << " bytes for i1 array but got " << dataSize;
- } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
+ int64_t elementSize =
+ llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT);
+ if (size * elementSize != dataSize) {
return emitError() << "expected data size (" << size << " elements, "
- << type.getElementTypeBitWidth()
- << " bits each) does not match: " << dataSize
+ << elementSize
+ << " bytes each) does not match: " << dataSize
<< " bytes";
}
return success();
}
-FailureOr<const bool *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
- if (auto attr = dyn_cast<DenseBoolArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const int8_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
- if (auto attr = dyn_cast<DenseI8ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const int16_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
- if (auto attr = dyn_cast<DenseI16ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const int32_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
- if (auto attr = dyn_cast<DenseI32ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const int64_t *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
- if (auto attr = dyn_cast<DenseI64ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const float *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
- if (auto attr = dyn_cast<DenseF32ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-FailureOr<const double *>
-DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
- if (auto attr = dyn_cast<DenseF64ArrayAttr>())
- return attr.asArrayRef().begin();
- return failure();
-}
-
namespace {
/// Instantiations of this class provide utilities for interacting with native
/// data types in the context of DenseArrayAttr.
@@ -898,12 +850,11 @@ DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
template <typename T>
DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
ArrayRef<T> content) {
- auto shapedType = RankedTensorType::get(
- content.size(), DenseArrayAttrUtil<T>::getElementType(context));
+ Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
content.size() * sizeof(T));
- return Base::get(context, shapedType, rawArray)
- .template cast<DenseArrayAttrImpl<T>>();
+ return llvm::cast<DenseArrayAttrImpl<T>>(
+ Base::get(context, elementType, content.size(), rawArray));
}
template <typename T>
diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp
index 18f3ea8f2382e..22a563dd7b2a2 100644
--- a/mlir/lib/IR/BuiltinDialectBytecode.cpp
+++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp
@@ -494,17 +494,20 @@ void BuiltinDialectBytecodeInterface::write(
DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr(
DialectBytecodeReader &reader) const {
- RankedTensorType type;
+ Type elementType;
+ uint64_t size;
ArrayRef<char> blob;
- if (failed(reader.readType(type)) || failed(reader.readBlob(blob)))
+ if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) ||
+ failed(reader.readBlob(blob)))
return DenseArrayAttr();
- return DenseArrayAttr::get(type, blob);
+ return DenseArrayAttr::get(elementType, size, blob);
}
void BuiltinDialectBytecodeInterface::write(
DenseArrayAttr attr, DialectBytecodeWriter &writer) const {
writer.writeVarInt(builtin_encoding::kDenseArrayAttr);
- writer.writeType(attr.getType());
+ writer.writeType(attr.getElementType());
+ writer.writeVarInt(attr.getSize());
writer.writeOwnedBlob(attr.getRawData());
}
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 3081702682658..d494824ec7e7c 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -630,8 +630,6 @@ func.func @dense_array_attr() attributes {
x7_f16 = array<f16: 1., 3.>
}: () -> ()
- // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
- test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
return
}
diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir
index 38b2b8aebb8a7..aac87f24ae31f 100644
--- a/mlir/test/IR/elements-attr-interface.mlir
+++ b/mlir/test/IR/elements-attr-interface.mlir
@@ -27,27 +27,6 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
// expected-error at below {{Test iterating `IntegerAttr`: }}
arith.constant dense<> : tensor<0xi64>
-// expected-error at below {{Test iterating `bool`: true, false, true, false, true, false}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i1: true, false, true, false, true, false>
-// expected-error at below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i8: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i16: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<i32: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}}
-arith.constant array<i64: 10, 11, -12, 13, 14>
-// expected-error at below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f32: 10., 11., -12., 13., 14.>
-// expected-error at below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}}
-// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
-arith.constant array<f64: 10., 11., -12., 13., 14.>
-
// Check that we handle an external constant parsed from the config.
// expected-error at below {{Test iterating `int64_t`: unable to iterate type}}
// expected-error at below {{Test iterating `uint64_t`: 1, 2, 3}}
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 8e57afa41ba88..5343f971537b9 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -546,18 +546,3 @@ func.func @duplicate_dictionary_attr_key() {
// expected-error at below {{expected '>' to close an array attribute}}
#attr = array<i8: 1)
-
-// -----
-
-// expected-error at below {{dense array attribute expected ranked tensor type}}
-test.typed_attr i32 = array<1>
-
-// -----
-
-// expected-error at below {{does not match parsed type}}
-test.typed_attr tensor<1xi32> = array<>
-
-// -----
-
-// expected-error at below {{does not match parsed type}}
-test.typed_attr tensor<0xi32> = array<1>
diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index cbef0bca2494d..9313f403ce1c5 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -21,10 +21,6 @@ template <typename T>
static void printOneElement(InFlightDiagnostic &os, T value) {
os << llvm::formatv("{0}", value).str();
}
-template <>
-void printOneElement<int8_t>(InFlightDiagnostic &os, int8_t value) {
- os << llvm::formatv("{0}", static_cast<int64_t>(value)).str();
-}
namespace {
struct TestElementsAttrInterface
@@ -41,32 +37,6 @@ struct TestElementsAttrInterface
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
- if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
- llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
- .Case([&](DenseBoolArrayAttr attr) {
- testElementsAttrIteration<bool>(op, attr, "bool");
- })
- .Case([&](DenseI8ArrayAttr attr) {
- testElementsAttrIteration<int8_t>(op, attr, "int8_t");
- })
- .Case([&](DenseI16ArrayAttr attr) {
- testElementsAttrIteration<int16_t>(op, attr, "int16_t");
- })
- .Case([&](DenseI32ArrayAttr attr) {
- testElementsAttrIteration<int32_t>(op, attr, "int32_t");
- })
- .Case([&](DenseI64ArrayAttr attr) {
- testElementsAttrIteration<int64_t>(op, attr, "int64_t");
- })
- .Case([&](DenseF32ArrayAttr attr) {
- testElementsAttrIteration<float>(op, attr, "float");
- })
- .Case([&](DenseF64ArrayAttr attr) {
- testElementsAttrIteration<double>(op, attr, "double");
- });
- testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
- continue;
- }
testElementsAttrIteration<int64_t>(op, elementsAttr, "int64_t");
testElementsAttrIteration<uint64_t>(op, elementsAttr, "uint64_t");
testElementsAttrIteration<APInt>(op, elementsAttr, "APInt");
More information about the Mlir-commits
mailing list