[Mlir-commits] [mlir] 6b476e2 - [mlir] Add support for parsing optional Attribute values.

River Riddle llvmlistbot at llvm.org
Tue Jul 14 13:21:36 PDT 2020


Author: River Riddle
Date: 2020-07-14T13:14:59-07:00
New Revision: 6b476e2426e9cfa442dac5deed2ceae890513f18

URL: https://github.com/llvm/llvm-project/commit/6b476e2426e9cfa442dac5deed2ceae890513f18
DIFF: https://github.com/llvm/llvm-project/commit/6b476e2426e9cfa442dac5deed2ceae890513f18.diff

LOG: [mlir] Add support for parsing optional Attribute values.

This adds a `parseOptionalAttribute` method to the OpAsmParser that allows for parsing optional attributes, in a similar fashion to how optional types are parsed. This also enables the use of attribute values as the first element of an assembly format optional group.

Differential Revision: https://reviews.llvm.org/D83712

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index a10610f87a0a..c068aac09bab 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -713,7 +713,8 @@ of the assembly format can be marked as `optional` based on the presence of this
 information. An optional group is defined by wrapping a set of elements within
 `()` followed by a `?` and has the following requirements:
 
-*   The first element of the group must either be a literal or an operand.
+*   The first element of the group must either be a literal, attribute, or an
+    operand.
     -   This is because the first element must be optionally parsable.
 *   Exactly one argument variable within the group must be marked as the anchor
     of the group.

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 20660be4347c..0124ef5f7c0a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -384,6 +384,17 @@ class OpAsmParser {
                                      StringRef attrName,
                                      NamedAttrList &attrs) = 0;
 
+  /// Parse an optional attribute.
+  virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
+                                                     Type type,
+                                                     StringRef attrName,
+                                                     NamedAttrList &attrs) = 0;
+  OptionalParseResult parseOptionalAttribute(Attribute &result,
+                                             StringRef attrName,
+                                             NamedAttrList &attrs) {
+    return parseOptionalAttribute(result, Type(), attrName, attrs);
+  }
+
   /// Parse an attribute of a specific kind and type.
   template <typename AttrType>
   ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index e2860b115231..1c1261e6d765 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -187,6 +187,40 @@ Attribute Parser::parseAttribute(Type type) {
   }
 }
 
+/// Parse an optional attribute with the provided type.
+OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
+                                                   Type type) {
+  switch (getToken().getKind()) {
+  case Token::at_identifier:
+  case Token::floatliteral:
+  case Token::integer:
+  case Token::hash_identifier:
+  case Token::kw_affine_map:
+  case Token::kw_affine_set:
+  case Token::kw_dense:
+  case Token::kw_false:
+  case Token::kw_loc:
+  case Token::kw_opaque:
+  case Token::kw_sparse:
+  case Token::kw_true:
+  case Token::kw_unit:
+  case Token::l_brace:
+  case Token::l_square:
+  case Token::minus:
+  case Token::string:
+    attribute = parseAttribute(type);
+    return success(attribute != nullptr);
+
+  default:
+    // Parse an optional type attribute.
+    Type type;
+    OptionalParseResult result = parseOptionalType(type);
+    if (result.hasValue() && succeeded(*result))
+      attribute = TypeAttr::get(type);
+    return result;
+  }
+}
+
 /// Attribute dictionary.
 ///
 ///   attribute-dict ::= `{` `}`

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index fc9d449ecc14..3a995a4e2b04 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1011,6 +1011,17 @@ class CustomOpAsmParser : public OpAsmParser {
     return success();
   }
 
+  /// Parse an optional attribute.
+  OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
+                                             StringRef attrName,
+                                             NamedAttrList &attrs) override {
+    OptionalParseResult parseResult =
+        parser.parseOptionalAttribute(result, type);
+    if (parseResult.hasValue() && succeeded(*parseResult))
+      attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+    return parseResult;
+  }
+
   /// Parse a named dictionary into 'result' if it is present.
   ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
     if (parser.getToken().isNot(Token::l_brace))

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 3d82d622bf06..3b2c6e852544 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -184,6 +184,10 @@ class Parser {
   /// Parse an arbitrary attribute with an optional type.
   Attribute parseAttribute(Type type = {});
 
+  /// Parse an optional attribute with the provided type.
+  OptionalParseResult parseOptionalAttribute(Attribute &attribute,
+                                             Type type = {});
+
   /// Parse an attribute dictionary.
   ParseResult parseAttributeDict(NamedAttrList &attributes);
 

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 19e636b3df32..e73c7c3f3230 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1253,9 +1253,13 @@ def FormatAttrOp : TEST_Op<"format_attr_op"> {
 }
 
 // Test that we elide optional attributes that are within the syntax.
-def FormatOptAttrOp : TEST_Op<"format_opt_attr_op"> {
+def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> {
   let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
-  let assemblyFormat = "(`(`$opt_attr^`)`)? attr-dict";
+  let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict";
+}
+def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> {
+  let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
+  let assemblyFormat = "($opt_attr^)? attr-dict";
 }
 
 // Test that we elide attributes that are within the syntax.

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 47255d47f8a7..3a3c500d76b3 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -206,10 +206,10 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
 def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
   ($attr)? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first element of an operand group must be a literal or operand
+// CHECK: error: first element of an operand group must be an attribute, literal, or operand
 def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
-  ($attr^)? attr-dict
-}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+  (type($operand) $operand^)? attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
 // CHECK: error: type directive can only refer to variables within the optional group
 def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
   (`,` $attr^ type(operands))? attr-dict

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 49ac3d26f926..af5976b22706 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -12,9 +12,15 @@ test.format_literal_op keyword_$. -> :, = <> () [] {foo.some_attr}
 // CHECK-NOT: {attr
 test.format_attr_op 10
 
-// CHECK: test.format_opt_attr_op(10)
+// CHECK: test.format_opt_attr_op_a(10)
 // CHECK-NOT: {opt_attr
-test.format_opt_attr_op(10)
+test.format_opt_attr_op_a(10)
+test.format_opt_attr_op_a
+
+// CHECK: test.format_opt_attr_op_b 10
+// CHECK-NOT: {opt_attr
+test.format_opt_attr_op_b 10
+test.format_opt_attr_op_b
 
 // CHECK: test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
 test.format_attr_dict_w_keyword attributes {attr = 10 : i64}

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3fcbeeff1e6f..13f2a2fd96dc 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -373,6 +373,15 @@ const char *const attrParserCode = R"(
   if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
     return failure();
 )";
+const char *const optionalAttrParserCode = R"(
+  {0} {1}Attr;
+  {
+    ::mlir::OptionalParseResult parseResult =
+      parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
+    if (parseResult.hasValue() && failed(*parseResult))
+      return failure();
+  }
+)";
 
 /// The code snippet used to generate a parser call for an enum attribute.
 ///
@@ -397,6 +406,30 @@ const char *const enumAttrParserCode = R"(
     result.addAttribute("{0}", {3});
   }
 )";
+const char *const optionalEnumAttrParserCode = R"(
+  Attribute {0}Attr;
+  {
+    ::mlir::StringAttr attrVal;
+    ::mlir::NamedAttrList attrStorage;
+    auto loc = parser.getCurrentLocation();
+
+    ::mlir::OptionalParseResult parseResult =
+      parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(),
+                                    "{0}", attrStorage);
+    if (parseResult.hasValue()) {
+      if (failed(*parseResult))
+        return failure();
+
+      auto attrOptional = {1}::{2}(attrVal.getValue());
+      if (!attrOptional)
+        return parser.emitError(loc, "invalid ")
+               << "{0} attribute specification: " << attrVal;
+
+      {0}Attr = {3};
+      result.addAttribute("{0}", {0}Attr);
+    }
+  }
+)";
 
 /// The code snippet used to generate a parser call for an operand.
 ///
@@ -599,11 +632,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
 
     // Generate a special optional parser for the first element to gate the
     // parsing of the rest of the elements.
-    if (auto *literal = dyn_cast<LiteralElement>(&*elements.begin())) {
+    Element *firstElement = &*elements.begin();
+    if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
+      genElementParser(attrVar, body, attrTypeCtx);
+      body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
+    } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
       body << "  if (succeeded(parser.parseOptional";
       genLiteralParser(literal->getLiteral(), body);
       body << ")) {\n";
-    } else if (auto *opVar = dyn_cast<OperandVariable>(&*elements.begin())) {
+    } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
       genElementParser(opVar, body, attrTypeCtx);
       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
     }
@@ -635,7 +672,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
                     "attrOptional.getValue()");
       }
 
-      body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
+      body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode
+                                             : enumAttrParserCode,
+                      var->name, enumAttr.getCppNamespace(),
                       enumAttr.getStringToSymbolFnName(), attrBuilderStr);
       return;
     }
@@ -648,8 +687,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
       os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
     }
 
-    body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
-                    attrTypeStr);
+    body << formatv(var->attr.isOptional() ? optionalAttrParserCode
+                                           : attrParserCode,
+                    var->attr.getStorageType(), var->name, attrTypeStr);
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
     StringRef name = operand->getVar()->name;
@@ -1910,10 +1950,11 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
 
   // The first element of the group must be one that can be parsed/printed in an
   // optional fashion.
-  if (!isa<LiteralElement>(&*elements.front()) &&
-      !isa<OperandVariable>(&*elements.front()))
-    return emitError(curLoc, "first element of an operand group must be a "
-                             "literal or operand");
+  Element *firstElement = &*elements.front();
+  if (!isa<AttributeVariable>(firstElement) &&
+      !isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
+    return emitError(curLoc, "first element of an operand group must be an "
+                             "attribute, literal, or operand");
 
   // After parsing all of the elements, ensure that all type directives refer
   // only to elements within the group.


        


More information about the Mlir-commits mailing list