[Mlir-commits] [mlir] cec7e80 - [mlir] Make DenseArrayAttr generic
Jeff Niu
llvmlistbot at llvm.org
Tue Aug 30 13:29:38 PDT 2022
Author: Jeff Niu
Date: 2022-08-30T13:29:24-07:00
New Revision: cec7e80ebd5dafc20974f2554db1b2eb0e2175b7
URL: https://github.com/llvm/llvm-project/commit/cec7e80ebd5dafc20974f2554db1b2eb0e2175b7
DIFF: https://github.com/llvm/llvm-project/commit/cec7e80ebd5dafc20974f2554db1b2eb0e2175b7.diff
LOG: [mlir] Make DenseArrayAttr generic
This patch turns `DenseArrayBaseAttr` into a fully-functional attribute by
adding a generic parser and printer, supporting bool or integer and floating
point element types with bitwidths divisible by 8. It has been renamed
to `DenseArrayAttr`. The patch maintains the specialized subclasses,
e.g. `DenseI32ArrayAttr`, which remain the preferred API for accessing
elements in C++.
This allows `DenseArrayAttr` to hold signed and unsigned integer elements:
```
array<si8: -128, 127>
array<ui8: 255>
```
"Exotic" floating point elements:
```
array<bf16: 1.2, 3.4>
```
And integers of other bitwidths:
```
array<i24: 8388607>
```
Reviewed By: rriddle, lattner
Differential Revision: https://reviews.llvm.org/D132758
Added:
Modified:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 4f2b41540a148..a102311beabca 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -761,9 +761,9 @@ namespace detail {
/// Base class for DenseArrayAttr that is instantiated and specialized for each
/// supported element type below.
template <typename T>
-class DenseArrayAttr : public DenseArrayBaseAttr {
+class DenseArrayAttrImpl : public DenseArrayAttr {
public:
- using DenseArrayBaseAttr::DenseArrayBaseAttr;
+ using DenseArrayAttr::DenseArrayAttr;
/// Implicit conversion to ArrayRef<T>.
operator ArrayRef<T>() const;
@@ -773,7 +773,7 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
T operator[](std::size_t index) const { return asArrayRef()[index]; }
/// Builder from ArrayRef<T>.
- static DenseArrayAttr get(MLIRContext *context, ArrayRef<T> content);
+ static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef<T> content);
/// Print the short form `[42, 100, -1]` without any type prefix.
void print(AsmPrinter &printer) const;
@@ -791,23 +791,23 @@ class DenseArrayAttr : public DenseArrayBaseAttr {
static bool classof(Attribute attr);
};
-extern template class DenseArrayAttr<bool>;
-extern template class DenseArrayAttr<int8_t>;
-extern template class DenseArrayAttr<int16_t>;
-extern template class DenseArrayAttr<int32_t>;
-extern template class DenseArrayAttr<int64_t>;
-extern template class DenseArrayAttr<float>;
-extern template class DenseArrayAttr<double>;
+extern template class DenseArrayAttrImpl<bool>;
+extern template class DenseArrayAttrImpl<int8_t>;
+extern template class DenseArrayAttrImpl<int16_t>;
+extern template class DenseArrayAttrImpl<int32_t>;
+extern template class DenseArrayAttrImpl<int64_t>;
+extern template class DenseArrayAttrImpl<float>;
+extern template class DenseArrayAttrImpl<double>;
} // namespace detail
// Public name for all the supported DenseArrayAttr
-using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
-using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
-using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
-using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
-using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
-using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
-using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+using DenseBoolArrayAttr = detail::DenseArrayAttrImpl<bool>;
+using DenseI8ArrayAttr = detail::DenseArrayAttrImpl<int8_t>;
+using DenseI16ArrayAttr = detail::DenseArrayAttrImpl<int16_t>;
+using DenseI32ArrayAttr = detail::DenseArrayAttrImpl<int32_t>;
+using DenseI64ArrayAttr = detail::DenseArrayAttrImpl<int64_t>;
+using DenseF32ArrayAttr = detail::DenseArrayAttrImpl<float>;
+using DenseF64ArrayAttr = detail::DenseArrayAttrImpl<double>;
//===----------------------------------------------------------------------===//
// DenseResourceElementsAttr
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 96ab4f2b92ece..2c4b92b6fa5fc 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -140,7 +140,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array", [
}
//===----------------------------------------------------------------------===//
-// DenseArrayBaseAttr
+// DenseArrayAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
@@ -155,23 +155,28 @@ def Builtin_DenseArrayRawDataParameter : ArrayRefParameter<
}];
}
-def Builtin_DenseArrayBase : Builtin_Attr<
- "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> {
- let summary = "A dense array of i8, i16, i32, i64, f32, or f64.";
+def Builtin_DenseArray : Builtin_Attr<
+ "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> {
+ let summary = "A dense array of integer or floating point elements.";
let description = [{
A dense array attribute is an attribute that represents a dense array of
primitive element types. Contrary to DenseIntOrFPElementsAttr this is a
flat unidimensional array which does not have a storage optimization for
splat. This allows to expose the raw array through a C++ API as
- `ArrayRef<T>`. This is the base class attribute, the actual access is
- intended to be managed through the subclasses `DenseI8ArrayAttr`,
- `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`,
- `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`.
+ `ArrayRef<T>` for compatible types. The element type must be bool or an
+ integer or float whose bitwidth is a multiple of 8. Bool elements are stored
+ as bytes.
+
+ This is the base class attribute. Access to C++ types is intended to be
+ managed through the subclasses `DenseI8ArrayAttr`, `DenseI16ArrayAttr`,
+ `DenseI32ArrayAttr`, `DenseI64ArrayAttr`, `DenseF32ArrayAttr`,
+ and `DenseF64ArrayAttr`.
Syntax:
```
- dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]`
+ dense-array-attribute ::= `array` `<` (integer-type | float-type)
+ (`:` tensor-literal)? `>`
```
Examples:
@@ -181,16 +186,26 @@ def Builtin_DenseArrayBase : Builtin_Attr<
array<f64: 42., 12.>
```
- when a specific subclass is used as argument of an operation, the declarative
- assembly will omit the type and print directly:
- ```
+ When a specific subclass is used as argument of an operation, the
+ declarative assembly will omit the type and print directly:
+
+ ```mlir
[1, 2, 3]
```
}];
+
let parameters = (ins
AttributeSelfTypeParameter<"", "RankedTensorType">:$type,
Builtin_DenseArrayRawDataParameter:$rawData
);
+
+ let builders = [
+ AttrBuilderWithInferredContext<(ins "RankedTensorType":$type,
+ "ArrayRef<char>":$rawData), [{
+ return $_get(type.getContext(), type, rawData);
+ }]>,
+ ];
+
let extraClassDeclaration = [{
/// Allow implicit conversion to ElementsAttr.
operator ElementsAttr() const {
@@ -207,13 +222,9 @@ def Builtin_DenseArrayBase : Builtin_Attr<
const int64_t *value_begin_impl(OverloadToken<int64_t>) const;
const float *value_begin_impl(OverloadToken<float>) const;
const double *value_begin_impl(OverloadToken<double>) const;
-
- /// Printer for the short form: will dispatch to the appropriate subclass.
- void print(AsmPrinter &printer) const;
- void print(raw_ostream &os) const;
- /// Print the short form `42, 100, -1` without any braces or prefix.
- void printWithoutBraces(raw_ostream &os) const;
}];
+
+ let genVerifyDecl = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 8fadc2933e94d..0bc73234cfe57 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -827,96 +827,142 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
}
//===----------------------------------------------------------------------===//
-// ElementsAttr Parser
+// DenseArrayAttr Parser
//===----------------------------------------------------------------------===//
namespace {
-/// This class provides an implementation of AsmParser, allowing to call back
-/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined
-/// in libMLIRIR.
-class CustomAsmParser : public AsmParserImpl<AsmParser> {
+/// A generic dense array element parser. It parsers integer and floating point
+/// elements.
+class DenseArrayElementParser {
public:
- CustomAsmParser(Parser &parser)
- : AsmParserImpl<AsmParser>(parser.getToken().getLoc(), parser) {}
+ explicit DenseArrayElementParser(Type type) : type(type) {}
+
+ /// Parse an integer element.
+ ParseResult parseIntegerElement(Parser &p);
+
+ /// Parse a floating point element.
+ ParseResult parseFloatElement(Parser &p);
+
+ /// Convert the current contents to a dense array.
+ DenseArrayAttr getAttr() {
+ return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData);
+ }
+
+private:
+ /// Append the raw data of an APInt to the result.
+ void append(const APInt &data);
+
+ /// The array element type.
+ Type type;
+ /// The resultant byte array representing the contents of the array.
+ std::vector<char> rawData;
+ /// The number of elements in the array.
+ int64_t size = 0;
};
} // namespace
+void DenseArrayElementParser::append(const APInt &data) {
+ unsigned byteSize = data.getBitWidth() / 8;
+ size_t offset = rawData.size();
+ rawData.insert(rawData.end(), byteSize, 0);
+ llvm::StoreIntToMemory(
+ data, reinterpret_cast<uint8_t *>(rawData.data() + offset), byteSize);
+ ++size;
+}
+
+ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
+ bool isNegative = p.consumeIf(Token::minus);
+
+ // Parse an integer literal as an APInt.
+ Optional<APInt> value;
+ StringRef spelling = p.getToken().getSpelling();
+ if (p.getToken().isAny(Token::kw_true, Token::kw_false)) {
+ if (!type.isInteger(1))
+ return p.emitError("expected i1 type for 'true' or 'false' values");
+ value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true),
+ !type.isUnsignedInteger());
+ p.consumeToken();
+ } else if (p.consumeIf(Token::integer)) {
+ value = buildAttributeAPInt(type, isNegative, spelling);
+ if (!value)
+ return p.emitError("integer constant out of range");
+ } else {
+ return p.emitError("expected integer literal");
+ }
+ append(*value);
+ return success();
+}
+
+ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
+ bool isNegative = p.consumeIf(Token::minus);
+
+ Token token = p.getToken();
+ Optional<APFloat> result;
+ auto floatType = type.cast<FloatType>();
+ if (p.consumeIf(Token::integer)) {
+ // Parse an integer literal as a float.
+ if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
+ floatType.getFloatSemantics(),
+ floatType.getWidth()))
+ return failure();
+ } else if (p.consumeIf(Token::floatliteral)) {
+ // Parse a floating point literal.
+ Optional<double> val = token.getFloatingPointValue();
+ if (!val)
+ return failure();
+ result = APFloat(isNegative ? -*val : *val);
+ if (!type.isF64()) {
+ bool unused;
+ result->convert(floatType.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &unused);
+ }
+ } else {
+ return p.emitError("expected integer or floating point literal");
+ }
+
+ append(result->bitcastToAPInt());
+ return success();
+}
+
/// Parse a dense array attribute.
Attribute Parser::parseDenseArrayAttr(Type type) {
consumeToken(Token::kw_array);
+ if (parseToken(Token::less, "expected '<' after 'array'"))
+ return {};
+
+ // Only bool or integer and floating point elements divisible by bytes are
+ // supported.
SMLoc typeLoc = getToken().getLoc();
- if (parseToken(Token::less, "expected '<' after 'array'") ||
- (!type && !(type = parseType())))
+ if (!type && !(type = parseType()))
+ return {};
+ if (!type.isIntOrIndexOrFloat()) {
+ emitError(typeLoc, "expected integer or float type, got: ") << type;
+ return {};
+ }
+ if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
+ emitError(typeLoc, "element type bitwidth must be a multiple of 8");
return {};
- CustomAsmParser parser(*this);
- Attribute result;
+ }
+
// Check for empty list.
- bool isEmptyList = getToken().is(Token::greater);
- if (!isEmptyList &&
- parseToken(Token::colon, "expected ':' after dense array type"))
+ if (consumeIf(Token::greater))
+ return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
+ if (parseToken(Token::colon, "expected ':' after dense array type"))
return {};
- if (auto intType = type.dyn_cast<IntegerType>()) {
- switch (type.getIntOrFloatBitWidth()) {
- case 1:
- if (isEmptyList)
- result = DenseBoolArrayAttr::get(parser.getContext(), {});
- else
- result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- case 8:
- if (isEmptyList)
- result = DenseI8ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- case 16:
- if (isEmptyList)
- result = DenseI16ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- case 32:
- if (isEmptyList)
- result = DenseI32ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- case 64:
- if (isEmptyList)
- result = DenseI64ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- default:
- emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
- return {};
- }
- } else if (auto floatType = type.dyn_cast<FloatType>()) {
- switch (type.getIntOrFloatBitWidth()) {
- case 32:
- if (isEmptyList)
- result = DenseF32ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- case 64:
- if (isEmptyList)
- result = DenseF64ArrayAttr::get(parser.getContext(), {});
- else
- result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{});
- break;
- default:
- emitError(typeLoc, "expected f32 or f64 but got: ") << type;
+ DenseArrayElementParser eltParser(type);
+ if (type.isIntOrIndex()) {
+ if (parseCommaSeparatedList(
+ [&] { return eltParser.parseIntegerElement(*this); }))
return {};
- }
} else {
- emitError(typeLoc, "expected integer or float type, got: ") << type;
- return {};
+ if (parseCommaSeparatedList(
+ [&] { return eltParser.parseFloatElement(*this); }))
+ return {};
}
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
return {};
- return result;
+ return eltParser.getAttr();
}
/// Parse a dense elements attribute.
diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
index c50096bb1c1b8..b02484ab9cc5b 100644
--- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp
@@ -383,7 +383,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size,
// Accessors.
intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) {
- return unwrap(attr).cast<DenseArrayBaseAttr>().size();
+ return unwrap(attr).cast<DenseArrayAttr>().size();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3bb67fcbc39a8..3f20a6d0fda1c 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1476,6 +1476,9 @@ class AsmPrinter::Impl {
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
bool allowHex);
+ /// Print a dense array attribute.
+ void printDenseArrayAttr(DenseArrayAttr attr);
+
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
@@ -1860,12 +1863,13 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
}
} else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
stridedLayoutAttr.print(os);
- } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
+ } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
typeElision = AttrTypeElision::Must;
os << "array<" << denseArrayAttr.getType().getElementType();
- if (!denseArrayAttr.empty())
+ if (!denseArrayAttr.empty()) {
os << ": ";
- denseArrayAttr.printWithoutBraces(os);
+ printDenseArrayAttr(denseArrayAttr);
+ }
os << ">";
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
os << "dense_resource<";
@@ -1890,11 +1894,11 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
/// Print the integer element of a DenseElementsAttr.
static void printDenseIntElement(const APInt &value, raw_ostream &os,
- bool isSigned) {
- if (value.getBitWidth() == 1)
+ Type type) {
+ if (type.isInteger(1))
os << (value.getBoolValue() ? "true" : "false");
else
- value.print(os, isSigned);
+ value.print(os, !type.isUnsignedInteger());
}
static void
@@ -1988,14 +1992,13 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
// printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
// and hence was replaced.
if (complexElementType.isa<IntegerType>()) {
- bool isSigned = !complexElementType.isUnsignedInteger();
auto valueIt = attr.value_begin<std::complex<APInt>>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
auto complexValue = *(valueIt + index);
os << "(";
- printDenseIntElement(complexValue.real(), os, isSigned);
+ printDenseIntElement(complexValue.real(), os, complexElementType);
os << ",";
- printDenseIntElement(complexValue.imag(), os, isSigned);
+ printDenseIntElement(complexValue.imag(), os, complexElementType);
os << ")";
});
} else {
@@ -2010,10 +2013,9 @@ void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
});
}
} else if (elementType.isIntOrIndex()) {
- bool isSigned = !elementType.isUnsignedInteger();
auto valueIt = attr.value_begin<APInt>();
printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
- printDenseIntElement(*(valueIt + index), os, isSigned);
+ printDenseIntElement(*(valueIt + index), os, elementType);
});
} else {
assert(elementType.isa<FloatType>() && "unexpected element type");
@@ -2031,6 +2033,29 @@ void AsmPrinter::Impl::printDenseStringElementsAttr(
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
+void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
+ Type type = attr.getElementType();
+ unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
+ unsigned byteSize = bitwidth / 8;
+ ArrayRef<char> data = attr.getRawData();
+
+ auto printElementAt = [&](unsigned i) {
+ APInt value(bitwidth, 0);
+ llvm::LoadIntFromMemory(
+ value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
+ byteSize);
+ // Print the data as-is or as a float.
+ if (type.isIntOrIndex()) {
+ printDenseIntElement(value, getStream(), type);
+ } else {
+ APFloat fltVal(type.cast<FloatType>().getFloatSemantics(), value);
+ printFloatValue(fltVal, getStream());
+ }
+ };
+ llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
+ printElementAt);
+}
+
void AsmPrinter::Impl::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index dab80f1f942a8..8d060ef233872 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -741,50 +741,50 @@ DenseElementsAttr::ComplexIntElementIterator::operator*() const {
// DenseArrayAttr
//===----------------------------------------------------------------------===//
-const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
+LogicalResult
+DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ RankedTensorType type, ArrayRef<char> rawData) {
+ if (type.getRank() != 1)
+ return emitError() << "expected rank 1 tensor type";
+ if (!type.getElementType().isIntOrIndexOrFloat())
+ return emitError() << "expected integer or floating point element type";
+ int64_t dataSize = rawData.size();
+ int64_t size = type.getShape().front();
+ if (type.getElementType().isInteger(1)) {
+ if (size != dataSize)
+ return emitError() << "expected " << size
+ << " bytes for i1 array but got " << dataSize;
+ } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
+ return emitError() << "expected data size (" << size << " elements, "
+ << type.getElementTypeBitWidth()
+ << " bits each) does not match: " << dataSize
+ << " bytes";
+ }
+ return success();
+}
+
+const bool *DenseArrayAttr::value_begin_impl(OverloadToken<bool>) const {
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
}
-const int8_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
+const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken<int8_t>) const {
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
}
-const int16_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
+const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken<int16_t>) const {
return cast<DenseI16ArrayAttr>().asArrayRef().begin();
}
-const int32_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
+const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken<int32_t>) const {
return cast<DenseI32ArrayAttr>().asArrayRef().begin();
}
-const int64_t *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
+const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken<int64_t>) const {
return cast<DenseI64ArrayAttr>().asArrayRef().begin();
}
-const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
+const float *DenseArrayAttr::value_begin_impl(OverloadToken<float>) const {
return cast<DenseF32ArrayAttr>().asArrayRef().begin();
}
-const double *
-DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
+const double *DenseArrayAttr::value_begin_impl(OverloadToken<double>) const {
return cast<DenseF64ArrayAttr>().asArrayRef().begin();
}
-void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
- print(printer.getStream());
-}
-
-void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
- llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
- .Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
- DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
- DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
-}
-
-void DenseArrayBaseAttr::print(raw_ostream &os) const {
- os << "[";
- printWithoutBraces(os);
- os << "]";
-}
-
namespace {
/// Instantiations of this class provide utilities for interacting with native
/// data types in the context of DenseArrayAttr.
@@ -869,19 +869,19 @@ struct DenseArrayAttrUtil<double> {
} // namespace
template <typename T>
-void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
+void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
print(printer.getStream());
}
template <typename T>
-void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
+void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
llvm::interleaveComma(asArrayRef(), os, [&](T value) {
DenseArrayAttrUtil<T>::printElement(os, value);
});
}
template <typename T>
-void DenseArrayAttr<T>::print(raw_ostream &os) const {
+void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
os << "[";
printWithoutBraces(os);
os << "]";
@@ -889,8 +889,8 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
template <typename T>
-Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
- Type odsType) {
+Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
+ Type odsType) {
SmallVector<T> data;
if (failed(parser.parseCommaSeparatedList([&]() {
T value;
@@ -905,7 +905,7 @@ Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
template <typename T>
-Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
+Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
if (parser.parseLSquare())
return {};
// Handle empty list case.
@@ -919,7 +919,7 @@ Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
template <typename T>
-DenseArrayAttr<T>::operator ArrayRef<T>() const {
+DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
ArrayRef<char> raw = getRawData();
assert((raw.size() % sizeof(T)) == 0);
return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
@@ -928,19 +928,19 @@ DenseArrayAttr<T>::operator ArrayRef<T>() const {
/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
template <typename T>
-DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
- ArrayRef<T> content) {
+DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
+ ArrayRef<T> content) {
auto shapedType = RankedTensorType::get(
content.size(), DenseArrayAttrUtil<T>::getElementType(context));
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
content.size() * sizeof(T));
return Base::get(context, shapedType, rawArray)
- .template cast<DenseArrayAttr<T>>();
+ .template cast<DenseArrayAttrImpl<T>>();
}
template <typename T>
-bool DenseArrayAttr<T>::classof(Attribute attr) {
- if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
+bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
+ if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
return false;
}
@@ -948,13 +948,13 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
namespace mlir {
namespace detail {
// Explicit instantiation for all the supported DenseArrayAttr.
-template class DenseArrayAttr<bool>;
-template class DenseArrayAttr<int8_t>;
-template class DenseArrayAttr<int16_t>;
-template class DenseArrayAttr<int32_t>;
-template class DenseArrayAttr<int64_t>;
-template class DenseArrayAttr<float>;
-template class DenseArrayAttr<double>;
+template class DenseArrayAttrImpl<bool>;
+template class DenseArrayAttrImpl<int8_t>;
+template class DenseArrayAttrImpl<int16_t>;
+template class DenseArrayAttrImpl<int32_t>;
+template class DenseArrayAttrImpl<int64_t>;
+template class DenseArrayAttrImpl<float>;
+template class DenseArrayAttrImpl<double>;
} // namespace detail
} // namespace mlir
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index cf2d5332351e2..64051552830f1 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -569,6 +569,26 @@ func.func @dense_array_attr() attributes {
f64attr = [-142.]
// CHECK-SAME: emptyattr = []
emptyattr = []
+
+ // CHECK: array.sizes
+ // CHECK-SAME: i0 = array<i0: 0, 0>
+ // CHECK-SAME: ui0 = array<ui0: 0, 0>
+ // CHECK-SAME: si0 = array<si0: 0, 0>
+ // CHECK-SAME: i24 = array<i24: -42, 42, 8388607>
+ // CHECK-SAME: ui24 = array<ui24: 16777215>
+ // CHECK-SAME: si24 = array<si24: -8388608>
+ // CHECK-SAME: bf16 = array<bf16: 1.2{{[0-9]+}}e+00, 3.4{{[0-9]+}}e+00>
+ // CHECK-SAME: f16 = array<f16: 1.{{[0-9]+}}e+00, 3.{{[0-9]+}}e+00>
+ "array.sizes"() {
+ x0_i0 = array<i0: 0, 0>,
+ x1_ui0 = array<ui0: 0, 0>,
+ x2_si0 = array<si0: 0, 0>,
+ x3_i24 = array<i24: -42, 42, 8388607>,
+ x4_ui24 = array<ui24: 16777215>,
+ x5_si24 = array<si24: -8388608>,
+ x6_bf16 = array<bf16: 1.2, 3.4>,
+ x7_f16 = array<f16: 1., 3.>
+ }: () -> ()
return
}
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 8095119393f0a..0444bef62d2e0 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -521,3 +521,28 @@ func.func @duplicate_dictionary_attr_key() {
// expected-error at +1 {{`dense_resource` expected a shaped type}}
#attr = dense_resource<resource> : i32
+
+// -----
+
+// expected-error at below {{expected '<' after 'array'}}
+#attr = array
+
+// -----
+
+// expected-error at below {{expected integer or float type}}
+#attr = array<vector<i32>>
+
+// -----
+
+// expected-error at below {{element type bitwidth must be a multiple of 8}}
+#attr = array<i7>
+
+// -----
+
+// expected-error at below {{expected ':' after dense array type}}
+#attr = array<i8)
+
+// -----
+
+// expected-error at below {{expected '>' to close an array attribute}}
+#attr = array<i8: 1)
diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
index 6257799824c6f..23fde121682cc 100644
--- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
+++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp
@@ -41,9 +41,8 @@ struct TestElementsAttrInterface
auto elementsAttr = attr.getValue().dyn_cast<ElementsAttr>();
if (!elementsAttr)
continue;
- if (auto concreteAttr =
- attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
- llvm::TypeSwitch<DenseArrayBaseAttr>(concreteAttr)
+ if (auto concreteAttr = attr.getValue().dyn_cast<DenseArrayAttr>()) {
+ llvm::TypeSwitch<DenseArrayAttr>(concreteAttr)
.Case([&](DenseBoolArrayAttr attr) {
testElementsAttrIteration<bool>(op, attr, "bool");
})
More information about the Mlir-commits
mailing list