[Mlir-commits] [mlir] 5bec1ea - [mlir] Added oilist primitive

Shraiysh Vaishay llvmlistbot at llvm.org
Wed Feb 16 21:40:35 PST 2022


Author: Shraiysh Vaishay
Date: 2022-02-17T11:10:24+05:30
New Revision: 5bec1ea7a74895895e7831fd951dd8130d4f3d01

URL: https://github.com/llvm/llvm-project/commit/5bec1ea7a74895895e7831fd951dd8130d4f3d01
DIFF: https://github.com/llvm/llvm-project/commit/5bec1ea7a74895895e7831fd951dd8130d4f3d01.diff

LOG: [mlir] Added oilist primitive

This patch attempts to add the `oilist` primitive proposed in the [[ https://llvm.discourse.group/t/rfc-extending-declarative-assembly-format-to-support-order-independent-variadic-segments/4388 | RFC: Extending Declarative Assembly Format to support order-independent variadic segments ]].

This element supports optional order-independent variadic segments for operations. This will allow OpenACC and OpenMP Dialects to have similar and relaxed requirements while encouraging the use of Declarative Assembly Format and avoiding code duplication.

An oilist element parses grammar of the form:
```
clause-list := clause clause-list | empty
clause := `keyword` clause1 | `otherKeyword` clause2
clause1 := <assembly-format element>
clause2 := <assembly-format element>
```

AssemblyFormat specification:
```
let assemblyFormat = [{
  oilist( `keyword` clause1
        | `otherkeyword` clause2
        ...
        )
}];
```

Example:
```
oilist( `private` `(` $arg0 `:` type($arg0) `)`
      | `nowait`
      | `reduction` custom<ReductionClause>($arg1, type($arg1)))

oilist( `private` `=` $arg0 `:` type($arg0)
      | `reduction` `=` $arg1 `:` type($arg1)
      | `firstprivate` `=` $arg3 `:` type($arg2))
```

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D115215

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/test/IR/traits.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/tools/mlir-tblgen/FormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.h
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 1058b33480073..e9aa37f5fa76c 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -619,6 +619,14 @@ The available directives are as follows:
     -   The constraints on `inputs` and `results` are the same as the `input` of
         the `type` directive.
 
+*   `oilist` ( \`keyword\` elements | \`otherKeyword\` elements ...)
+
+    -   Represents an optional order-independent list of clauses. Each clause
+        has a keyword and corresponding assembly format.
+    -   Each clause can appear 0 or 1 time (in any order).
+    -   Only literals, types and variables can be used within an oilist element.
+    -   All the variables must be optional or variadic.
+
 *   `operands`
 
     -   Represents all of the operands of an operation.

diff  --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index c0fb012975bac..e6283b52caa52 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -488,6 +488,75 @@ func @succeededResultSizeAttr() {
 
 // -----
 
+// CHECK-LABEL: @succeededOilistTrivial
+func @succeededOilistTrivial() {
+  // CHECK: test.oilist_with_keywords_only keyword
+  test.oilist_with_keywords_only keyword
+  // CHECK: test.oilist_with_keywords_only otherKeyword
+  test.oilist_with_keywords_only otherKeyword
+  // CHECK: test.oilist_with_keywords_only keyword otherKeyword
+  test.oilist_with_keywords_only keyword otherKeyword
+  // CHECK: test.oilist_with_keywords_only keyword otherKeyword
+  test.oilist_with_keywords_only otherKeyword keyword
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @succeededOilistSimple
+func @succeededOilistSimple(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32
+  test.oilist_with_simple_args keyword %arg0 : i32
+  // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32
+  test.oilist_with_simple_args otherKeyword %arg0 : i32
+  // CHECK: test.oilist_with_simple_args thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args thirdKeyword %arg0 : i32
+
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32
+  test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args keyword %arg0 : i32 thirdKeyword %arg1 : i32
+  // CHECK: test.oilist_with_simple_args otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args thirdKeyword %arg0 : i32 otherKeyword %arg1 : i32
+
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args keyword %arg0 : i32 otherKeyword %arg1 : i32 thirdKeyword %arg2 : i32
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args otherKeyword %arg0 : i32 keyword %arg1 : i32 thirdKeyword %arg2 : i32
+  // CHECK: test.oilist_with_simple_args keyword %{{.*}} : i32 otherKeyword %{{.*}} : i32 thirdKeyword %{{.*}} : i32
+  test.oilist_with_simple_args otherKeyword %arg0 : i32 thirdKeyword %arg1 : i32 keyword %arg2 : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @succeededOilistVariadic
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+func @succeededOilistVariadic(%arg0: i32, %arg1: i32, %arg2: i32) {
+  // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32)
+  test.oilist_variadic_with_parens keyword (%arg0, %arg1 : i32, i32)
+  // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG2]], %[[ARG1]] : i32, i32)
+  test.oilist_variadic_with_parens otherKeyword (%arg2, %arg1 : i32, i32) keyword (%arg0, %arg1 : i32, i32)
+  // CHECK: test.oilist_variadic_with_parens keyword(%[[ARG0]], %[[ARG1]] : i32, i32) otherKeyword(%[[ARG0]], %[[ARG1]] : i32, i32) thirdKeyword(%[[ARG2]], %[[ARG0]], %[[ARG1]] : i32, i32, i32)
+  test.oilist_variadic_with_parens thirdKeyword (%arg2, %arg0, %arg1 : i32, i32, i32) keyword (%arg0, %arg1 : i32, i32) otherKeyword (%arg0, %arg1 : i32, i32)
+  return
+}
+
+// -----
+// CHECK-LABEL: succeededOilistCustom
+// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
+func @succeededOilistCustom(%arg0: i32, %arg1: i32, %arg2: i32) {
+  // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32)
+  test.oilist_custom private (%arg0, %arg1 : i32, i32)
+  // CHECK: test.oilist_custom private(%[[ARG0]], %[[ARG1]] : i32, i32) nowait
+  test.oilist_custom private (%arg0, %arg1 : i32, i32) nowait
+  // CHECK: test.oilist_custom private(%arg0, %arg1 : i32, i32) nowait reduction (%arg1)
+  test.oilist_custom nowait reduction (%arg1) private (%arg0, %arg1 : i32, i32)
+  return
+}
+
+// -----
+
 func @failedHasDominanceScopeOutsideDominanceFreeScope() -> () {
   "test.ssacfg_region"() ({
     test.graph_region {

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 623d51295516e..f3f4d54d26e1d 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -387,6 +387,17 @@ void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
 //===----------------------------------------------------------------------===//
 // Parsing
 
+static ParseResult
+parseCustomOptionalOperand(OpAsmParser &parser,
+                           Optional<OpAsmParser::OperandType> &optOperand) {
+  if (succeeded(parser.parseOptionalLParen())) {
+    optOperand.emplace();
+    if (parser.parseOperand(*optOperand) || parser.parseRParen())
+      return failure();
+  }
+  return success();
+}
+
 static ParseResult parseCustomDirectiveOperands(
     OpAsmParser &parser, OpAsmParser::OperandType &operand,
     Optional<OpAsmParser::OperandType> &optOperand,
@@ -505,6 +516,12 @@ static ParseResult parseCustomDirectiveOptionalOperandRef(
 //===----------------------------------------------------------------------===//
 // Printing
 
+static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
+                                       Value optOperand) {
+  if (optOperand)
+    printer << "(" << optOperand << ") ";
+}
+
 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
                                          Value operand, Value optOperand,
                                          OperandRange varOperands) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index f5834efe9cb5a..40bec4f4807e4 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -636,6 +636,54 @@ def AttrSizedResultOp : TEST_Op<"attr_sized_results",
 // is the dialect parser and printer hooks.
 def CustomFormatFallbackOp : TEST_Op<"dialect_custom_format_fallback">;
 
+// Ops related to OIList primitive
+def OIListTrivial : TEST_Op<"oilist_with_keywords_only"> {
+  let assemblyFormat = [{
+    oilist( `keyword`
+          | `otherKeyword`) attr-dict
+  }];
+}
+
+def OIListSimple : TEST_Op<"oilist_with_simple_args", [AttrSizedOperandSegments]> {
+  let arguments = (ins Optional<AnyType>:$arg0,
+                       Optional<AnyType>:$arg1,
+                       Optional<AnyType>:$arg2);
+  let assemblyFormat = [{
+    oilist( `keyword` $arg0 `:` type($arg0)
+          | `otherKeyword` $arg1 `:` type($arg1)
+          | `thirdKeyword` $arg2 `:` type($arg2) ) attr-dict
+  }];
+}
+
+def OIListVariadic : TEST_Op<"oilist_variadic_with_parens", [AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$arg0,
+                       Variadic<AnyType>:$arg1,
+                       Variadic<AnyType>:$arg2);
+  let assemblyFormat = [{
+    oilist( `keyword` `(` $arg0 `:` type($arg0) `)`
+          | `otherKeyword` `(` $arg1 `:` type($arg1) `)`
+          | `thirdKeyword` `(` $arg2 `:` type($arg2) `)`) attr-dict
+  }];
+}
+
+def OIListCustom : TEST_Op<"oilist_custom", [AttrSizedOperandSegments]> {
+  let arguments = (ins Variadic<AnyType>:$arg0,
+                       Optional<I32>:$optOperand,
+                       UnitAttr:$nowait);
+  let assemblyFormat = [{
+    oilist( `private` `(` $arg0 `:` type($arg0) `)`
+          | `nowait`
+          | `reduction` custom<CustomOptionalOperand>($optOperand)
+    ) attr-dict
+  }];
+}
+
+def OIListAllowedLiteral : TEST_Op<"oilist_allowed_literal"> {
+  let assemblyFormat = [{
+    oilist( `foo` | `bar` ) `buzz` attr-dict
+  }];
+}
+
 // This is used to test encoding of a string attribute into an SSA name of a
 // pretty printed value name.
 def StringAttrPrettyNameOp

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 1c419424d6021..84edca8e621ac 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -344,6 +344,57 @@ def LiteralValid : TestFormat_Op<[{
   attr-dict
 }]>;
 
+//===----------------------------------------------------------------------===//
+// OIList Element
+//===----------------------------------------------------------------------===//
+
+// CHECK: error: format ambiguity because bar is used in two adjacent oilist elements.
+def OIListAdjacentOIList : TestFormat_Op<[{
+  oilist ( `foo` | `bar` ) oilist ( `bar` | `buzz` ) attr-dict
+}]>;
+// CHECK: error: expected literal, but got ')'
+def OIListErrorExpectedLiteral : TestFormat_Op<[{
+  oilist( `keyword` | ) attr-dict
+}]>;
+// CHECK: error: expected literal, but got ')'
+def OIListErrorExpectedEmpty : TestFormat_Op<[{
+  oilist() attr-dict
+}]>;
+// CHECK: error: expected literal, but got '$arg0'
+def OIListErrorNoLiteral : TestFormat_Op<[{
+  oilist( $arg0 `:` type($arg0) | $arg1 `:` type($arg1) ) attr-dict
+}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1)>;
+// CHECK: error: format ambiguity because foo is used both in oilist element and the adjacent literal.
+def OIListLiteralAmbiguity : TestFormat_Op<[{
+  oilist( `foo` | `bar` ) `foo` attr-dict
+}]>;
+// CHECK: error: expected '(' before oilist argument list
+def OIListStartingToken : TestFormat_Op<[{
+  oilist `wrong` attr-dict
+}]>;
+
+// CHECK-NOT: error
+def OIListTrivial : TestFormat_Op<[{
+  oilist(`keyword` `(` `)` | `otherkeyword` `(` `)`) attr-dict
+}]>;
+def OIListSimple : TestFormat_Op<[{
+  oilist( `keyword` $arg0 `:` type($arg0)
+        | `otherkeyword` $arg1 `:` type($arg1)
+        | `thirdkeyword` $arg2 `:` type($arg2) )
+  attr-dict
+}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1, Optional<AnyType>:$arg2)>;
+def OIListVariadic : TestFormat_Op<[{
+  oilist( `keyword` `(` $args0 `:` type($args0) `)`
+        | `otherkeyword` `(` $args1 `:` type($args1) `)`
+        | `thirdkeyword` `(` $args2 `:` type($args2) `)`)
+  attr-dict
+}], [AttrSizedOperandSegments]>, Arguments<(ins Variadic<AnyType>:$args0, Variadic<AnyType>:$args1, Variadic<AnyType>:$args2)>;
+def OIListCustom : TestFormat_Op<[{
+  oilist( `private` `(` $arg0 `:` type($arg0) `)`
+        | `nowait`
+        | `reduction` custom<ReductionClause>($arg1, type($arg1))) attr-dict
+}], [AttrSizedOperandSegments]>, Arguments<(ins Optional<AnyType>:$arg0, Optional<AnyType>:$arg1)>;
+
 //===----------------------------------------------------------------------===//
 // Optional Groups
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index a4c9dcf28981f..8d08340800c91 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -115,6 +115,8 @@ FormatToken FormatLexer::lexToken() {
     return formToken(FormatToken::r_paren, tokStart);
   case '*':
     return formToken(FormatToken::star, tokStart);
+  case '|':
+    return formToken(FormatToken::pipe, tokStart);
 
   // Ignore whitespace characters.
   case 0:
@@ -164,6 +166,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
           .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
           .Case("custom", FormatToken::kw_custom)
           .Case("functional-type", FormatToken::kw_functional_type)
+          .Case("oilist", FormatToken::kw_oilist)
           .Case("operands", FormatToken::kw_operands)
           .Case("params", FormatToken::kw_params)
           .Case("ref", FormatToken::kw_ref)
@@ -230,7 +233,12 @@ FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
         "literals may only be used in the top-level section of the format");
   }
   // Get the spelling without the surrounding backticks.
-  StringRef value = tok.getSpelling().drop_front().drop_back();
+  StringRef value = tok.getSpelling();
+  // Prevents things like `$arg0` or empty literals (when a literal is expected
+  // but not found) from getting segmentation faults.
+  if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
+    return emitError(tok.getLoc(), "expected literal, but got '" + value + "'");
+  value = value.drop_front().drop_back();
 
   // The parsed literal is a space element (`` or ` `) or a newline.
   if (value.empty() || value == " " || value == "\\n")

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 4ad591d49ebc6..741e2716f0388 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -54,6 +54,7 @@ class FormatToken {
     greater,
     question,
     star,
+    pipe,
 
     // Keywords.
     keyword_start,
@@ -61,6 +62,7 @@ class FormatToken {
     kw_attr_dict_w_keyword,
     kw_custom,
     kw_functional_type,
+    kw_oilist,
     kw_operands,
     kw_params,
     kw_qualified,
@@ -271,6 +273,7 @@ class DirectiveElement : public FormatElementBase<FormatElement::Directive> {
     AttrDict,
     Custom,
     FunctionalType,
+    OIList,
     Operands,
     Ref,
     Regions,

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3e395c2f77310..37e62880b543c 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -185,6 +185,62 @@ class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
 
   bool shouldBeQualifiedFlag = false;
 };
+
+/// This class represents a group of order-independent optional clauses. Each
+/// clause starts with a literal element and has a coressponding parsing
+/// element. A parsing element is a continous sequence of format elements.
+/// Each clause can appear 0 or 1 time.
+class OIListElement : public DirectiveElementBase<DirectiveElement::OIList> {
+public:
+  OIListElement(std::vector<FormatElement *> &&literalElements,
+                std::vector<std::vector<FormatElement *>> &&parsingElements)
+      : literalElements(std::move(literalElements)),
+        parsingElements(std::move(parsingElements)) {}
+
+  /// Returns a range to iterate over the LiteralElements.
+  auto getLiteralElements() const {
+    // The use of std::function is unfortunate but necessary here. Lambda
+    // functions cannot be copied but std::function can be copied. This copy
+    // constructor is used in llvm::zip.
+    std::function<LiteralElement *(FormatElement * el)>
+        literalElementCastConverter =
+            [](FormatElement *el) { return cast<LiteralElement>(el); };
+    return llvm::map_range(literalElements, literalElementCastConverter);
+  }
+
+  /// Returns a range to iterate over the parsing elements corresponding to the
+  /// clauses.
+  ArrayRef<std::vector<FormatElement *>> getParsingElements() const {
+    return parsingElements;
+  }
+
+  /// Returns a range to iterate over tuples of parsing and literal elements.
+  auto getClauses() const {
+    return llvm::zip(getLiteralElements(), getParsingElements());
+  }
+
+private:
+  /// A vector of `LiteralElement` objects. Each element stores the keyword
+  /// for one case of oilist element. For example, an oilist element along with
+  /// the `literalElements` vector:
+  /// ```
+  ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
+  ///  literalElements = { `keyword`, `otherKeyword` }
+  /// ```
+  std::vector<FormatElement *> literalElements;
+
+  /// A vector of valid declarative assembly format vectors. Each object in
+  /// parsing elements is a vector of elements in assembly format syntax.
+  /// For example, an oilist element along with the parsingElements vector:
+  /// ```
+  ///  oilist [ `keyword` `=` `(` $arg0 `)` | `otherKeyword` `<` $arg1 `>`]
+  ///  parsingElements = {
+  ///    { `=`, `(`, $arg0, `)` },
+  ///    { `<`, $arg1, `>` }
+  ///  }
+  /// ```
+  std::vector<std::vector<FormatElement *>> parsingElements;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -630,6 +686,19 @@ const char *successorParserCode = R"(
     return ::mlir::failure();
 )";
 
+/// The code snippet used to generate a parser for OIList
+///
+/// {0}: literal keyword corresponding to a case for oilist
+const char *oilistParserCode = R"(
+  if ({0}Clause) {
+    return parser.emitError(parser.getNameLoc())
+          << "`{0}` clause can appear at most once in the expansion of the "
+             "oilist directive";
+  }
+  {0}Clause = true;
+  result.addAttribute("{0}", UnitAttr::get(parser.getContext()));
+)";
+
 namespace {
 /// The type of length for a given parse argument.
 enum class ArgumentLengthKind {
@@ -720,6 +789,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
     for (FormatElement *childElement : optional->getElseElements())
       genElementParserStorage(childElement, op, body);
 
+  } else if (auto *oilist = dyn_cast<OIListElement>(element)) {
+    for (ArrayRef<FormatElement *> pelement : oilist->getParsingElements())
+      for (FormatElement *element : pelement)
+        genElementParserStorage(element, op, body);
+
   } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
     for (FormatElement *paramElement : custom->getArguments())
       genElementParserStorage(paramElement, op, body);
@@ -1104,6 +1178,31 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
     }
     body << "\n";
 
+    /// OIList Directive
+  } else if (OIListElement *oilist = dyn_cast<OIListElement>(element)) {
+    for (LiteralElement *le : oilist->getLiteralElements())
+      body << "  bool " << le->getSpelling() << "Clause = false;\n";
+
+    // Generate the parsing loop
+    body << "  while(true) {\n";
+    for (auto clause : oilist->getClauses()) {
+      LiteralElement *lelement = std::get<0>(clause);
+      ArrayRef<FormatElement *> pelement = std::get<1>(clause);
+      body << "if (succeeded(parser.parseOptional";
+      genLiteralParser(lelement->getSpelling(), body);
+      body << ")) {\n";
+      StringRef attrName = lelement->getSpelling();
+      body << formatv(oilistParserCode, attrName);
+      inferredAttributes.insert(attrName);
+      for (FormatElement *el : pelement)
+        genElementParser(el, body, attrTypeCtx);
+      body << "    } else ";
+    }
+    body << " {\n";
+    body << "    break;\n";
+    body << "  }\n";
+    body << "}\n";
+
     /// Literals.
   } else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
     body << "  if (parser.parse";
@@ -1844,6 +1943,26 @@ void OperationFormat::genElementPrinter(FormatElement *element,
     return;
   }
 
+  // Emit the OIList
+  if (auto *oilist = dyn_cast<OIListElement>(element)) {
+    genLiteralPrinter(" ", body, shouldEmitSpace, lastWasPunctuation);
+    for (auto clause : oilist->getClauses()) {
+      LiteralElement *lelement = std::get<0>(clause);
+      ArrayRef<FormatElement *> pelement = std::get<1>(clause);
+
+      body << "  if ((*this)->hasAttrOfType<UnitAttr>(\""
+           << lelement->getSpelling() << "\")) {\n";
+      genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace,
+                        lastWasPunctuation);
+      for (FormatElement *element : pelement) {
+        genElementPrinter(element, body, op, shouldEmitSpace,
+                          lastWasPunctuation);
+      }
+      body << "  }\n";
+    }
+    return;
+  }
+
   // Emit the attribute dictionary.
   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
     genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
@@ -2061,6 +2180,9 @@ class OpFormatParser : public FormatParser {
   /// Verify the state of operation successors within the format.
   LogicalResult verifySuccessors(SMLoc loc);
 
+  LogicalResult verifyOIListElements(SMLoc loc,
+                                     ArrayRef<FormatElement *> elements);
+
   /// Given the values of an `AllTypesMatch` trait, check for inferable type
   /// resolution.
   void handleAllTypesMatchConstraint(
@@ -2087,6 +2209,8 @@ class OpFormatParser : public FormatParser {
                                                     bool withKeyword);
   FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
                                                           Context context);
+  FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
+  LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
   FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
   FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
                                                      Context context);
@@ -2157,7 +2281,8 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
   if (failed(verifyAttributes(loc, elements)) ||
       failed(verifyResults(loc, variableTyResolver)) ||
       failed(verifyOperands(loc, variableTyResolver)) ||
-      failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
+      failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) ||
+      failed(verifyOIListElements(loc, elements)))
     return failure();
 
   // Collect the set of used attributes in the format.
@@ -2377,6 +2502,43 @@ LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
   return success();
 }
 
+LogicalResult
+OpFormatParser::verifyOIListElements(SMLoc loc,
+                                     ArrayRef<FormatElement *> elements) {
+  // Check that all of the successors are within the format.
+  SmallVector<StringRef> prohibitedLiterals;
+  for (FormatElement *it : elements) {
+    if (auto *oilist = dyn_cast<OIListElement>(it)) {
+      if (!prohibitedLiterals.empty()) {
+        // We just saw an oilist element in last iteration. Literals should not
+        // match.
+        for (LiteralElement *literal : oilist->getLiteralElements()) {
+          if (find(prohibitedLiterals, literal->getSpelling()) !=
+              prohibitedLiterals.end()) {
+            return emitError(
+                loc, "format ambiguity because " + literal->getSpelling() +
+                         " is used in two adjacent oilist elements.");
+          }
+        }
+      }
+      for (LiteralElement *literal : oilist->getLiteralElements())
+        prohibitedLiterals.push_back(literal->getSpelling());
+    } else if (auto *literal = dyn_cast<LiteralElement>(it)) {
+      if (find(prohibitedLiterals, literal->getSpelling()) !=
+          prohibitedLiterals.end()) {
+        return emitError(
+            loc,
+            "format ambiguity because " + literal->getSpelling() +
+                " is used both in oilist element and the adjacent literal.");
+      }
+      prohibitedLiterals.clear();
+    } else {
+      prohibitedLiterals.clear();
+    }
+  }
+  return success();
+}
+
 void OpFormatParser::handleAllTypesMatchConstraint(
     ArrayRef<StringRef> values,
     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
@@ -2532,6 +2694,8 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
     return parseReferenceDirective(loc, ctx);
   case FormatToken::kw_type:
     return parseTypeDirective(loc, ctx);
+  case FormatToken::kw_oilist:
+    return parseOIListDirective(loc, ctx);
 
   default:
     return emitError(loc, "unsupported directive kind");
@@ -2675,6 +2839,91 @@ OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
   return create<SuccessorsDirective>();
 }
 
+FailureOr<FormatElement *>
+OpFormatParser::parseOIListDirective(SMLoc loc, Context context) {
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before oilist argument list")))
+    return failure();
+  std::vector<FormatElement *> literalElements;
+  std::vector<std::vector<FormatElement *>> parsingElements;
+  do {
+    FailureOr<FormatElement *> lelement = parseLiteral(context);
+    if (failed(lelement))
+      return failure();
+    literalElements.push_back(*lelement);
+    parsingElements.push_back(std::vector<FormatElement *>());
+    std::vector<FormatElement *> &currParsingElements = parsingElements.back();
+    while (peekToken().getKind() != FormatToken::pipe &&
+           peekToken().getKind() != FormatToken::r_paren) {
+      FailureOr<FormatElement *> pelement = parseElement(context);
+      if (failed(pelement) ||
+          failed(verifyOIListParsingElement(*pelement, loc)))
+        return failure();
+      currParsingElements.push_back(*pelement);
+    }
+    if (peekToken().getKind() == FormatToken::pipe) {
+      consumeToken();
+      continue;
+    }
+    if (peekToken().getKind() == FormatToken::r_paren) {
+      consumeToken();
+      break;
+    }
+  } while (true);
+
+  return create<OIListElement>(std::move(literalElements),
+                               std::move(parsingElements));
+}
+
+LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element,
+                                                         SMLoc loc) {
+  return TypeSwitch<FormatElement *, LogicalResult>(element)
+      // Only optional attributes can be within an oilist parsing group.
+      .Case([&](AttributeVariable *attrEle) {
+        if (!attrEle->getVar()->attr.isOptional())
+          return emitError(loc, "only optional attributes can be used to "
+                                "in an oilist parsing group");
+        return success();
+      })
+      // Only optional-like(i.e. variadic) operands can be within an oilist
+      // parsing group.
+      .Case([&](OperandVariable *ele) {
+        if (!ele->getVar()->isVariableLength())
+          return emitError(loc, "only variable length operands can be "
+                                "used within an oilist parsing group");
+        return success();
+      })
+      // Only optional-like(i.e. variadic) results can be within an oilist
+      // parsing group.
+      .Case([&](ResultVariable *ele) {
+        if (!ele->getVar()->isVariableLength())
+          return emitError(loc, "only variable length results can be "
+                                "used within an oilist parsing group");
+        return success();
+      })
+      .Case([&](RegionVariable *) {
+        // TODO: When ODS has proper support for marking "optional" regions, add
+        // a check here.
+        return success();
+      })
+      .Case([&](TypeDirective *ele) {
+        return verifyOIListParsingElement(ele->getArg(), loc);
+      })
+      .Case([&](FunctionalTypeDirective *ele) {
+        if (failed(verifyOIListParsingElement(ele->getInputs(), loc)))
+          return failure();
+        return verifyOIListParsingElement(ele->getResults(), loc);
+      })
+      // Literals, whitespace, and custom directives may be used.
+      .Case<LiteralElement, WhitespaceElement, CustomDirective,
+            FunctionalTypeDirective, OptionalElement>(
+          [&](FormatElement *) { return success(); })
+      .Default([&](FormatElement *) {
+        return emitError(loc, "only literals, types, and variables can be "
+                              "used within an oilist group");
+      });
+}
+
 FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
                                                               Context context) {
   if (context == TypeDirectiveContext)


        


More information about the Mlir-commits mailing list