[Mlir-commits] [mlir] 7a7c069 - [mlir] Allow dense array to be parsed with type elision
Jeff Niu
llvmlistbot at llvm.org
Tue Aug 30 13:29:44 PDT 2022
Author: Jeff Niu
Date: 2022-08-30T13:29:25-07:00
New Revision: 7a7c0697cd97655dfe017fb7515509f8e39c1270
URL: https://github.com/llvm/llvm-project/commit/7a7c0697cd97655dfe017fb7515509f8e39c1270
DIFF: https://github.com/llvm/llvm-project/commit/7a7c0697cd97655dfe017fb7515509f8e39c1270.diff
LOG: [mlir] Allow dense array to be parsed with type elision
This patch makes parsing dense arrays with type elision work properly.
If a ranked tensor type is supplied to `parseAttribute` on a dense
array, the element type is skipped. Moreover, if type elision is set to
`AttrTypeElision::Must`, the element type is elided.
For example, this allows
```
memref.global @z : memref<3xi32> = array<1, 2, 3>
```
Fixes #57433
Depends on D132758
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D132964
Added:
Modified:
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/attribute.mlir
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 0bc73234cfe57..f4077a50b8051 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -925,33 +925,59 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
}
/// Parse a dense array attribute.
-Attribute Parser::parseDenseArrayAttr(Type type) {
+Attribute Parser::parseDenseArrayAttr(Type attrType) {
consumeToken(Token::kw_array);
if (parseToken(Token::less, "expected '<' after 'array'"))
return {};
- // Only bool or integer and floating point elements divisible by bytes are
- // supported.
SMLoc typeLoc = getToken().getLoc();
- if (!type && !(type = parseType()))
+ Type eltType;
+ // If an attribute type was provided, use its element type.
+ if (attrType) {
+ auto tensorType = attrType.dyn_cast<RankedTensorType>();
+ if (!tensorType) {
+ emitError(typeLoc, "dense array attribute expected ranked tensor type");
+ return {};
+ }
+ eltType = tensorType.getElementType();
+
+ // Otherwise, parse a type.
+ } else if (!(eltType = parseType())) {
return {};
- if (!type.isIntOrIndexOrFloat()) {
- emitError(typeLoc, "expected integer or float type, got: ") << type;
+ }
+
+ // Only bool or integer and floating point elements divisible by bytes are
+ // supported.
+ if (!eltType.isIntOrIndexOrFloat()) {
+ emitError(typeLoc, "expected integer or float type, got: ") << eltType;
return {};
}
- if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
+ if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
emitError(typeLoc, "element type bitwidth must be a multiple of 8");
return {};
}
+ // If a type was provided, check that it matches the parsed type.
+ auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
+ if (attrType && result.getType() != attrType) {
+ emitError(typeLoc, "expected attribute type ")
+ << attrType << " does not match parsed type " << result.getType();
+ return {};
+ }
+ return result;
+ };
+
// Check for empty list.
- if (consumeIf(Token::greater))
- return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
- if (parseToken(Token::colon, "expected ':' after dense array type"))
+ if (consumeIf(Token::greater)) {
+ return checkProvidedType(
+ DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
+ }
+ if (!attrType &&
+ parseToken(Token::colon, "expected ':' after dense array type"))
return {};
- DenseArrayElementParser eltParser(type);
- if (type.isIntOrIndex()) {
+ DenseArrayElementParser eltParser(eltType);
+ if (eltType.isIntOrIndex()) {
if (parseCommaSeparatedList(
[&] { return eltParser.parseIntegerElement(*this); }))
return {};
@@ -962,7 +988,7 @@ Attribute Parser::parseDenseArrayAttr(Type type) {
}
if (parseToken(Token::greater, "expected '>' to close an array attribute"))
return {};
- return eltParser.getAttr();
+ return checkProvidedType(eltParser.getAttr());
}
/// Parse a dense elements attribute.
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3f20a6d0fda1c..fedb452aa2fc9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1864,13 +1864,16 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
} else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
stridedLayoutAttr.print(os);
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
- typeElision = AttrTypeElision::Must;
- os << "array<" << denseArrayAttr.getType().getElementType();
+ os << "array<";
+ if (typeElision != AttrTypeElision::Must)
+ printType(denseArrayAttr.getType().getElementType());
if (!denseArrayAttr.empty()) {
- os << ": ";
+ if (typeElision != AttrTypeElision::Must)
+ os << ": ";
printDenseArrayAttr(denseArrayAttr);
}
os << ">";
+ return;
} else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
os << "dense_resource<";
printResourceHandle(resourceAttr.getRawHandle());
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 64051552830f1..7870606adab78 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -589,6 +589,9 @@ func.func @dense_array_attr() attributes {
x6_bf16 = array<bf16: 1.2, 3.4>,
x7_f16 = array<f16: 1., 3.>
}: () -> ()
+
+ // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
+ test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
return
}
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 0444bef62d2e0..49acce2cf1187 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -546,3 +546,18 @@ func.func @duplicate_dictionary_attr_key() {
// expected-error at below {{expected '>' to close an array attribute}}
#attr = array<i8: 1)
+
+// -----
+
+// expected-error at below {{dense array attribute expected ranked tensor type}}
+test.typed_attr i32 = array<1>
+
+// -----
+
+// expected-error at below {{does not match parsed type}}
+test.typed_attr tensor<1xi32> = array<>
+
+// -----
+
+// expected-error at below {{does not match parsed type}}
+test.typed_attr tensor<0xi32> = array<1>
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e75c7ea964a4b..ad15ebfdfdb64 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -460,6 +460,22 @@ TestDialect::getOperationPrinter(Operation *op) const {
return {};
}
+//===----------------------------------------------------------------------===//
+// TypedAttrOp
+//===----------------------------------------------------------------------===//
+
+/// Parse an attribute with a given type.
+static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
+ Attribute &attr) {
+ return parser.parseAttribute(attr, type.getValue());
+}
+
+/// Print an attribute without its type.
+static void printAttrElideType(AsmPrinter &printer, Operation *op,
+ TypeAttr type, Attribute attr) {
+ printer.printAttributeWithoutType(attr);
+}
+
//===----------------------------------------------------------------------===//
// TestBranchOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 18775b7841bc0..f0dd4e29be0f8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -270,6 +270,13 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
);
}
+def TypedAttrOp : TEST_Op<"typed_attr"> {
+ let arguments = (ins TypeAttr:$type, AnyAttr:$attr);
+ let assemblyFormat = [{
+ attr-dict $type `=` custom<AttrElideType>(ref($type), $attr)
+ }];
+}
+
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
let arguments = (ins
DenseBoolArrayAttr:$i1attr,
More information about the Mlir-commits
mailing list