[Mlir-commits] [mlir] 7a7c069 - [mlir] Allow dense array to be parsed with type elision

Jeff Niu llvmlistbot at llvm.org
Tue Aug 30 13:29:44 PDT 2022


Author: Jeff Niu
Date: 2022-08-30T13:29:25-07:00
New Revision: 7a7c0697cd97655dfe017fb7515509f8e39c1270

URL: https://github.com/llvm/llvm-project/commit/7a7c0697cd97655dfe017fb7515509f8e39c1270
DIFF: https://github.com/llvm/llvm-project/commit/7a7c0697cd97655dfe017fb7515509f8e39c1270.diff

LOG: [mlir] Allow dense array to be parsed with type elision

This patch makes parsing dense arrays with type elision work properly.
If a ranked tensor type is supplied to `parseAttribute` on a dense
array, the element type is skipped. Moreover, if type elision is set to
`AttrTypeElision::Must`, the element type is elided.

For example, this allows

```
memref.global @z : memref<3xi32> = array<1, 2, 3>
```

Fixes #57433

Depends on D132758

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D132964

Added: 
    

Modified: 
    mlir/lib/AsmParser/AttributeParser.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/IR/invalid-builtin-attributes.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 0bc73234cfe57..f4077a50b8051 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -925,33 +925,59 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
 }
 
 /// Parse a dense array attribute.
-Attribute Parser::parseDenseArrayAttr(Type type) {
+Attribute Parser::parseDenseArrayAttr(Type attrType) {
   consumeToken(Token::kw_array);
   if (parseToken(Token::less, "expected '<' after 'array'"))
     return {};
 
-  // Only bool or integer and floating point elements divisible by bytes are
-  // supported.
   SMLoc typeLoc = getToken().getLoc();
-  if (!type && !(type = parseType()))
+  Type eltType;
+  // If an attribute type was provided, use its element type.
+  if (attrType) {
+    auto tensorType = attrType.dyn_cast<RankedTensorType>();
+    if (!tensorType) {
+      emitError(typeLoc, "dense array attribute expected ranked tensor type");
+      return {};
+    }
+    eltType = tensorType.getElementType();
+
+    // Otherwise, parse a type.
+  } else if (!(eltType = parseType())) {
     return {};
-  if (!type.isIntOrIndexOrFloat()) {
-    emitError(typeLoc, "expected integer or float type, got: ") << type;
+  }
+
+  // Only bool or integer and floating point elements divisible by bytes are
+  // supported.
+  if (!eltType.isIntOrIndexOrFloat()) {
+    emitError(typeLoc, "expected integer or float type, got: ") << eltType;
     return {};
   }
-  if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) {
+  if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) {
     emitError(typeLoc, "element type bitwidth must be a multiple of 8");
     return {};
   }
 
+  // If a type was provided, check that it matches the parsed type.
+  auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute {
+    if (attrType && result.getType() != attrType) {
+      emitError(typeLoc, "expected attribute type ")
+          << attrType << " does not match parsed type " << result.getType();
+      return {};
+    }
+    return result;
+  };
+
   // Check for empty list.
-  if (consumeIf(Token::greater))
-    return DenseArrayAttr::get(RankedTensorType::get(0, type), {});
-  if (parseToken(Token::colon, "expected ':' after dense array type"))
+  if (consumeIf(Token::greater)) {
+    return checkProvidedType(
+        DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}));
+  }
+  if (!attrType &&
+      parseToken(Token::colon, "expected ':' after dense array type"))
     return {};
 
-  DenseArrayElementParser eltParser(type);
-  if (type.isIntOrIndex()) {
+  DenseArrayElementParser eltParser(eltType);
+  if (eltType.isIntOrIndex()) {
     if (parseCommaSeparatedList(
             [&] { return eltParser.parseIntegerElement(*this); }))
       return {};
@@ -962,7 +988,7 @@ Attribute Parser::parseDenseArrayAttr(Type type) {
   }
   if (parseToken(Token::greater, "expected '>' to close an array attribute"))
     return {};
-  return eltParser.getAttr();
+  return checkProvidedType(eltParser.getAttr());
 }
 
 /// Parse a dense elements attribute.

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3f20a6d0fda1c..fedb452aa2fc9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1864,13 +1864,16 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
   } else if (auto stridedLayoutAttr = attr.dyn_cast<StridedLayoutAttr>()) {
     stridedLayoutAttr.print(os);
   } else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayAttr>()) {
-    typeElision = AttrTypeElision::Must;
-    os << "array<" << denseArrayAttr.getType().getElementType();
+    os << "array<";
+    if (typeElision != AttrTypeElision::Must)
+      printType(denseArrayAttr.getType().getElementType());
     if (!denseArrayAttr.empty()) {
-      os << ": ";
+      if (typeElision != AttrTypeElision::Must)
+        os << ": ";
       printDenseArrayAttr(denseArrayAttr);
     }
     os << ">";
+    return;
   } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
     os << "dense_resource<";
     printResourceHandle(resourceAttr.getRawHandle());

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 64051552830f1..7870606adab78 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -589,6 +589,9 @@ func.func @dense_array_attr() attributes {
     x6_bf16 = array<bf16: 1.2, 3.4>,
     x7_f16 = array<f16: 1., 3.>
   }: () -> ()
+
+  // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
+  test.typed_attr tensor<4xi32> = array<1, 2, 3, 4>
   return
 }
 

diff  --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 0444bef62d2e0..49acce2cf1187 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -546,3 +546,18 @@ func.func @duplicate_dictionary_attr_key() {
 
 // expected-error at below {{expected '>' to close an array attribute}}
 #attr = array<i8: 1)
+
+// -----
+
+// expected-error at below {{dense array attribute expected ranked tensor type}}
+test.typed_attr i32 = array<1>
+
+// -----
+
+// expected-error at below {{does not match parsed type}}
+test.typed_attr tensor<1xi32> = array<>
+
+// -----
+
+// expected-error at below {{does not match parsed type}}
+test.typed_attr tensor<0xi32> = array<1>

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index e75c7ea964a4b..ad15ebfdfdb64 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -460,6 +460,22 @@ TestDialect::getOperationPrinter(Operation *op) const {
   return {};
 }
 
+//===----------------------------------------------------------------------===//
+// TypedAttrOp
+//===----------------------------------------------------------------------===//
+
+/// Parse an attribute with a given type.
+static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
+                                      Attribute &attr) {
+  return parser.parseAttribute(attr, type.getValue());
+}
+
+/// Print an attribute without its type.
+static void printAttrElideType(AsmPrinter &printer, Operation *op,
+                               TypeAttr type, Attribute attr) {
+  printer.printAttributeWithoutType(attr);
+}
+
 //===----------------------------------------------------------------------===//
 // TestBranchOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 18775b7841bc0..f0dd4e29be0f8 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -270,6 +270,13 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
   );
 }
 
+def TypedAttrOp : TEST_Op<"typed_attr"> {
+  let arguments = (ins TypeAttr:$type, AnyAttr:$attr);
+  let assemblyFormat = [{
+    attr-dict $type `=` custom<AttrElideType>(ref($type), $attr)
+  }];
+}
+
 def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
   let arguments = (ins
     DenseBoolArrayAttr:$i1attr,


        


More information about the Mlir-commits mailing list