[Mlir-commits] [mlir] 7f2d9c2 - [mlir][ods] Support default-valued attributes in optional groups
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Oct 16 15:01:48 PDT 2022
Author: rkayaith
Date: 2022-10-16T18:01:39-04:00
New Revision: 7f2d9c21b49c8515769fba113631df0f492d6279
URL: https://github.com/llvm/llvm-project/commit/7f2d9c21b49c8515769fba113631df0f492d6279
DIFF: https://github.com/llvm/llvm-project/commit/7f2d9c21b49c8515769fba113631df0f492d6279.diff
LOG: [mlir][ods] Support default-valued attributes in optional groups
Add support for default-valued attributes as optional-group anchors. The
attribute is considered present if it holds a non-default value.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D134993
Added:
Modified:
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-invalid.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/test/mlir-tblgen/op-format.td
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7f329126b2892..85c5f3204b722 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2183,6 +2183,13 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
let assemblyFormat = "($attr^)? attr-dict";
}
+def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> {
+ let arguments = (ins DefaultValuedStrAttr<StrAttr, "default">:$str,
+ DefaultValuedStrAttr<SymbolNameAttr, "default">:$sym,
+ DefaultValuedAttr<SomeI64Enum, "SomeI64Enum::case5">:$e);
+ let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict";
+}
+
def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
let arguments = (ins UnitAttr:$isFirstBranchPresent);
let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";
diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index b790360d6df8c..a44a2e6a0c5c2 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -369,7 +369,7 @@ def OptionalInvalidE : TestFormat_Op<[{
def OptionalInvalidF : TestFormat_Op<[{
($attr^ $attr2^)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
-// CHECK: error: only optional attributes can be used to anchor an optional group
+// CHECK: error: only optional or default-valued attributes can be used to anchor an optional group
def OptionalInvalidG : TestFormat_Op<[{
($attr^)? attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index fb390d14bc960..cddae76b38637 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -197,6 +197,15 @@ test.format_optional_enum_attr case5
// CHECK-NOT: "case5"
test.format_optional_enum_attr
+// CHECK: test.format_optional_default_attrs "foo" @foo case10
+test.format_optional_default_attrs "foo" @foo case10
+
+// CHECK: test.format_optional_default_attr
+// CHECK-NOT: "default"
+// CHECK-NOT: @default
+// CHECK-NOT: case5
+test.format_optional_default_attrs "default" @default case5
+
//===----------------------------------------------------------------------===//
// Format optional operands and results
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index 5a7656ee60feb..91c0011ae660c 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -71,3 +71,13 @@ def OptionalGroupA : TestFormat_Op<[{
def OptionalGroupB : TestFormat_Op<[{
(`foo`) : (`bar` $a^)? attr-dict
}]>, Arguments<(ins UnitAttr:$a)>;
+
+// Optional group anchored on a default-valued attribute:
+// CHECK-LABEL: OptionalGroupC::parse
+// CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
+// CHECK-NEXT: odsPrinter << ' ';
+// CHECK-NEXT: odsPrinter.printAttributeWithoutType(getAAttr());
+// CHECK-NEXT: }
+def OptionalGroupC : TestFormat_Op<[{
+ ($a^)? attr-dict
+}]>, Arguments<(ins DefaultValuedStrAttr<StrAttr, "default">:$a)>;
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0079600928a65..fdbcd62b4546a 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1041,7 +1041,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
/// Generate the parser for a enum attribute.
static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
- FmtContext &attrTypeCtx) {
+ FmtContext &attrTypeCtx, bool parseAsOptional) {
Attribute baseAttr = var->attr.getBaseAttr();
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1065,7 +1065,7 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
// If the attribute is not optional, build an error message for the missing
// attribute.
std::string errorMessage;
- if (!var->attr.isOptional()) {
+ if (!parseAsOptional) {
llvm::raw_string_ostream errorMessageOS(errorMessage);
errorMessageOS
<< "return parser.emitError(loc, \"expected string or "
@@ -1082,6 +1082,43 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
validCaseKeywordsStr, errorMessage);
}
+// Generate the parser for an attribute.
+static void genAttrParser(AttributeVariable *attr, MethodBody &body,
+ FmtContext &attrTypeCtx, bool parseAsOptional) {
+ const NamedAttribute *var = attr->getVar();
+
+ // Check to see if we can parse this as an enum attribute.
+ if (canFormatEnumAttr(var))
+ return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional);
+
+ // Check to see if we should parse this as a symbol name attribute.
+ if (shouldFormatSymbolNameAttr(var)) {
+ body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
+ : symbolNameAttrParserCode,
+ var->name);
+ return;
+ }
+
+ // If this attribute has a buildable type, use that when parsing the
+ // attribute.
+ std::string attrTypeStr;
+ if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+ llvm::raw_string_ostream os(attrTypeStr);
+ os << tgfmt(*typeBuilder, &attrTypeCtx);
+ } else {
+ attrTypeStr = "::mlir::Type{}";
+ }
+ if (parseAsOptional) {
+ body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+ } else {
+ if (attr->shouldBeQualified() ||
+ var->attr.getStorageType() == "::mlir::Attribute")
+ body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+ else
+ body << formatv(attrParserCode, var->name, attrTypeStr);
+ }
+}
+
void OperationFormat::genParser(Operator &op, OpClass &opClass) {
SmallVector<MethodParameter> paramList;
paramList.emplace_back("::mlir::OpAsmParser &", "parser");
@@ -1153,7 +1190,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
// parsing of the rest of the elements.
FormatElement *firstElement = thenElements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
- genElementParser(attrVar, body, attrTypeCtx);
+ genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (::mlir::succeeded(parser.parseOptional";
@@ -1236,38 +1273,9 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
/// Arguments.
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
- const NamedAttribute *var = attr->getVar();
-
- // Check to see if we can parse this as an enum attribute.
- if (canFormatEnumAttr(var))
- return genEnumAttrParser(var, body, attrTypeCtx);
-
- // Check to see if we should parse this as a symbol name attribute.
- if (shouldFormatSymbolNameAttr(var)) {
- body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
- : symbolNameAttrParserCode,
- var->name);
- return;
- }
-
- // If this attribute has a buildable type, use that when parsing the
- // attribute.
- std::string attrTypeStr;
- if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
- llvm::raw_string_ostream os(attrTypeStr);
- os << tgfmt(*typeBuilder, &attrTypeCtx);
- } else {
- attrTypeStr = "::mlir::Type{}";
- }
- if (genCtx == GenContext::Normal && var->attr.isOptional()) {
- body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
- } else {
- if (attr->shouldBeQualified() ||
- var->attr.getStorageType() == "::mlir::Attribute")
- body << formatv(genericAttrParserCode, var->name, attrTypeStr);
- else
- body << formatv(attrParserCode, var->name, attrTypeStr);
- }
+ bool parseAsOptional =
+ (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
+ genAttrParser(attr, body, attrTypeCtx, parseAsOptional);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1872,8 +1880,22 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
.Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
})
- .Case<AttributeVariable>([&](AttributeVariable *attr) {
- body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")";
+ .Case<AttributeVariable>([&](AttributeVariable *element) {
+ Attribute attr = element->getVar()->attr;
+ body << "(*this)->getAttr(\"" << element->getVar()->name << "\")";
+ if (attr.isOptional())
+ return; // done
+ if (attr.hasDefaultValue()) {
+ // Consider a default-valued attribute as present if it's not the
+ // default value.
+ FmtContext fctx;
+ fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
+ body << " != "
+ << tgfmt(attr.getConstBuilderTemplate(), &fctx,
+ attr.getDefaultValue());
+ return;
+ }
+ llvm_unreachable("attribute must be optional or default-valued");
});
}
@@ -3185,9 +3207,10 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
// All attributes can be within the optional group, but only optional
// attributes can be the anchor.
.Case([&](AttributeVariable *attrEle) {
- if (isAnchor && !attrEle->getVar()->attr.isOptional())
- return emitError(loc, "only optional attributes can be used to "
- "anchor an optional group");
+ Attribute attr = attrEle->getVar()->attr;
+ if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
+ return emitError(loc, "only optional or default-valued attributes "
+ "can be used to anchor an optional group");
return success();
})
// Only optional-like(i.e. variadic) operands can be within an optional
More information about the Mlir-commits
mailing list