[Mlir-commits] [mlir] 7f2d9c2 - [mlir][ods] Support default-valued attributes in optional groups

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Oct 16 15:01:48 PDT 2022


Author: rkayaith
Date: 2022-10-16T18:01:39-04:00
New Revision: 7f2d9c21b49c8515769fba113631df0f492d6279

URL: https://github.com/llvm/llvm-project/commit/7f2d9c21b49c8515769fba113631df0f492d6279
DIFF: https://github.com/llvm/llvm-project/commit/7f2d9c21b49c8515769fba113631df0f492d6279.diff

LOG: [mlir][ods] Support default-valued attributes in optional groups

Add support for default-valued attributes as optional-group anchors. The
attribute is considered present if it holds a non-default value.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134993

Added: 
    

Modified: 
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-invalid.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/test/mlir-tblgen/op-format.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 7f329126b2892..85c5f3204b722 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2183,6 +2183,13 @@ def FormatOptionalEnumAttr : TEST_Op<"format_optional_enum_attr"> {
   let assemblyFormat = "($attr^)? attr-dict";
 }
 
+def FormatOptionalDefaultAttrs : TEST_Op<"format_optional_default_attrs"> {
+  let arguments = (ins DefaultValuedStrAttr<StrAttr, "default">:$str,
+                       DefaultValuedStrAttr<SymbolNameAttr, "default">:$sym,
+                       DefaultValuedAttr<SomeI64Enum, "SomeI64Enum::case5">:$e);
+  let assemblyFormat = "($str^)? ($sym^)? ($e^)? attr-dict";
+}
+
 def FormatOptionalWithElse : TEST_Op<"format_optional_else"> {
   let arguments = (ins UnitAttr:$isFirstBranchPresent);
   let assemblyFormat = "(`then` $isFirstBranchPresent^):(`else`)? attr-dict";

diff  --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index b790360d6df8c..a44a2e6a0c5c2 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -369,7 +369,7 @@ def OptionalInvalidE : TestFormat_Op<[{
 def OptionalInvalidF : TestFormat_Op<[{
   ($attr^ $attr2^)? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
-// CHECK: error: only optional attributes can be used to anchor an optional group
+// CHECK: error: only optional or default-valued attributes can be used to anchor an optional group
 def OptionalInvalidG : TestFormat_Op<[{
   ($attr^)? attr-dict
 }]>, Arguments<(ins I64Attr:$attr)>;

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index fb390d14bc960..cddae76b38637 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -197,6 +197,15 @@ test.format_optional_enum_attr case5
 // CHECK-NOT: "case5"
 test.format_optional_enum_attr
 
+// CHECK: test.format_optional_default_attrs "foo" @foo case10
+test.format_optional_default_attrs "foo" @foo case10
+
+// CHECK: test.format_optional_default_attr
+// CHECK-NOT: "default"
+// CHECK-NOT: @default
+// CHECK-NOT: case5
+test.format_optional_default_attrs "default" @default case5
+
 //===----------------------------------------------------------------------===//
 // Format optional operands and results
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index 5a7656ee60feb..91c0011ae660c 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -71,3 +71,13 @@ def OptionalGroupA : TestFormat_Op<[{
 def OptionalGroupB : TestFormat_Op<[{
   (`foo`) : (`bar` $a^)? attr-dict
 }]>, Arguments<(ins UnitAttr:$a)>;
+
+// Optional group anchored on a default-valued attribute:
+// CHECK-LABEL: OptionalGroupC::parse
+//       CHECK: if ((*this)->getAttr("a") != ::mlir::OpBuilder((*this)->getContext()).getStringAttr("default")) {
+//  CHECK-NEXT:   odsPrinter << ' ';
+//  CHECK-NEXT:   odsPrinter.printAttributeWithoutType(getAAttr());
+//  CHECK-NEXT: }
+def OptionalGroupC : TestFormat_Op<[{
+  ($a^)? attr-dict
+}]>, Arguments<(ins DefaultValuedStrAttr<StrAttr, "default">:$a)>;

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0079600928a65..fdbcd62b4546a 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1041,7 +1041,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
 
 /// Generate the parser for a enum attribute.
 static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
-                              FmtContext &attrTypeCtx) {
+                              FmtContext &attrTypeCtx, bool parseAsOptional) {
   Attribute baseAttr = var->attr.getBaseAttr();
   const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
   std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
@@ -1065,7 +1065,7 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
   // If the attribute is not optional, build an error message for the missing
   // attribute.
   std::string errorMessage;
-  if (!var->attr.isOptional()) {
+  if (!parseAsOptional) {
     llvm::raw_string_ostream errorMessageOS(errorMessage);
     errorMessageOS
         << "return parser.emitError(loc, \"expected string or "
@@ -1082,6 +1082,43 @@ static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body,
                   validCaseKeywordsStr, errorMessage);
 }
 
+// Generate the parser for an attribute.
+static void genAttrParser(AttributeVariable *attr, MethodBody &body,
+                          FmtContext &attrTypeCtx, bool parseAsOptional) {
+  const NamedAttribute *var = attr->getVar();
+
+  // Check to see if we can parse this as an enum attribute.
+  if (canFormatEnumAttr(var))
+    return genEnumAttrParser(var, body, attrTypeCtx, parseAsOptional);
+
+  // Check to see if we should parse this as a symbol name attribute.
+  if (shouldFormatSymbolNameAttr(var)) {
+    body << formatv(parseAsOptional ? optionalSymbolNameAttrParserCode
+                                    : symbolNameAttrParserCode,
+                    var->name);
+    return;
+  }
+
+  // If this attribute has a buildable type, use that when parsing the
+  // attribute.
+  std::string attrTypeStr;
+  if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
+    llvm::raw_string_ostream os(attrTypeStr);
+    os << tgfmt(*typeBuilder, &attrTypeCtx);
+  } else {
+    attrTypeStr = "::mlir::Type{}";
+  }
+  if (parseAsOptional) {
+    body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
+  } else {
+    if (attr->shouldBeQualified() ||
+        var->attr.getStorageType() == "::mlir::Attribute")
+      body << formatv(genericAttrParserCode, var->name, attrTypeStr);
+    else
+      body << formatv(attrParserCode, var->name, attrTypeStr);
+  }
+}
+
 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
   SmallVector<MethodParameter> paramList;
   paramList.emplace_back("::mlir::OpAsmParser &", "parser");
@@ -1153,7 +1190,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
     // parsing of the rest of the elements.
     FormatElement *firstElement = thenElements.front();
     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
-      genElementParser(attrVar, body, attrTypeCtx);
+      genAttrParser(attrVar, body, attrTypeCtx, /*parseAsOptional=*/true);
       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
       body << "  if (::mlir::succeeded(parser.parseOptional";
@@ -1236,38 +1273,9 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
 
     /// Arguments.
   } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
-    const NamedAttribute *var = attr->getVar();
-
-    // Check to see if we can parse this as an enum attribute.
-    if (canFormatEnumAttr(var))
-      return genEnumAttrParser(var, body, attrTypeCtx);
-
-    // Check to see if we should parse this as a symbol name attribute.
-    if (shouldFormatSymbolNameAttr(var)) {
-      body << formatv(var->attr.isOptional() ? optionalSymbolNameAttrParserCode
-                                             : symbolNameAttrParserCode,
-                      var->name);
-      return;
-    }
-
-    // If this attribute has a buildable type, use that when parsing the
-    // attribute.
-    std::string attrTypeStr;
-    if (Optional<StringRef> typeBuilder = attr->getTypeBuilder()) {
-      llvm::raw_string_ostream os(attrTypeStr);
-      os << tgfmt(*typeBuilder, &attrTypeCtx);
-    } else {
-      attrTypeStr = "::mlir::Type{}";
-    }
-    if (genCtx == GenContext::Normal && var->attr.isOptional()) {
-      body << formatv(optionalAttrParserCode, var->name, attrTypeStr);
-    } else {
-      if (attr->shouldBeQualified() ||
-          var->attr.getStorageType() == "::mlir::Attribute")
-        body << formatv(genericAttrParserCode, var->name, attrTypeStr);
-      else
-        body << formatv(attrParserCode, var->name, attrTypeStr);
-    }
+    bool parseAsOptional =
+        (genCtx == GenContext::Normal && attr->getVar()->attr.isOptional());
+    genAttrParser(attr, body, attrTypeCtx, parseAsOptional);
 
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
@@ -1872,8 +1880,22 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
       .Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
         genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
       })
-      .Case<AttributeVariable>([&](AttributeVariable *attr) {
-        body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")";
+      .Case<AttributeVariable>([&](AttributeVariable *element) {
+        Attribute attr = element->getVar()->attr;
+        body << "(*this)->getAttr(\"" << element->getVar()->name << "\")";
+        if (attr.isOptional())
+          return; // done
+        if (attr.hasDefaultValue()) {
+          // Consider a default-valued attribute as present if it's not the
+          // default value.
+          FmtContext fctx;
+          fctx.withBuilder("::mlir::OpBuilder((*this)->getContext())");
+          body << " != "
+               << tgfmt(attr.getConstBuilderTemplate(), &fctx,
+                        attr.getDefaultValue());
+          return;
+        }
+        llvm_unreachable("attribute must be optional or default-valued");
       });
 }
 
@@ -3185,9 +3207,10 @@ LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
       // All attributes can be within the optional group, but only optional
       // attributes can be the anchor.
       .Case([&](AttributeVariable *attrEle) {
-        if (isAnchor && !attrEle->getVar()->attr.isOptional())
-          return emitError(loc, "only optional attributes can be used to "
-                                "anchor an optional group");
+        Attribute attr = attrEle->getVar()->attr;
+        if (isAnchor && !(attr.isOptional() || attr.hasDefaultValue()))
+          return emitError(loc, "only optional or default-valued attributes "
+                                "can be used to anchor an optional group");
         return success();
       })
       // Only optional-like(i.e. variadic) operands can be within an optional


        


More information about the Mlir-commits mailing list