[Mlir-commits] [mlir] 6b476e2 - [mlir] Add support for parsing optional Attribute values.
River Riddle
llvmlistbot at llvm.org
Tue Jul 14 13:21:36 PDT 2020
Author: River Riddle
Date: 2020-07-14T13:14:59-07:00
New Revision: 6b476e2426e9cfa442dac5deed2ceae890513f18
URL: https://github.com/llvm/llvm-project/commit/6b476e2426e9cfa442dac5deed2ceae890513f18
DIFF: https://github.com/llvm/llvm-project/commit/6b476e2426e9cfa442dac5deed2ceae890513f18.diff
LOG: [mlir] Add support for parsing optional Attribute values.
This adds a `parseOptionalAttribute` method to the OpAsmParser that allows for parsing optional attributes, in a similar fashion to how optional types are parsed. This also enables the use of attribute values as the first element of an assembly format optional group.
Differential Revision: https://reviews.llvm.org/D83712
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index a10610f87a0a..c068aac09bab 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -713,7 +713,8 @@ of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
-* The first element of the group must either be a literal or an operand.
+* The first element of the group must either be a literal, attribute, or an
+ operand.
- This is because the first element must be optionally parsable.
* Exactly one argument variable within the group must be marked as the anchor
of the group.
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 20660be4347c..0124ef5f7c0a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -384,6 +384,17 @@ class OpAsmParser {
StringRef attrName,
NamedAttrList &attrs) = 0;
+ /// Parse an optional attribute.
+ virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
+ Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) = 0;
+ OptionalParseResult parseOptionalAttribute(Attribute &result,
+ StringRef attrName,
+ NamedAttrList &attrs) {
+ return parseOptionalAttribute(result, Type(), attrName, attrs);
+ }
+
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index e2860b115231..1c1261e6d765 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -187,6 +187,40 @@ Attribute Parser::parseAttribute(Type type) {
}
}
+/// Parse an optional attribute with the provided type.
+OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
+ Type type) {
+ switch (getToken().getKind()) {
+ case Token::at_identifier:
+ case Token::floatliteral:
+ case Token::integer:
+ case Token::hash_identifier:
+ case Token::kw_affine_map:
+ case Token::kw_affine_set:
+ case Token::kw_dense:
+ case Token::kw_false:
+ case Token::kw_loc:
+ case Token::kw_opaque:
+ case Token::kw_sparse:
+ case Token::kw_true:
+ case Token::kw_unit:
+ case Token::l_brace:
+ case Token::l_square:
+ case Token::minus:
+ case Token::string:
+ attribute = parseAttribute(type);
+ return success(attribute != nullptr);
+
+ default:
+ // Parse an optional type attribute.
+ Type type;
+ OptionalParseResult result = parseOptionalType(type);
+ if (result.hasValue() && succeeded(*result))
+ attribute = TypeAttr::get(type);
+ return result;
+ }
+}
+
/// Attribute dictionary.
///
/// attribute-dict ::= `{` `}`
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index fc9d449ecc14..3a995a4e2b04 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1011,6 +1011,17 @@ class CustomOpAsmParser : public OpAsmParser {
return success();
}
+ /// Parse an optional attribute.
+ OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ OptionalParseResult parseResult =
+ parser.parseOptionalAttribute(result, type);
+ if (parseResult.hasValue() && succeeded(*parseResult))
+ attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+ return parseResult;
+ }
+
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
if (parser.getToken().isNot(Token::l_brace))
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 3d82d622bf06..3b2c6e852544 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -184,6 +184,10 @@ class Parser {
/// Parse an arbitrary attribute with an optional type.
Attribute parseAttribute(Type type = {});
+ /// Parse an optional attribute with the provided type.
+ OptionalParseResult parseOptionalAttribute(Attribute &attribute,
+ Type type = {});
+
/// Parse an attribute dictionary.
ParseResult parseAttributeDict(NamedAttrList &attributes);
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 19e636b3df32..e73c7c3f3230 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1253,9 +1253,13 @@ def FormatAttrOp : TEST_Op<"format_attr_op"> {
}
// Test that we elide optional attributes that are within the syntax.
-def FormatOptAttrOp : TEST_Op<"format_opt_attr_op"> {
+def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> {
let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
- let assemblyFormat = "(`(`$opt_attr^`)`)? attr-dict";
+ let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict";
+}
+def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> {
+ let arguments = (ins OptionalAttr<I64Attr>:$opt_attr);
+ let assemblyFormat = "($opt_attr^)? attr-dict";
}
// Test that we elide attributes that are within the syntax.
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 47255d47f8a7..3a3c500d76b3 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -206,10 +206,10 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first element of an operand group must be a literal or operand
+// CHECK: error: first element of an operand group must be an attribute, literal, or operand
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
- ($attr^)? attr-dict
-}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+ (type($operand) $operand^)? attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
// CHECK: error: type directive can only refer to variables within the optional group
def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
(`,` $attr^ type(operands))? attr-dict
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 49ac3d26f926..af5976b22706 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -12,9 +12,15 @@ test.format_literal_op keyword_$. -> :, = <> () [] {foo.some_attr}
// CHECK-NOT: {attr
test.format_attr_op 10
-// CHECK: test.format_opt_attr_op(10)
+// CHECK: test.format_opt_attr_op_a(10)
// CHECK-NOT: {opt_attr
-test.format_opt_attr_op(10)
+test.format_opt_attr_op_a(10)
+test.format_opt_attr_op_a
+
+// CHECK: test.format_opt_attr_op_b 10
+// CHECK-NOT: {opt_attr
+test.format_opt_attr_op_b 10
+test.format_opt_attr_op_b
// CHECK: test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
test.format_attr_dict_w_keyword attributes {attr = 10 : i64}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3fcbeeff1e6f..13f2a2fd96dc 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -373,6 +373,15 @@ const char *const attrParserCode = R"(
if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
return failure();
)";
+const char *const optionalAttrParserCode = R"(
+ {0} {1}Attr;
+ {
+ ::mlir::OptionalParseResult parseResult =
+ parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
+ if (parseResult.hasValue() && failed(*parseResult))
+ return failure();
+ }
+)";
/// The code snippet used to generate a parser call for an enum attribute.
///
@@ -397,6 +406,30 @@ const char *const enumAttrParserCode = R"(
result.addAttribute("{0}", {3});
}
)";
+const char *const optionalEnumAttrParserCode = R"(
+ Attribute {0}Attr;
+ {
+ ::mlir::StringAttr attrVal;
+ ::mlir::NamedAttrList attrStorage;
+ auto loc = parser.getCurrentLocation();
+
+ ::mlir::OptionalParseResult parseResult =
+ parser.parseOptionalAttribute(attrVal, parser.getBuilder().getNoneType(),
+ "{0}", attrStorage);
+ if (parseResult.hasValue()) {
+ if (failed(*parseResult))
+ return failure();
+
+ auto attrOptional = {1}::{2}(attrVal.getValue());
+ if (!attrOptional)
+ return parser.emitError(loc, "invalid ")
+ << "{0} attribute specification: " << attrVal;
+
+ {0}Attr = {3};
+ result.addAttribute("{0}", {0}Attr);
+ }
+ }
+)";
/// The code snippet used to generate a parser call for an operand.
///
@@ -599,11 +632,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
- if (auto *literal = dyn_cast<LiteralElement>(&*elements.begin())) {
+ Element *firstElement = &*elements.begin();
+ if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
+ genElementParser(attrVar, body, attrTypeCtx);
+ body << " if (" << attrVar->getVar()->name << "Attr) {\n";
+ } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (succeeded(parser.parseOptional";
genLiteralParser(literal->getLiteral(), body);
body << ")) {\n";
- } else if (auto *opVar = dyn_cast<OperandVariable>(&*elements.begin())) {
+ } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
genElementParser(opVar, body, attrTypeCtx);
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
}
@@ -635,7 +672,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
"attrOptional.getValue()");
}
- body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
+ body << formatv(var->attr.isOptional() ? optionalEnumAttrParserCode
+ : enumAttrParserCode,
+ var->name, enumAttr.getCppNamespace(),
enumAttr.getStringToSymbolFnName(), attrBuilderStr);
return;
}
@@ -648,8 +687,9 @@ static void genElementParser(Element *element, OpMethodBody &body,
os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
}
- body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
- attrTypeStr);
+ body << formatv(var->attr.isOptional() ? optionalAttrParserCode
+ : attrParserCode,
+ var->attr.getStorageType(), var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@@ -1910,10 +1950,11 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
// The first element of the group must be one that can be parsed/printed in an
// optional fashion.
- if (!isa<LiteralElement>(&*elements.front()) &&
- !isa<OperandVariable>(&*elements.front()))
- return emitError(curLoc, "first element of an operand group must be a "
- "literal or operand");
+ Element *firstElement = &*elements.front();
+ if (!isa<AttributeVariable>(firstElement) &&
+ !isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
+ return emitError(curLoc, "first element of an operand group must be an "
+ "attribute, literal, or operand");
// After parsing all of the elements, ensure that all type directives refer
// only to elements within the group.
More information about the Mlir-commits
mailing list