[Mlir-commits] [mlir] 95a33b4 - [mlir][ods] Format: allow anchors in the else elements

Jeff Niu llvmlistbot at llvm.org
Tue Sep 20 11:08:08 PDT 2022


Author: Jeff Niu
Date: 2022-09-20T11:07:50-07:00
New Revision: 95a33b455d74b8c0c112ad499c071117361dd403

URL: https://github.com/llvm/llvm-project/commit/95a33b455d74b8c0c112ad499c071117361dd403
DIFF: https://github.com/llvm/llvm-project/commit/95a33b455d74b8c0c112ad499c071117361dd403.diff

LOG: [mlir][ods] Format: allow anchors in the else elements

This patch changes optional groups to allow anchors in the 'else'
element group. When printing, the optional condition is inverted to
decide which group to print. This is useful for parsing concrete
optional elements that don't have a `parseOptional*` method or some
other way to test whether it's present.

Depends on D133805

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D133812

Added: 
    

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/docs/OpDefinitions.md
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/test/mlir-tblgen/op-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.h
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 748b69fd86ee0..4821e63f515f4 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -660,8 +660,9 @@ set to `llvm::None` and `Attribute` will be set to `nullptr`.
 
 Only optional parameters or directives that only capture optional parameters can
 be used in optional groups. An optional group is a set of elements optionally
-printed based on the presence of an anchor. Suppose parameter `a` is an
-`IntegerAttr`.
+printed based on the presence of an anchor. The group in which the anchor is
+placed is printed if it is present, otherwise the other one is printed. Suppose
+parameter `a` is an `IntegerAttr`.
 
 ```
 ( `(` $a^ `)` ) : (`x`)?

diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index fe8dfc0db15bb..754fb641c9392 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -856,17 +856,18 @@ of the assembly format can be marked as `optional` based on the presence of this
 information. An optional group is defined as follows:
 
 ```
-optional-group: `(` elements `)` (`:` `(` else-elements `)`)? `?`
+optional-group: `(` then-elements `)` (`:` `(` else-elements `)`)? `?`
 ```
 
-The `elements` of an optional group have the following requirements:
+The elements of an optional group have the following requirements:
 
-*   The first element of the group must either be a attribute, literal, operand,
-    or region.
+*   The first element of `then-elements` must either be a attribute, literal,
+    operand, or region.
     -   This is because the first element must be optionally parsable.
-*   Exactly one argument variable or type directive within the group must be
-    marked as the anchor of the group.
-    -   The anchor is the element whose presence controls whether the group
+*   Exactly one argument variable or type directive within either
+    `then-elements` or `else-elements` must be marked as the anchor of the
+    group.
+    -   The anchor is the element whose presence controls which elements
         should be printed/parsed.
     -   An element is marked as the anchor by adding a trailing `^`.
     -   The first element is *not* required to be the anchor of the group.

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 80d3c5abdc84d..86ae51e3461d6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -797,6 +797,11 @@ def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> {
   }];
 }
 
+def ElseAnchorOp : TEST_Op<"else_anchor"> {
+  let arguments = (ins Optional<AnyType>:$a);
+  let assemblyFormat = "`(` (`?`) : (`` $a^ `:` type($a))? `)` attr-dict";
+}
+
 // This is used to test encoding of a string attribute into an SSA name of a
 // pretty printed value name.
 def StringAttrPrettyNameOp

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 595809474ed19..8272efae32338 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -332,4 +332,17 @@ def TestTypeCustomString : Test_Type<"TestTypeCustomString"> {
                               custom<BarString>(ref($foo)) `>` }];
 }
 
+def TestTypeElseAnchor : Test_Type<"TestTypeElseAnchor"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a);
+  let mnemonic = "else_anchor";
+  let assemblyFormat = "`<` (`?`) : ($a^)? `>`";
+}
+
+def TestTypeElseAnchorStruct : Test_Type<"TestTypeElseAnchorStruct"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "else_anchor_struct";
+  let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 937c35d503769..1ecb688ccd29c 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -571,3 +571,39 @@ def TypeL : TestType<"TestN"> {
   let mnemonic = "type_l";
   let assemblyFormat = [{ custom<Foo>($a, "1") }];
 }
+
+// TYPE-LABEL: ::mlir::Type TestOType::parse
+// TYPE: if (odsParser.parseOptionalQuestion())
+// TYPE: _result_a =
+// TYPE: else
+
+// TYPE-LABEL: void TestOType::print
+// TYPE: if (!((getA())))
+// TYPE: odsPrinter << ' ' << "?"
+// TYPE: else
+// TYPE: odsPrinter.printStrippedAttrOrType(getA())
+
+def TypeM : TestType<"TestO"> {
+  let parameters = (ins OptionalParameter<"int">:$a);
+  let mnemonic = "type_m";
+  let assemblyFormat = "(`?`) : ($a^)?";
+}
+
+// TYPE-LABEL: ::mlir::Type TestPType::parse
+// TYPE: if (odsParser.parseOptionalQuestion())
+// TYPE: bool _seen_a
+// TYPE: bool _seen_b
+// TYPE: _loop_body(_paramKey))
+// TYPE: else {
+// TYPE-NEXT: }
+
+// TYPE-LABEL: void TestPType::print
+// TYPE: if (!((getA()) || (getB())))
+// TYPE-NEXT: odsPrinter << "?"
+
+def TypeN : TestType<"TestP"> {
+  let parameters = (ins OptionalParameter<"int">:$a,
+                        OptionalParameter<"int">:$b);
+  let mnemonic = "type_n";
+  let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
+}

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index ae28910bfa5a6..68ed88c136c96 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -463,3 +463,18 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
 
 // CHECK: test.has_str_value
 test.has_str_value {}
+
+//===----------------------------------------------------------------------===//
+// ElseAnchorOp
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @else_anchor_op
+func.func @else_anchor_op(%a: !test.else_anchor<?>, %b: !test.else_anchor<5>) {
+  // CHECK: test.else_anchor(?) {a = !test.else_anchor_struct<?>}
+  test.else_anchor(?) {a = !test.else_anchor_struct<?>}
+  // CHECK: test.else_anchor(%{{.*}} : !test.else_anchor<?>) {a = !test.else_anchor_struct<a = 0>}
+  test.else_anchor(%a : !test.else_anchor<?>) {a = !test.else_anchor_struct<a = 0>}
+  // CHECK: test.else_anchor(%{{.*}} : !test.else_anchor<5>) {a = !test.else_anchor_struct<b = 0>}
+  test.else_anchor(%b : !test.else_anchor<5>) {a = !test.else_anchor_struct<b = 0>}
+  return
+}

diff  --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index 1fdb485e5e56c..cfa8dbcbf1834 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -40,3 +40,34 @@ def CustomStringLiteralB : TestFormat_Op<[{
 def CustomStringLiteralC : TestFormat_Op<[{
   custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
 }]>;
+
+//===----------------------------------------------------------------------===//
+// Optional Groups
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: OptionalGroupA::parse
+// CHECK: if (::mlir::succeeded(parser.parseOptionalQuestion())
+// CHECK-NEXT: else
+// CHECK: parser.parseOptionalOperand
+// CHECK-LABEL: OptionalGroupA::print
+// CHECK: if (!getA())
+// CHECK-NEXT: odsPrinter << ' ' << "?";
+// CHECK-NEXT: else
+// CHECK: odsPrinter << value;
+def OptionalGroupA : TestFormat_Op<[{
+  (`?`) : ($a^)? attr-dict
+}]>, Arguments<(ins Optional<I1>:$a)>;
+
+// CHECK-LABEL: OptionalGroupB::parse
+// CHECK: if (::mlir::succeeded(parser.parseOptionalKeyword("foo")))
+// CHECK-NEXT: else
+// CHECK-NEXT: result.addAttribute("a", parser.getBuilder().getUnitAttr())
+// CHECK: parser.parseKeyword("bar")
+// CHECK-LABEL: OptionalGroupB::print
+// CHECK: if (!(*this)->getAttr("a"))
+// CHECK-NEXT: odsPrinter << ' ' << "foo"
+// CHECK-NEXT: else
+// CHECK-NEXT: odsPrinter << ' ' << "bar"
+def OptionalGroupB : TestFormat_Op<[{
+  (`foo`) : (`bar` $a^)? attr-dict
+}]>, Arguments<(ins UnitAttr:$a)>;

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 7ccaccef8a5c8..655f40dbe1c47 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -656,10 +656,10 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
 
 void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
                                        MethodBody &os) {
-  ArrayRef<FormatElement *> elements =
-      el->getThenElements().drop_front(el->getParseStart());
+  ArrayRef<FormatElement *> thenElements =
+      el->getThenElements(/*parseable=*/true);
 
-  FormatElement *first = elements.front();
+  FormatElement *first = thenElements.front();
   const auto guardOn = [&](auto params) {
     os << "if (!(";
     llvm::interleave(
@@ -687,12 +687,12 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
   }
   os.indent();
 
-  // Generate the parsers for the rest of the elements.
-  for (FormatElement *element : el->getElseElements())
+  // Generate the parsers for the rest of the thenElements.
+  for (FormatElement *element : el->getElseElements(/*parseable=*/true))
     genElementParser(element, ctx, os);
   os.unindent() << "} else {\n";
   os.indent();
-  for (FormatElement *element : elements.drop_front())
+  for (FormatElement *element : thenElements.drop_front())
     genElementParser(element, ctx, os);
   os.unindent() << "}\n";
 }
@@ -781,12 +781,16 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
 
 /// Generate code to guard printing on the presence of any optional parameters.
 template <typename ParameterRange>
-static void guardOnAny(FmtContext &ctx, MethodBody &os,
-                       ParameterRange &&params) {
+static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
+                       bool inverted = false) {
   os << "if (";
+  if (inverted)
+    os << "!(";
   llvm::interleave(
       params, os,
       [&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
+  if (inverted)
+    os << ")";
   os << ") {\n";
   os.indent();
 }
@@ -860,12 +864,12 @@ void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
                                         MethodBody &os) {
   FormatElement *anchor = el->getAnchor();
   if (auto *param = dyn_cast<ParameterElement>(anchor)) {
-    guardOnAny(ctx, os, llvm::makeArrayRef(param));
+    guardOnAny(ctx, os, llvm::makeArrayRef(param), el->isInverted());
   } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
-    guardOnAny(ctx, os, params->getParams());
+    guardOnAny(ctx, os, params->getParams(), el->isInverted());
   } else {
     auto *strct = cast<StructDirective>(anchor);
-    guardOnAny(ctx, os, strct->getParams());
+    guardOnAny(ctx, os, strct->getParams(), el->isInverted());
   }
   // Generate the printer for the contained elements.
   {

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 84dd03d28ac9f..5948415dc4876 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -321,35 +321,42 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
   // Parse the child elements for this optional group.
   std::vector<FormatElement *> thenElements, elseElements;
   FormatElement *anchor = nullptr;
-  do {
-    FailureOr<FormatElement *> element = parseElement(TopLevelContext);
-    if (failed(element))
-      return failure();
-    // Check for an anchor.
-    if (curToken.is(FormatToken::caret)) {
-      if (anchor)
-        return emitError(curToken.getLoc(), "only one element can be marked as "
-                                            "the anchor of an optional group");
-      anchor = *element;
-      consumeToken();
-    }
-    thenElements.push_back(*element);
-  } while (!curToken.is(FormatToken::r_paren));
+  auto parseChildElements =
+      [this, &anchor](std::vector<FormatElement *> &elements) -> LogicalResult {
+    do {
+      FailureOr<FormatElement *> element = parseElement(TopLevelContext);
+      if (failed(element))
+        return failure();
+      // Check for an anchor.
+      if (curToken.is(FormatToken::caret)) {
+        if (anchor) {
+          return emitError(curToken.getLoc(),
+                           "only one element can be marked as the anchor of an "
+                           "optional group");
+        }
+        anchor = *element;
+        consumeToken();
+      }
+      elements.push_back(*element);
+    } while (!curToken.is(FormatToken::r_paren));
+    return success();
+  };
+
+  // Parse the 'then' elements. If the anchor was found in this group, then the
+  // optional is not inverted.
+  if (failed(parseChildElements(thenElements)))
+    return failure();
   consumeToken();
+  bool inverted = !anchor;
 
   // Parse the `else` elements of this optional group.
   if (curToken.is(FormatToken::colon)) {
     consumeToken();
-    if (failed(
-            parseToken(FormatToken::l_paren,
-                       "expected '(' to start else branch of optional group")))
+    if (failed(parseToken(
+            FormatToken::l_paren,
+            "expected '(' to start else branch of optional group")) ||
+        failed(parseChildElements(elseElements)))
       return failure();
-    do {
-      FailureOr<FormatElement *> element = parseElement(TopLevelContext);
-      if (failed(element))
-        return failure();
-      elseElements.push_back(*element);
-    } while (!curToken.is(FormatToken::r_paren));
     consumeToken();
   }
   if (failed(parseToken(FormatToken::question,
@@ -367,17 +374,21 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
 
   // Get the first parsable element. It must be an element that can be
   // optionally-parsed.
-  auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
+  auto isWhitespace = [](FormatElement *element) {
     return isa<WhitespaceElement>(element);
-  });
-  if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
+  };
+  auto thenParseBegin = llvm::find_if_not(thenElements, isWhitespace);
+  auto elseParseBegin = llvm::find_if_not(elseElements, isWhitespace);
+  unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
+  unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);
+
+  if (!isa<LiteralElement, VariableElement>(*thenParseBegin)) {
     return emitError(loc, "first parsable element of an optional group must be "
                           "a literal or variable");
   }
-
-  unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
   return create<OptionalElement>(std::move(thenElements),
-                                 std::move(elseElements), anchor, parseStart);
+                                 std::move(elseElements), thenParseStart,
+                                 elseParseStart, anchor, inverted);
 }
 
 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 3230a591a0206..60264adcebc9a 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -378,33 +378,48 @@ class OptionalElement : public FormatElementBase<FormatElement::Optional> {
   /// Create an optional group with the given child elements.
   OptionalElement(std::vector<FormatElement *> &&thenElements,
                   std::vector<FormatElement *> &&elseElements,
-                  FormatElement *anchor, unsigned parseStart)
+                  unsigned thenParseStart, unsigned elseParseStart,
+                  FormatElement *anchor, bool inverted)
       : thenElements(std::move(thenElements)),
-        elseElements(std::move(elseElements)), anchor(anchor),
-        parseStart(parseStart) {}
-
-  /// Return the `then` elements of the optional group.
-  ArrayRef<FormatElement *> getThenElements() const { return thenElements; }
+        elseElements(std::move(elseElements)), thenParseStart(thenParseStart),
+        elseParseStart(elseParseStart), anchor(anchor), inverted(inverted) {}
+
+  /// Return the `then` elements of the optional group. Drops the first
+  /// `thenParseStart` whitespace elements if `parseable` is true.
+  ArrayRef<FormatElement *> getThenElements(bool parseable = false) const {
+    return llvm::makeArrayRef(thenElements)
+        .drop_front(parseable ? thenParseStart : 0);
+  }
 
-  /// Return the `else` elements of the optional group.
-  ArrayRef<FormatElement *> getElseElements() const { return elseElements; }
+  /// Return the `else` elements of the optional group. Drops the first
+  /// `elseParseStart` whitespace elements if `parseable` is true.
+  ArrayRef<FormatElement *> getElseElements(bool parseable = false) const {
+    return llvm::makeArrayRef(elseElements)
+        .drop_front(parseable ? elseParseStart : 0);
+  }
 
   /// Return the anchor of the optional group.
   FormatElement *getAnchor() const { return anchor; }
 
-  /// Return the index of the first element to be parsed.
-  unsigned getParseStart() const { return parseStart; }
+  /// Return true if the optional group is inverted.
+  bool isInverted() const { return inverted; }
 
 private:
   /// The child elements emitted when the anchor is present.
   std::vector<FormatElement *> thenElements;
   /// The child elements emitted when the anchor is not present.
   std::vector<FormatElement *> elseElements;
-  /// The anchor element of the optional group.
-  FormatElement *anchor;
   /// The index of the first element that is parsed in `thenElements`. That is,
   /// the first non-whitespace element.
-  unsigned parseStart;
+  unsigned thenParseStart;
+  /// The index of the first element that is parsed in `elseElements`. That is,
+  /// the first non-whitespace element.
+  unsigned elseParseStart;
+  /// The anchor element of the optional group.
+  FormatElement *anchor;
+  /// Whether the optional group condition is inverted and the anchor element is
+  /// in the else group.
+  bool inverted;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 38a1ed35baf36..90b69c1925af2 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1119,17 +1119,43 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
                                        GenContext genCtx) {
   /// Optional Group.
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
-    ArrayRef<FormatElement *> elements =
-        optional->getThenElements().drop_front(optional->getParseStart());
+    auto genElementParsers = [&](FormatElement *firstElement,
+                                 ArrayRef<FormatElement *> elements,
+                                 bool thenGroup) {
+      // If the anchor is a unit attribute, we don't need to print it. When
+      // parsing, we will add this attribute if this group is present.
+      FormatElement *elidedAnchorElement = nullptr;
+      auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
+      if (anchorAttr && anchorAttr != firstElement &&
+          anchorAttr->isUnitAttr()) {
+        elidedAnchorElement = anchorAttr;
+
+        if (!thenGroup == optional->isInverted()) {
+          // Add the anchor unit attribute to the operation state.
+          body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
+               << "\", parser.getBuilder().getUnitAttr());\n";
+        }
+      }
+
+      // Generate the rest of the elements inside an optional group. Elements in
+      // an optional group after the guard are parsed as required.
+      for (FormatElement *childElement : elements)
+        if (childElement != elidedAnchorElement)
+          genElementParser(childElement, body, attrTypeCtx,
+                           GenContext::Optional);
+    };
+
+    ArrayRef<FormatElement *> thenElements =
+        optional->getThenElements(/*parseable=*/true);
 
     // Generate a special optional parser for the first element to gate the
     // parsing of the rest of the elements.
-    FormatElement *firstElement = elements.front();
+    FormatElement *firstElement = thenElements.front();
     if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
       genElementParser(attrVar, body, attrTypeCtx);
       body << "  if (" << attrVar->getVar()->name << "Attr) {\n";
     } else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
-      body << "  if (succeeded(parser.parseOptional";
+      body << "  if (::mlir::succeeded(parser.parseOptional";
       genLiteralParser(literal->getSpelling(), body);
       body << ")) {\n";
     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
@@ -1151,31 +1177,18 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
       }
     }
 
-    // If the anchor is a unit attribute, we don't need to print it. When
-    // parsing, we will add this attribute if this group is present.
-    FormatElement *elidedAnchorElement = nullptr;
-    auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
-    if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
-      elidedAnchorElement = anchorAttr;
-
-      // Add the anchor unit attribute to the operation state.
-      body << "    result.addAttribute(\"" << anchorAttr->getVar()->name
-           << "\", parser.getBuilder().getUnitAttr());\n";
-    }
-
-    // Generate the rest of the elements inside an optional group. Elements in
-    // an optional group after the guard are parsed as required.
-    for (FormatElement *childElement : llvm::drop_begin(elements, 1))
-      if (childElement != elidedAnchorElement)
-        genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
+    genElementParsers(firstElement, thenElements.drop_front(),
+                      /*thenGroup=*/true);
     body << "  }";
 
     // Generate the else elements.
     auto elseElements = optional->getElseElements();
     if (!elseElements.empty()) {
       body << " else {\n";
-      for (FormatElement *childElement : elseElements)
-        genElementParser(childElement, body, attrTypeCtx);
+      ArrayRef<FormatElement *> elseElements =
+          optional->getElseElements(/*parsable=*/true);
+      genElementParsers(elseElements.front(), elseElements,
+                        /*thenGroup=*/false);
       body << "  }";
     }
     body << "\n";
@@ -1842,15 +1855,15 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
         const NamedTypeConstraint *var = element->getVar();
         std::string name = op.getGetterName(var->name);
         if (var->isOptional())
-          body << "  if (" << name << "()) {\n";
+          body << name << "()";
         else if (var->isVariadic())
-          body << "  if (!" << name << "().empty()) {\n";
+          body << "!" << name << "().empty()";
       })
       .Case<RegionVariable>([&](RegionVariable *element) {
         const NamedRegion *var = element->getVar();
         std::string name = op.getGetterName(var->name);
         // TODO: Add a check for optional regions here when ODS supports it.
-        body << "  if (!" << name << "().empty()) {\n";
+        body << "!" << name << "().empty()";
       })
       .Case<TypeDirective>([&](TypeDirective *element) {
         genOptionalGroupPrinterAnchor(element->getArg(), op, body);
@@ -1859,8 +1872,7 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
         genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
       })
       .Case<AttributeVariable>([&](AttributeVariable *attr) {
-        body << "  if ((*this)->getAttr(\"" << attr->getVar()->name
-             << "\")) {\n";
+        body << "(*this)->getAttr(\"" << attr->getVar()->name << "\")";
       });
 }
 
@@ -1912,39 +1924,45 @@ void OperationFormat::genElementPrinter(FormatElement *element,
   if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
     // Emit the check for the presence of the anchor element.
     FormatElement *anchor = optional->getAnchor();
+    body << "  if (";
+    if (optional->isInverted())
+      body << "!";
     genOptionalGroupPrinterAnchor(anchor, op, body);
+    body << ") {\n";
+    body.indent();
 
     // If the anchor is a unit attribute, we don't need to print it. When
     // parsing, we will add this attribute if this group is present.
-    auto elements = optional->getThenElements();
+    ArrayRef<FormatElement *> thenElements = optional->getThenElements();
+    ArrayRef<FormatElement *> elseElements = optional->getElseElements();
     FormatElement *elidedAnchorElement = nullptr;
     auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
-    if (anchorAttr && anchorAttr != elements.front() &&
+    if (anchorAttr && anchorAttr != thenElements.front() &&
+        (elseElements.empty() || anchorAttr != elseElements.front()) &&
         anchorAttr->isUnitAttr()) {
       elidedAnchorElement = anchorAttr;
     }
+    auto genElementPrinters = [&](ArrayRef<FormatElement *> elements) {
+      for (FormatElement *childElement : elements) {
+        if (childElement != elidedAnchorElement) {
+          genElementPrinter(childElement, body, op, shouldEmitSpace,
+                            lastWasPunctuation);
+        }
+      }
+    };
 
     // Emit each of the elements.
-    for (FormatElement *childElement : elements) {
-      if (childElement != elidedAnchorElement) {
-        genElementPrinter(childElement, body, op, shouldEmitSpace,
-                          lastWasPunctuation);
-      }
-    }
-    body << "  }";
+    genElementPrinters(thenElements);
+    body << "}";
 
     // Emit each of the else elements.
-    auto elseElements = optional->getElseElements();
     if (!elseElements.empty()) {
       body << " else {\n";
-      for (FormatElement *childElement : elseElements) {
-        genElementPrinter(childElement, body, op, shouldEmitSpace,
-                          lastWasPunctuation);
-      }
-      body << "  }";
+      genElementPrinters(elseElements);
+      body << "}";
     }
 
-    body << "\n";
+    body.unindent() << "\n";
     return;
   }
 


        


More information about the Mlir-commits mailing list