[Mlir-commits] [mlir] 2d0477a - [mlir][DeclarativeParser] Add basic support for optional groups in the assembly format.

River Riddle llvmlistbot at llvm.org
Fri Feb 21 15:17:36 PST 2020


Author: River Riddle
Date: 2020-02-21T15:15:31-08:00
New Revision: 2d0477a003687588886ae6e9b59b9355f8bb6b8c

URL: https://github.com/llvm/llvm-project/commit/2d0477a003687588886ae6e9b59b9355f8bb6b8c
DIFF: https://github.com/llvm/llvm-project/commit/2d0477a003687588886ae6e9b59b9355f8bb6b8c.diff

LOG: [mlir][DeclarativeParser] Add basic support for optional groups in the assembly format.

When operations have optional attributes, or optional operands(i.e. empty variadic operands), the assembly format often has an optional section to represent these arguments. This revision adds basic support for defining an "optional group" in the assembly format to support this. An optional group is defined by wrapping a set of elements in `()` followed by `?` and requires the following:

* The first element of the group must be either a literal or an operand argument.
  - This is because the first element must be optionally parsable.
* There must be exactly one argument variable within the group that is marked as the anchor of the group. The anchor is the element whose presence controls whether the group should be printed/parsed. An element is marked as the anchor by adding a trailing `^`.
* The group must only contain literals, variables, and type directives.
  - Any attribute variables may be used, but only optional attributes can be marked as the anchor.
  - Only variadic, i.e. optional, operand arguments can be used.
  - The elements of a type directive must be defined within the same optional group.

An example of this can be seen with the assembly format for ReturnOp, which has a variadic number of operands.

```
def ReturnOp : ... {
  let arguments = (ins Variadic<AnyType>:$operands);

  // We only print the operands+types if there are a non-zero number
  // of operands.
  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
}
```

Differential Revision: https://reviews.llvm.org/D74681

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 57ad925b1b07..3f877319dced 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -619,6 +619,43 @@ the variables would be `$callee`  and `$args`.
 Attribute variables are printed with their respective value type, unless that
 value type is buildable. In those cases, the type of the attribute is elided.
 
+#### Optional Groups
+
+In certain situations operations may have "optional" information, e.g.
+attributes or an empty set of variadic operands. In these situtations a section
+of the assembly format can be marked as `optional` based on the presence of this
+information. An optional group is defined by wrapping a set of elements within
+`()` followed by a `?` and has the following requirements:
+
+*   The first element of the group must either be a literal or an operand.
+    -   This is because the first element must be optionally parsable.
+*   Exactly one argument variable within the group must be marked as the anchor
+    of the group.
+    -   The anchor is the element whose presence controls whether the group
+        should be printed/parsed.
+    -   An element is marked as the anchor by adding a trailing `^`.
+    -   The first element is *not* required to be the anchor of the group.
+*   Literals, variables, and type directives are the only valid elements within
+    the group.
+    -   Any attribute variable may be used, but only optional attributes can be
+        marked as the anchor.
+    -   Only variadic, i.e. optional, operand arguments can be used.
+    -   The operands to a type directive must be defined within the optional
+        group.
+
+An example of an operation with an optional group is `std.return`, which has a
+variadic number of operands.
+
+```
+def ReturnOp : ... {
+  let arguments = (ins Variadic<AnyType>:$operands);
+
+  // We only print the operands and types if there are a non-zero number
+  // of operands.
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+}
+```
+
 #### Requirements
 
 The format specification has a certain set of requirements that must be adhered

diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index fe28f8d7143f..1fc4330cefdd 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -1059,6 +1059,8 @@ def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
   let builders = [OpBuilder<
     "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
   >];
+
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
 def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape,

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index aa0a42812342..80f85e02289b 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -1736,21 +1736,6 @@ OpFoldResult UnsignedRemIOp::fold(ArrayRef<Attribute> operands) {
 // ReturnOp
 //===----------------------------------------------------------------------===//
 
-static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
-  SmallVector<OpAsmParser::OperandType, 2> opInfo;
-  SmallVector<Type, 2> types;
-  llvm::SMLoc loc = parser.getCurrentLocation();
-  return failure(parser.parseOperandList(opInfo) ||
-                 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
-                 parser.resolveOperands(opInfo, types, loc, result.operands));
-}
-
-static void print(OpAsmPrinter &p, ReturnOp op) {
-  p << "return";
-  if (op.getNumOperands() != 0)
-    p << ' ' << op.getOperands() << " : " << op.getOperandTypes();
-}
-
 static LogicalResult verify(ReturnOp op) {
   auto function = cast<FuncOp>(op.getParentOp());
 

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index dfc8d0caffe1..a49697b932bf 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -46,7 +46,7 @@ def DirectiveFunctionalTypeInvalidA : TestFormat_Op<"functype_invalid_a", [{
 def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{
   functional-type
 }]>;
-// CHECK: error: expected directive, literal, or variable
+// CHECK: error: expected directive, literal, variable, or optional group
 def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{
   functional-type(
 }]>;
@@ -54,7 +54,7 @@ def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{
 def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{
   functional-type(operands
 }]>;
-// CHECK: error: expected directive, literal, or variable
+// CHECK: error: expected directive, literal, variable, or optional group
 def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{
   functional-type(operands,
 }]>;
@@ -98,7 +98,7 @@ def DirectiveResultsInvalidA : TestFormat_Op<"operands_invalid_a", [{
 def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{
   type
 }]>;
-// CHECK: error: expected directive, literal, or variable
+// CHECK: error: expected directive, literal, variable, or optional group
 def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{
   type(
 }]>;
@@ -165,7 +165,7 @@ def LiteralInvalidA : TestFormat_Op<"literal_invalid_a", [{
   `1`
 }]>;
 // CHECK: error: unexpected end of file in literal
-// CHECK: error: expected directive, literal, or variable
+// CHECK: error: expected directive, literal, variable, or optional group
 def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{
   `
 }]>;
@@ -175,6 +175,55 @@ def LiteralValid : TestFormat_Op<"literal_valid", [{
   attr-dict
 }]>;
 
+//===----------------------------------------------------------------------===//
+// Optional Groups
+//===----------------------------------------------------------------------===//
+
+// CHECK: error: optional groups can only be used as top-level elements
+def OptionalInvalidA : TestFormat_Op<"optional_invalid_a", [{
+  type(($attr^)?) attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+// CHECK: error: expected directive, literal, variable, or optional group
+def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
+  () attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+// CHECK: error: optional group specified no anchor element
+def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
+  ($attr)? attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+// CHECK: error: first element of an operand group must be a literal or operand
+def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
+  ($attr^)? attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+// CHECK: error: type directive can only refer to variables within the optional group
+def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{
+  (`,` $attr^ type(operands))? attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
+// CHECK: error: only one element can be marked as the anchor of an optional group
+def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{
+  ($attr^ $attr2^) attr-dict
+}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
+// CHECK: error: only optional attributes can be used to anchor an optional group
+def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{
+  ($attr^) attr-dict
+}]>, Arguments<(ins I64Attr:$attr)>;
+// CHECK: error: only variadic operands can be used within an optional group
+def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{
+  ($arg^) attr-dict
+}]>, Arguments<(ins I64:$arg)>;
+// CHECK: error: only variables can be used to anchor an optional group
+def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{
+  ($arg type($arg)^) attr-dict
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: only literals, types, and variables can be used within an optional group
+def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
+  (attr-dict)
+}]>;
+// CHECK: error: expected '?' after optional group
+def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
+  ($arg^)
+}]>, Arguments<(ins Variadic<I64>:$arg)>;
+
 //===----------------------------------------------------------------------===//
 // Variables
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 62653bc2da03..b5aa24e6e394 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -58,6 +58,9 @@ class Element {
     AttributeVariable,
     OperandVariable,
     ResultVariable,
+
+    /// This element is an optional element.
+    Optional,
   };
   Element(Kind kind) : kind(kind) {}
   virtual ~Element() = default;
@@ -164,7 +167,7 @@ namespace {
 class LiteralElement : public Element {
 public:
   LiteralElement(StringRef literal)
-      : Element{Kind::Literal}, literal(literal){};
+      : Element{Kind::Literal}, literal(literal) {}
   static bool classof(const Element *element) {
     return element->getKind() == Kind::Literal;
   }
@@ -203,6 +206,36 @@ bool LiteralElement::isValidLiteral(StringRef value) {
   });
 }
 
+//===----------------------------------------------------------------------===//
+// OptionalElement
+
+namespace {
+/// This class represents a group of elements that are optionally emitted based
+/// upon an optional variable of the operation.
+class OptionalElement : public Element {
+public:
+  OptionalElement(std::vector<std::unique_ptr<Element>> &&elements,
+                  unsigned anchor)
+      : Element{Kind::Optional}, elements(std::move(elements)), anchor(anchor) {
+  }
+  static bool classof(const Element *element) {
+    return element->getKind() == Kind::Optional;
+  }
+
+  /// Return the nested elements of this grouping.
+  auto getElements() const { return llvm::make_pointee_range(elements); }
+
+  /// Return the anchor of this optional group.
+  Element *getAnchor() const { return elements[anchor].get(); }
+
+private:
+  /// The child elements of this optional.
+  std::vector<std::unique_ptr<Element>> elements;
+  /// The index of the element that acts as the anchor for the optional group.
+  unsigned anchor;
+};
+} // end anonymous namespace
+
 //===----------------------------------------------------------------------===//
 // OperationFormat
 //===----------------------------------------------------------------------===//
@@ -327,32 +360,26 @@ const char *const enumAttrParserCode = R"(
 const char *const variadicOperandParserCode = R"(
   llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();
   (void){0}OperandsLoc;
-  SmallVector<OpAsmParser::OperandType, 4> {0}Operands;
   if (parser.parseOperandList({0}Operands))
     return failure();
 )";
 const char *const operandParserCode = R"(
   llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();
   (void){0}OperandsLoc;
-  OpAsmParser::OperandType {0}RawOperands[1];
   if (parser.parseOperand({0}RawOperands[0]))
     return failure();
-  ArrayRef<OpAsmParser::OperandType> {0}Operands({0}RawOperands);
 )";
 
 /// The code snippet used to generate a parser call for a type list.
 ///
 /// {0}: The name for the type list.
 const char *const variadicTypeParserCode = R"(
-  SmallVector<Type, 1> {0}Types;
   if (parser.parseTypeList({0}Types))
     return failure();
 )";
 const char *const typeParserCode = R"(
-  Type {0}RawTypes[1] = {{nullptr};
   if (parser.parseType({0}RawTypes[0]))
     return failure();
-  ArrayRef<Type> {0}Types({0}RawTypes);
 )";
 
 /// The code snippet used to generate a parser call for a functional type.
@@ -363,8 +390,8 @@ const char *const functionalTypeParserCode = R"(
   FunctionType {0}__{1}_functionType;
   if (parser.parseType({0}__{1}_functionType))
     return failure();
-  ArrayRef<Type> {0}Types = {0}__{1}_functionType.getInputs();
-  ArrayRef<Type> {1}Types = {0}__{1}_functionType.getResults();
+  {0}Types = {0}__{1}_functionType.getInputs();
+  {1}Types = {0}__{1}_functionType.getResults();
 )";
 
 /// Get the name used for the type list for the given type directive operand.
@@ -388,25 +415,144 @@ static StringRef getTypeListName(Element *arg, bool &isVariadic) {
 
 /// Generate the parser for a literal value.
 static void genLiteralParser(StringRef value, OpMethodBody &body) {
-  body << "  if (parser.parse";
-
   // Handle the case of a keyword/identifier.
   if (value.front() == '_' || isalpha(value.front())) {
     body << "Keyword(\"" << value << "\")";
+    return;
+  }
+  body << (StringRef)llvm::StringSwitch<StringRef>(value)
+              .Case("->", "Arrow()")
+              .Case(":", "Colon()")
+              .Case(",", "Comma()")
+              .Case("=", "Equal()")
+              .Case("<", "Less()")
+              .Case(">", "Greater()")
+              .Case("(", "LParen()")
+              .Case(")", "RParen()")
+              .Case("[", "LSquare()")
+              .Case("]", "RSquare()");
+}
+
+/// Generate the storage code required for parsing the given element.
+static void genElementParserStorage(Element *element, OpMethodBody &body) {
+  if (auto *optional = dyn_cast<OptionalElement>(element)) {
+    for (auto &childElement : optional->getElements())
+      genElementParserStorage(&childElement, body);
+  } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
+    StringRef name = operand->getVar()->name;
+    if (operand->getVar()->isVariadic())
+      body << "  SmallVector<OpAsmParser::OperandType, 4> " << name
+           << "Operands;\n";
+    else
+      body << "  OpAsmParser::OperandType " << name << "RawOperands[1];\n"
+           << "  ArrayRef<OpAsmParser::OperandType> " << name << "Operands("
+           << name << "RawOperands);";
+  } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
+    bool variadic = false;
+    StringRef name = getTypeListName(dir->getOperand(), variadic);
+    if (variadic)
+      body << "  SmallVector<Type, 1> " << name << "Types;\n";
+    else
+      body << llvm::formatv("  Type {0}RawTypes[1];\n", name)
+           << llvm::formatv("  ArrayRef<Type> {0}Types({0}RawTypes);\n", name);
+  } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
+    bool ignored = false;
+    body << "  ArrayRef<Type> " << getTypeListName(dir->getInputs(), ignored)
+         << "Types;\n";
+    body << "  ArrayRef<Type> " << getTypeListName(dir->getResults(), ignored)
+         << "Types;\n";
+  }
+}
+
+/// Generate the parser for a single format element.
+static void genElementParser(Element *element, OpMethodBody &body,
+                             FmtContext &attrTypeCtx) {
+  /// Optional Group.
+  if (auto *optional = dyn_cast<OptionalElement>(element)) {
+    auto elements = optional->getElements();
+
+    // Generate a special optional parser for the first element to gate the
+    // parsing of the rest of the elements.
+    if (auto *literal = dyn_cast<LiteralElement>(&*elements.begin())) {
+      body << "  if (succeeded(parser.parseOptional";
+      genLiteralParser(literal->getLiteral(), body);
+      body << ")) {\n";
+    } else if (auto *opVar = dyn_cast<OperandVariable>(&*elements.begin())) {
+      genElementParser(opVar, body, attrTypeCtx);
+      body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
+    }
+
+    // Generate the rest of the elements normally.
+    for (auto &childElement : llvm::drop_begin(elements, 1))
+      genElementParser(&childElement, body, attrTypeCtx);
+    body << "  }\n";
+
+    /// Literals.
+  } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
+    body << "  if (parser.parse";
+    genLiteralParser(literal->getLiteral(), body);
+    body << ")\n    return failure();\n";
+
+    /// Arguments.
+  } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+    const NamedAttribute *var = attr->getVar();
+
+    // Check to see if we can parse this as an enum attribute.
+    if (canFormatEnumAttr(var)) {
+      const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
+
+      // Generate the code for building an attribute for this enum.
+      std::string attrBuilderStr;
+      {
+        llvm::raw_string_ostream os(attrBuilderStr);
+        os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
+                    "attrOptional.getValue()");
+      }
+
+      body << formatv(enumAttrParserCode, var->name, enumAttr.getCppNamespace(),
+                      enumAttr.getStringToSymbolFnName(), attrBuilderStr);
+      return;
+    }
+
+    // If this attribute has a buildable type, use that when parsing the
+    // attribute.
+    std::string attrTypeStr;
+    if (Optional<Type> attrType = var->attr.getValueType()) {
+      if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
+        llvm::raw_string_ostream os(attrTypeStr);
+        os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
+      }
+    }
+
+    body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
+                    attrTypeStr);
+  } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
+    bool isVariadic = operand->getVar()->isVariadic();
+    body << formatv(isVariadic ? variadicOperandParserCode : operandParserCode,
+                    operand->getVar()->name);
+
+    /// Directives.
+  } else if (isa<AttrDictDirective>(element)) {
+    body << "  if (parser.parseOptionalAttrDict(result.attributes))\n"
+         << "    return failure();\n";
+  } else if (isa<OperandsDirective>(element)) {
+    body << "  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
+         << "  SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
+         << "  if (parser.parseOperandList(allOperands))\n"
+         << "    return failure();\n";
+  } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
+    bool isVariadic = false;
+    StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
+    body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode,
+                    listName);
+  } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
+    bool ignored = false;
+    body << formatv(functionalTypeParserCode,
+                    getTypeListName(dir->getInputs(), ignored),
+                    getTypeListName(dir->getResults(), ignored));
   } else {
-    body << (StringRef)llvm::StringSwitch<StringRef>(value)
-                .Case("->", "Arrow()")
-                .Case(":", "Colon()")
-                .Case(",", "Comma()")
-                .Case("=", "Equal()")
-                .Case("<", "Less()")
-                .Case(">", "Greater()")
-                .Case("(", "LParen()")
-                .Case(")", "RParen()")
-                .Case("[", "LSquare()")
-                .Case("]", "RSquare()");
+    llvm_unreachable("unknown format element");
   }
-  body << ")\n    return failure();\n";
 }
 
 void OperationFormat::genParser(Operator &op, OpClass &opClass) {
@@ -415,79 +561,19 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
       OpMethod::MP_Static);
   auto &body = method.body();
 
+  // Generate variables to store the operands and type within the format. This
+  // allows for referencing these variables in the presence of optional
+  // groupings.
+  for (auto &element : elements)
+    genElementParserStorage(&*element, body);
+
   // A format context used when parsing attributes with buildable types.
   FmtContext attrTypeCtx;
   attrTypeCtx.withBuilder("parser.getBuilder()");
 
   // Generate parsers for each of the elements.
-  for (auto &element : elements) {
-    /// Literals.
-    if (LiteralElement *literal = dyn_cast<LiteralElement>(element.get())) {
-      genLiteralParser(literal->getLiteral(), body);
-
-      /// Arguments.
-    } else if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
-      const NamedAttribute *var = attr->getVar();
-
-      // Check to see if we can parse this as an enum attribute.
-      if (canFormatEnumAttr(var)) {
-        const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
-
-        // Generate the code for building an attribute for this enum.
-        std::string attrBuilderStr;
-        {
-          llvm::raw_string_ostream os(attrBuilderStr);
-          os << tgfmt(enumAttr.getConstBuilderTemplate(), &attrTypeCtx,
-                      "attrOptional.getValue()");
-        }
-
-        body << formatv(enumAttrParserCode, var->name,
-                        enumAttr.getCppNamespace(),
-                        enumAttr.getStringToSymbolFnName(), attrBuilderStr);
-        continue;
-      }
-
-      // If this attribute has a buildable type, use that when parsing the
-      // attribute.
-      std::string attrTypeStr;
-      if (Optional<Type> attrType = var->attr.getValueType()) {
-        if (Optional<StringRef> typeBuilder = attrType->getBuilderCall()) {
-          llvm::raw_string_ostream os(attrTypeStr);
-          os << ", " << tgfmt(*typeBuilder, &attrTypeCtx);
-        }
-      }
-
-      body << formatv(attrParserCode, var->attr.getStorageType(), var->name,
-                      attrTypeStr);
-    } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) {
-      bool isVariadic = operand->getVar()->isVariadic();
-      body << formatv(isVariadic ? variadicOperandParserCode
-                                 : operandParserCode,
-                      operand->getVar()->name);
-
-      /// Directives.
-    } else if (isa<AttrDictDirective>(element.get())) {
-      body << "  if (parser.parseOptionalAttrDict(result.attributes))\n"
-           << "    return failure();\n";
-    } else if (isa<OperandsDirective>(element.get())) {
-      body << "  llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
-           << "  SmallVector<OpAsmParser::OperandType, 4> allOperands;\n"
-           << "  if (parser.parseOperandList(allOperands))\n"
-           << "    return failure();\n";
-    } else if (auto *dir = dyn_cast<TypeDirective>(element.get())) {
-      bool isVariadic = false;
-      StringRef listName = getTypeListName(dir->getOperand(), isVariadic);
-      body << formatv(isVariadic ? variadicTypeParserCode : typeParserCode,
-                      listName);
-    } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element.get())) {
-      bool ignored = false;
-      body << formatv(functionalTypeParserCode,
-                      getTypeListName(dir->getInputs(), ignored),
-                      getTypeListName(dir->getResults(), ignored));
-    } else {
-      llvm_unreachable("unknown format element");
-    }
-  }
+  for (auto &element : elements)
+    genElementParser(element.get(), body, attrTypeCtx);
 
   // Generate the code to resolve the operand and result types now that they
   // have been parsed.
@@ -676,7 +762,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
 }
 
-/// Generate the c++ for an operand to a (*-)type directive.
+/// Generate the C++ for an operand to a (*-)type directive.
 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
   if (isa<OperandsDirective>(arg))
     return body << "getOperation()->getOperandTypes()";
@@ -689,6 +775,79 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
   return body << "ArrayRef<Type>(" << var->name << "().getType())";
 }
 
+/// Generate the code for printing the given element.
+static void genElementPrinter(Element *element, OpMethodBody &body,
+                              OperationFormat &fmt, bool &shouldEmitSpace,
+                              bool &lastWasPunctuation) {
+  if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
+    return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
+                             lastWasPunctuation);
+
+  // Emit an optional group.
+  if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
+    // Emit the check for the presence of the anchor element.
+    Element *anchor = optional->getAnchor();
+    if (AttributeVariable *attrVar = dyn_cast<AttributeVariable>(anchor))
+      body << "  if (getAttr(\"" << attrVar->getVar()->name << "\")) {\n";
+    else
+      body << "  if (!" << cast<OperandVariable>(anchor)->getVar()->name
+           << "().empty()) {\n";
+
+    // Emit each of the elements.
+    for (Element &childElement : optional->getElements())
+      genElementPrinter(&childElement, body, fmt, shouldEmitSpace,
+                        lastWasPunctuation);
+    body << "  }\n";
+    return;
+  }
+
+  // Emit the attribute dictionary.
+  if (isa<AttrDictDirective>(element)) {
+    genAttrDictPrinter(fmt, body);
+    lastWasPunctuation = false;
+    return;
+  }
+
+  // Optionally insert a space before the next element. The AttrDict printer
+  // already adds a space as necessary.
+  if (shouldEmitSpace || !lastWasPunctuation)
+    body << "  p << \" \";\n";
+  lastWasPunctuation = false;
+  shouldEmitSpace = true;
+
+  if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+    const NamedAttribute *var = attr->getVar();
+
+    // If we are formatting as a enum, symbolize the attribute as a string.
+    if (canFormatEnumAttr(var)) {
+      const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
+      body << "  p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName() << "("
+           << var->name << "()) << \"\\\"\";\n";
+      return;
+    }
+
+    // Elide the attribute type if it is buildable.
+    Optional<Type> attrType = var->attr.getValueType();
+    if (attrType && attrType->getBuilderCall())
+      body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
+    else
+      body << "  p.printAttribute(" << var->name << "Attr());\n";
+  } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
+    body << "  p << " << operand->getVar()->name << "();\n";
+  } else if (isa<OperandsDirective>(element)) {
+    body << "  p << getOperation()->getOperands();\n";
+  } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
+    body << "  p << ";
+    genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
+  } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
+    body << "  p.printFunctionalType(";
+    genTypeOperandPrinter(dir->getInputs(), body) << ", ";
+    genTypeOperandPrinter(dir->getResults(), body) << ");\n";
+  } else {
+    llvm_unreachable("unknown format element");
+  }
+}
+
 void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
   auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p");
   auto &body = method.body();
@@ -706,60 +865,9 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
   // Flags for if we should emit a space, and if the last element was
   // punctuation.
   bool shouldEmitSpace = true, lastWasPunctuation = false;
-  for (auto &element : elements) {
-    // Emit a literal element.
-    if (LiteralElement *literal = dyn_cast<LiteralElement>(element.get())) {
-      genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
-                        lastWasPunctuation);
-      continue;
-    }
-
-    // Emit the attribute dictionary.
-    if (isa<AttrDictDirective>(element.get())) {
-      genAttrDictPrinter(*this, body);
-      lastWasPunctuation = false;
-      continue;
-    }
-
-    // Optionally insert a space before the next element. The AttrDict printer
-    // already adds a space as necessary.
-    if (shouldEmitSpace || !lastWasPunctuation)
-      body << "  p << \" \";\n";
-    lastWasPunctuation = false;
-    shouldEmitSpace = true;
-
-    if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
-      const NamedAttribute *var = attr->getVar();
-
-      // If we are formatting as a enum, symbolize the attribute as a string.
-      if (canFormatEnumAttr(var)) {
-        const EnumAttr &enumAttr = cast<EnumAttr>(var->attr);
-        body << "  p << \"\\\"\" << " << enumAttr.getSymbolToStringFnName()
-             << "(" << var->name << "()) << \"\\\"\";\n";
-        continue;
-      }
-
-      // Elide the attribute type if it is buildable.
-      Optional<Type> attrType = var->attr.getValueType();
-      if (attrType && attrType->getBuilderCall())
-        body << "  p.printAttributeWithoutType(" << var->name << "Attr());\n";
-      else
-        body << "  p.printAttribute(" << var->name << "Attr());\n";
-    } else if (auto *operand = dyn_cast<OperandVariable>(element.get())) {
-      body << "  p << " << operand->getVar()->name << "();\n";
-    } else if (isa<OperandsDirective>(element.get())) {
-      body << "  p << getOperation()->getOperands();\n";
-    } else if (auto *dir = dyn_cast<TypeDirective>(element.get())) {
-      body << "  p << ";
-      genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
-    } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element.get())) {
-      body << "  p.printFunctionalType(";
-      genTypeOperandPrinter(dir->getInputs(), body) << ", ";
-      genTypeOperandPrinter(dir->getResults(), body) << ");\n";
-    } else {
-      llvm_unreachable("unknown format element");
-    }
-  }
+  for (auto &element : elements)
+    genElementPrinter(element.get(), body, *this, shouldEmitSpace,
+                      lastWasPunctuation);
 }
 
 //===----------------------------------------------------------------------===//
@@ -778,8 +886,10 @@ class Token {
     // Tokens with no info.
     l_paren,
     r_paren,
+    caret,
     comma,
     equal,
+    question,
 
     // Keywords.
     keyword_start,
@@ -908,10 +1018,14 @@ Token FormatLexer::lexToken() {
     return formToken(Token::eof, tokStart);
 
   // Lex punctuation.
+  case '^':
+    return formToken(Token::caret, tokStart);
   case ',':
     return formToken(Token::comma, tokStart);
   case '=':
     return formToken(Token::equal, tokStart);
+  case '?':
+    return formToken(Token::question, tokStart);
   case '(':
     return formToken(Token::l_paren, tokStart);
   case ')':
@@ -1026,6 +1140,12 @@ class FormatParser {
   LogicalResult parseDirective(std::unique_ptr<Element> &element,
                                bool isTopLevel);
   LogicalResult parseLiteral(std::unique_ptr<Element> &element);
+  LogicalResult parseOptional(std::unique_ptr<Element> &element,
+                              bool isTopLevel);
+  LogicalResult parseOptionalChildElement(
+      std::vector<std::unique_ptr<Element>> &childElements,
+      SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
+      Optional<unsigned> &anchorIdx);
 
   /// Parse the various 
diff erent directives.
   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
@@ -1077,6 +1197,7 @@ class FormatParser {
   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
   llvm::DenseSet<const NamedAttribute *> seenAttrs;
+  llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
 };
 } // end anonymous namespace
 
@@ -1236,11 +1357,14 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
   // Literals.
   if (curToken.getKind() == Token::literal)
     return parseLiteral(element);
+  // Optionals.
+  if (curToken.getKind() == Token::l_paren)
+    return parseOptional(element, isTopLevel);
   // Variables.
   if (curToken.getKind() == Token::variable)
     return parseVariable(element, isTopLevel);
   return emitError(curToken.getLoc(),
-                   "expected directive, literal, or variable");
+                   "expected directive, literal, variable, or optional group");
 }
 
 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
@@ -1314,6 +1438,115 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element) {
   return success();
 }
 
+LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
+                                          bool isTopLevel) {
+  llvm::SMLoc curLoc = curToken.getLoc();
+  if (!isTopLevel)
+    return emitError(curLoc, "optional groups can only be used as top-level "
+                             "elements");
+  consumeToken();
+
+  // Parse the child elements for this optional group.
+  std::vector<std::unique_ptr<Element>> elements;
+  SmallPtrSet<const NamedTypeConstraint *, 8> seenVariables;
+  Optional<unsigned> anchorIdx;
+  do {
+    if (failed(parseOptionalChildElement(elements, seenVariables, anchorIdx)))
+      return failure();
+  } while (curToken.getKind() != Token::r_paren);
+  consumeToken();
+  if (failed(parseToken(Token::question, "expected '?' after optional group")))
+    return failure();
+
+  // The optional group is required to have an anchor.
+  if (!anchorIdx)
+    return emitError(curLoc, "optional group specified no anchor element");
+
+  // The first element of the group must be one that can be parsed/printed in an
+  // optional fashion.
+  if (!isa<LiteralElement>(&*elements.front()) &&
+      !isa<OperandVariable>(&*elements.front()))
+    return emitError(curLoc, "first element of an operand group must be a "
+                             "literal or operand");
+
+  // After parsing all of the elements, ensure that all type directives refer
+  // only to elements within the group.
+  auto checkTypeOperand = [&](Element *typeEle) {
+    auto *opVar = dyn_cast<OperandVariable>(typeEle);
+    const NamedTypeConstraint *var = opVar ? opVar->getVar() : nullptr;
+    if (!seenVariables.count(var))
+      return emitError(curLoc, "type directive can only refer to variables "
+                               "within the optional group");
+    return success();
+  };
+  for (auto &ele : elements) {
+    if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+      if (failed(checkTypeOperand(typeEle->getOperand())))
+        return failure();
+    } else if (auto *typeEle = dyn_cast<FunctionalTypeDirective>(ele.get())) {
+      if (failed(checkTypeOperand(typeEle->getInputs())) ||
+          failed(checkTypeOperand(typeEle->getResults())))
+        return failure();
+    }
+  }
+
+  optionalVariables.insert(seenVariables.begin(), seenVariables.end());
+  element = std::make_unique<OptionalElement>(std::move(elements), *anchorIdx);
+  return success();
+}
+
+LogicalResult FormatParser::parseOptionalChildElement(
+    std::vector<std::unique_ptr<Element>> &childElements,
+    SmallPtrSetImpl<const NamedTypeConstraint *> &seenVariables,
+    Optional<unsigned> &anchorIdx) {
+  llvm::SMLoc childLoc = curToken.getLoc();
+  childElements.push_back({});
+  if (failed(parseElement(childElements.back(), /*isTopLevel=*/true)))
+    return failure();
+
+  // Check to see if this element is the anchor of the optional group.
+  bool isAnchor = curToken.getKind() == Token::caret;
+  if (isAnchor) {
+    if (anchorIdx)
+      return emitError(childLoc, "only one element can be marked as the anchor "
+                                 "of an optional group");
+    anchorIdx = childElements.size() - 1;
+    consumeToken();
+  }
+
+  return TypeSwitch<Element *, LogicalResult>(childElements.back().get())
+      // All attributes can be within the optional group, but only optional
+      // attributes can be the anchor.
+      .Case([&](AttributeVariable *attrEle) {
+        if (isAnchor && !attrEle->getVar()->attr.isOptional())
+          return emitError(childLoc, "only optional attributes can be used to "
+                                     "anchor an optional group");
+        return success();
+      })
+      // Only optional-like(i.e. variadic) operands can be within an optional
+      // group.
+      .Case<OperandVariable>([&](auto *ele) {
+        if (!ele->getVar()->isVariadic())
+          return emitError(childLoc, "only variadic operands can be used within"
+                                     " an optional group");
+        seenVariables.insert(ele->getVar());
+        return success();
+      })
+      // Literals and type directives may be used, but they can't anchor the
+      // group.
+      .Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
+          [&](auto *) {
+            if (isAnchor)
+              return emitError(childLoc, "only variables can be used to anchor "
+                                         "an optional group");
+            return success();
+          })
+      .Default([&](auto *) {
+        return emitError(childLoc, "only literals, types, and variables can be "
+                                   "used within an optional group");
+      });
+}
+
 LogicalResult
 FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
                                      llvm::SMLoc loc, bool isTopLevel) {
@@ -1344,8 +1577,6 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
       failed(parseTypeDirectiveOperand(results)) ||
       failed(parseToken(Token::r_paren, "expected ')' after argument list")))
     return failure();
-
-  // Get the proper directive kind and create it.
   element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
                                                       std::move(results));
   return success();


        


More information about the Mlir-commits mailing list