[Mlir-commits] [mlir] 88c6e25 - [mlir][OpFormatGen] Add support for specifiy "custom" directives.
    River Riddle 
    llvmlistbot at llvm.org
       
    Mon Aug 31 13:26:53 PDT 2020
    
    
  
Author: River Riddle
Date: 2020-08-31T13:26:23-07:00
New Revision: 88c6e25e4f0630bd9204cb02787fcb67e097a43a
URL: https://github.com/llvm/llvm-project/commit/88c6e25e4f0630bd9204cb02787fcb67e097a43a
DIFF: https://github.com/llvm/llvm-project/commit/88c6e25e4f0630bd9204cb02787fcb67e097a43a.diff
LOG: [mlir][OpFormatGen] Add support for specifiy "custom" directives.
This revision adds support for custom directives to the declarative assembly format. This allows for users to use C++ for printing and parsing subsections of an otherwise declaratively specified format. The custom directive is structured as follows:
```
custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
```
`user-directive` is used as a suffix when this directive is used during printing and parsing. When parsing, `parseUserDirective` will be invoked. When printing, `printUserDirective` will be invoked. The first parameter to these methods must be a reference to either the OpAsmParser, or OpAsmPrinter. The type of rest of the parameters is dependent on the `Params` specified in the assembly format.
Differential Revision: https://reviews.llvm.org/D84719
Added: 
    
Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpImplementation.h
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed: 
    
################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 418da6a857dc..167546e67522 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -664,6 +664,12 @@ The available directives are as follows:
     -   Represents the attribute dictionary of the operation, but prefixes the
         dictionary with an `attributes` keyword.
 
+*   `custom` < UserDirective > ( Params )
+
+    -   Represents a custom directive implemented by the user in C++.
+    -   See the [Custom Directives](#custom-directives) section below for more
+        details.
+
 *   `functional-type` ( inputs , results )
 
     -   Formats the `inputs` and `results` arguments as a
@@ -705,6 +711,75 @@ example above, the variables would be `$callee` and `$args`.
 Attribute variables are printed with their respective value type, unless that
 value type is buildable. In those cases, the type of the attribute is elided.
 
+#### Custom Directives
+
+The declarative assembly format specification allows for handling a large
+majority of the common cases when formatting an operation. For the operations
+that require or desire specifying parts of the operation in a form not supported
+by the declarative syntax, custom directives may be specified. A custom
+directive essentially allows for users to use C++ for printing and parsing
+subsections of an otherwise declaratively specified format. Looking at the
+specification of a custom directive above:
+
+```
+custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
+```
+
+A custom directive has two main parts: The `UserDirective` and the `Params`. A
+custom directive is transformed into a call to a `print*` and a `parse*` method
+when generating the C++ code for the format. The `UserDirective` is an
+identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
+would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
+parser and printer respectively. `Params` may be any combination of variables
+(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
+directives must refer to a variable, but that variable need not also be a
+parameter to the custom directive.
+
+The arguments to the `parse<UserDirective>` method is firstly a reference to the
+`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
+corresponding to the parameters specified in the format. The mapping of
+declarative parameter to `parse` method argument is detailed below:
+
+*   Attribute Variables
+    -   Single: `<Attribute-Storage-Type>(e.g. Attribute) &`
+    -   Optional: `<Attribute-Storage-Type>(e.g. Attribute) &`
+*   Operand Variables
+    -   Single: `OpAsmParser::OperandType &`
+    -   Optional: `Optional<OpAsmParser::OperandType> &`
+    -   Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+*   Successor Variables
+    -   Single: `Block *&`
+    -   Variadic: `SmallVectorImpl<Block *> &`
+*   Type Directives
+    -   Single: `Type &`
+    -   Optional: `Type &`
+    -   Variadic: `SmallVectorImpl<Type> &`
+
+When a variable is optional, the value should only be specified if the variable
+is present. Otherwise, the value should remain `None` or null.
+
+The arguments to the `print<UserDirective>` method is firstly a reference to the
+`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
+corresponding to the parameters specified in the format. The mapping of
+declarative parameter to `print` method argument is detailed below:
+
+*   Attribute Variables
+    -   Single: `<Attribute-Storage-Type>(e.g. Attribute)`
+    -   Optional: `<Attribute-Storage-Type>(e.g. Attribute)`
+*   Operand Variables
+    -   Single: `Value`
+    -   Optional: `Value`
+    -   Variadic: `OperandRange`
+*   Successor Variables
+    -   Single: `Block *`
+    -   Variadic: `SuccessorRange`
+*   Type Directives
+    -   Single: `Type`
+    -   Optional: `Type`
+    -   Variadic: `TypeRange`
+
+When a variable is optional, the provided value may be null.
+
 #### Optional Groups
 
 In certain situations operations may have "optional" information, e.g.
@@ -722,8 +797,8 @@ information. An optional group is defined by wrapping a set of elements within
         should be printed/parsed.
     -   An element is marked as the anchor by adding a trailing `^`.
     -   The first element is *not* required to be the anchor of the group.
-*   Literals, variables, and type directives are the only valid elements within
-    the group.
+*   Literals, variables, custom directives, and type directives are the only
+    valid elements within the group.
     -   Any attribute variable may be used, but only optional attributes can be
         marked as the anchor.
     -   Only variadic or optional operand arguments can be used.
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e0726a9901cf..5962a1a48668 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -202,6 +202,10 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
   llvm::interleaveComma(types, p);
   return p;
 }
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
+  llvm::interleaveComma(types, p);
+  return p;
+}
 inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
   llvm::interleaveComma(types, p);
   return p;
diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 8fedb088e0b7..292f5ed4b641 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -267,6 +267,108 @@ void FoldToCallOp::getCanonicalizationPatterns(
   results.insert<FoldToCallOpPattern>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// Test Format* operations
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Parsing
+
+static ParseResult parseCustomDirectiveOperands(
+    OpAsmParser &parser, OpAsmParser::OperandType &operand,
+    Optional<OpAsmParser::OperandType> &optOperand,
+    SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
+  if (parser.parseOperand(operand))
+    return failure();
+  if (succeeded(parser.parseOptionalComma())) {
+    optOperand.emplace();
+    if (parser.parseOperand(*optOperand))
+      return failure();
+  }
+  if (parser.parseArrow() || parser.parseLParen() ||
+      parser.parseOperandList(varOperands) || parser.parseRParen())
+    return failure();
+  return success();
+}
+static ParseResult
+parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+                            Type &optOperandType,
+                            SmallVectorImpl<Type> &varOperandTypes) {
+  if (parser.parseColon())
+    return failure();
+
+  if (parser.parseType(operandType))
+    return failure();
+  if (succeeded(parser.parseOptionalComma())) {
+    if (parser.parseType(optOperandType))
+      return failure();
+  }
+  if (parser.parseArrow() || parser.parseLParen() ||
+      parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+    return failure();
+  return success();
+}
+static ParseResult parseCustomDirectiveOperandsAndTypes(
+    OpAsmParser &parser, OpAsmParser::OperandType &operand,
+    Optional<OpAsmParser::OperandType> &optOperand,
+    SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
+    Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
+  if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+      parseCustomDirectiveResults(parser, operandType, optOperandType,
+                                  varOperandTypes))
+    return failure();
+  return success();
+}
+static ParseResult
+parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+                               SmallVectorImpl<Block *> &varSuccessors) {
+  if (parser.parseSuccessor(successor))
+    return failure();
+  if (failed(parser.parseOptionalComma()))
+    return success();
+  Block *varSuccessor;
+  if (parser.parseSuccessor(varSuccessor))
+    return failure();
+  varSuccessors.append(2, varSuccessor);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing
+
+static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
+                                         Value optOperand,
+                                         OperandRange varOperands) {
+  printer << operand;
+  if (optOperand)
+    printer << ", " << optOperand;
+  printer << " -> (" << varOperands << ")";
+}
+static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
+                                        Type optOperandType,
+                                        TypeRange varOperandTypes) {
+  printer << " : " << operandType;
+  if (optOperandType)
+    printer << ", " << optOperandType;
+  printer << " -> (" << varOperandTypes << ")";
+}
+static void
+printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
+                                     Value optOperand, OperandRange varOperands,
+                                     Type operandType, Type optOperandType,
+                                     TypeRange varOperandTypes) {
+  printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
+  printCustomDirectiveResults(printer, operandType, optOperandType,
+                              varOperandTypes);
+}
+static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
+                                           Block *successor,
+                                           SuccessorRange varSuccessors) {
+  printer << successor;
+  if (!varSuccessors.empty())
+    printer << ", " << varSuccessors.front();
+}
+
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 022732d55016..bbe246a011af 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1414,8 +1414,60 @@ def FormatOptionalUnitAttrNoElide
 }
 
 //===----------------------------------------------------------------------===//
-// AllTypesMatch type inference
+// Custom Directives
+
+def FormatCustomDirectiveOperands
+    : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
+  let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
+                       Variadic<I64>:$varOperands);
+  let assemblyFormat = [{
+    custom<CustomDirectiveOperands>(
+      $operand, $optOperand, $varOperands
+    )
+    attr-dict
+  }];
+}
+
+def FormatCustomDirectiveOperandsAndTypes
+    : TEST_Op<"format_custom_directive_operands_and_types",
+              [AttrSizedOperandSegments]> {
+  let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
+                       Variadic<AnyType>:$varOperands);
+  let assemblyFormat = [{
+    custom<CustomDirectiveOperandsAndTypes>(
+      $operand, $optOperand, $varOperands,
+      type($operand), type($optOperand), type($varOperands)
+    )
+    attr-dict
+  }];
+}
+
+def FormatCustomDirectiveResults
+    : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
+  let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+                      Variadic<AnyType>:$varResults);
+  let assemblyFormat = [{
+    custom<CustomDirectiveResults>(
+      type($result), type($optResult), type($varResults)
+    )
+    attr-dict
+  }];
+}
+
+def FormatCustomDirectiveSuccessors
+    : TEST_Op<"format_custom_directive_successors", [Terminator]> {
+  let successors = (successor AnySuccessor:$successor,
+                              VariadicSuccessor<AnySuccessor>:$successors);
+  let assemblyFormat = [{
+    custom<CustomDirectiveSuccessors>(
+      $successor, $successors
+    )
+    attr-dict
+  }];
+}
+
 //===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
 
 def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
     AllTypesMatch<["value1", "value2", "result"]>
@@ -1435,7 +1487,6 @@ def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
 
 //===----------------------------------------------------------------------===//
 // TypesMatchWith type inference
-//===----------------------------------------------------------------------===//
 
 def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
     TypesMatchWith<"result type matches operand", "value", "result", "$_self">
diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 3a3c500d76b3..8e7c8ec56a2a 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -42,6 +42,49 @@ def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
   attr-dict-with-keyword
 }]>;
 
+//===----------------------------------------------------------------------===//
+// custom
+
+// CHECK: error: expected '<' before custom directive name
+def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{
+  custom(
+}]>;
+// CHECK: error: expected custom directive name identifier
+def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{
+  custom<>
+}]>;
+// CHECK: error: expected '>' after custom directive name
+def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{
+  custom<MyDirective(
+}]>;
+// CHECK: error: expected '(' before custom directive parameters
+def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{
+  custom<MyDirective>)
+}]>;
+// CHECK: error: only variables and types may be used as parameters to a custom directive
+def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{
+  custom<MyDirective>(operands)
+}]>;
+// CHECK: error: expected ')' after custom directive parameters
+def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{
+  custom<MyDirective>($operand<
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: type directives within a custom directive may only refer to variables
+def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{
+  custom<MyDirective>(type(operands))
+}]>;
+
+// CHECK-NOT: error
+def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{
+  custom<MyDirective>($operand) attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
+def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{
+  custom<MyDirective>($operand, type($operand), type($result)) attr-dict
+}]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>;
+def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
+  custom<MyDirective>($attr) attr-dict
+}]>, Arguments<(ins I64Attr:$attr)>;
+
 //===----------------------------------------------------------------------===//
 // functional-type
 
@@ -238,6 +281,10 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
 def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
   ($arg^)
 }]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: only variables can be used to anchor an optional group
+def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
+  (custom<MyDirective>($arg)^)?
+}]>, Arguments<(ins I64:$arg)>;
 
 //===----------------------------------------------------------------------===//
 // Variables
diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 959bbdc5c6bb..96a923ba81ec 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -122,6 +122,40 @@ test.format_optional_operand_result_b_op( : ) : i64
 // CHECK: test.format_optional_operand_result_b_op : i64
 test.format_optional_operand_result_b_op : i64
 
+//===----------------------------------------------------------------------===//
+// Format custom directives
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_custom_directive_operands %[[I64]], %[[I64]] -> (%[[I64]])
+test.format_custom_directive_operands %i64, %i64 -> (%i64)
+
+// CHECK: test.format_custom_directive_operands %[[I64]] -> (%[[I64]])
+test.format_custom_directive_operands %i64 -> (%i64)
+
+// CHECK: test.format_custom_directive_operands_and_types %[[I64]], %[[I64]] -> (%[[I64]]) : i64, i64 -> (i64)
+test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
+test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
+test.format_custom_directive_results : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results : i64 -> (i64)
+test.format_custom_directive_results : i64 -> (i64)
+
+func @foo() {
+  // CHECK: test.format_custom_directive_successors ^bb1, ^bb2
+  test.format_custom_directive_successors ^bb1, ^bb2
+
+^bb1:
+  // CHECK: test.format_custom_directive_successors ^bb2
+  test.format_custom_directive_successors ^bb2
+
+^bb2:
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Format trait type inference
 //===----------------------------------------------------------------------===//
diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3cce5262b50d..d3fdc9ec323e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -45,6 +45,7 @@ class Element {
   enum class Kind {
     /// This element is a directive.
     AttrDictDirective,
+    CustomDirective,
     FunctionalTypeDirective,
     OperandsDirective,
     ResultsDirective,
@@ -132,8 +133,7 @@ using SuccessorVariable =
 
 namespace {
 /// This class implements single kind directives.
-template <Element::Kind type>
-class DirectiveElement : public Element {
+template <Element::Kind type> class DirectiveElement : public Element {
 public:
   DirectiveElement() : Element(type){};
   static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -164,6 +164,33 @@ class AttrDictDirective
   bool withKeyword;
 };
 
+/// This class represents a custom format directive that is implemented by the
+/// user in C++.
+class CustomDirective : public Element {
+public:
+  CustomDirective(StringRef name,
+                  std::vector<std::unique_ptr<Element>> &&arguments)
+      : Element{Kind::CustomDirective}, name(name),
+        arguments(std::move(arguments)) {}
+
+  static bool classof(const Element *element) {
+    return element->getKind() == Kind::CustomDirective;
+  }
+
+  /// Return the name of this optional element.
+  StringRef getName() const { return name; }
+
+  /// Return the arguments to the custom directive.
+  auto getArguments() const { return llvm::make_pointee_range(arguments); }
+
+private:
+  /// The user provided name of the directive.
+  StringRef name;
+
+  /// The arguments to the custom directive.
+  std::vector<std::unique_ptr<Element>> arguments;
+};
+
 /// This class represents the `functional-type` directive. This directive takes
 /// two arguments and formats them, respectively, as the inputs and results of a
 /// FunctionType.
@@ -370,19 +397,16 @@ static bool canFormatEnumAttr(const NamedAttribute *attr) {
 
 /// The code snippet used to generate a parser call for an attribute.
 ///
-/// {0}: The storage type of the attribute.
-/// {1}: The name of the attribute.
-/// {2}: The type for the attribute.
+/// {0}: The name of the attribute.
+/// {1}: The type for the attribute.
 const char *const attrParserCode = R"(
-  {0} {1}Attr;
-  if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
+  if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
     return failure();
 )";
 const char *const optionalAttrParserCode = R"(
-  {0} {1}Attr;
   {
     ::mlir::OptionalParseResult parseResult =
-      parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
+      parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
     if (parseResult.hasValue() && failed(*parseResult))
       return failure();
   }
@@ -408,11 +432,11 @@ const char *const enumAttrParserCode = R"(
       return parser.emitError(loc, "invalid ")
              << "{0} attribute specification: " << attrVal;
 
-    result.addAttribute("{0}", {3});
+    {0}Attr = {3};
+    result.addAttribute("{0}", {0}Attr);
   }
 )";
 const char *const optionalEnumAttrParserCode = R"(
-  Attribute {0}Attr;
   {
     ::mlir::StringAttr attrVal;
     ::mlir::NamedAttrList attrStorage;
@@ -440,11 +464,13 @@ const char *const optionalEnumAttrParserCode = R"(
 ///
 /// {0}: The name of the operand.
 const char *const variadicOperandParserCode = R"(
+  {0}OperandsLoc = parser.getCurrentLocation();
   if (parser.parseOperandList({0}Operands))
     return failure();
 )";
 const char *const optionalOperandParserCode = R"(
   {
+    {0}OperandsLoc = parser.getCurrentLocation();
     ::mlir::OpAsmParser::OperandType operand;
     ::mlir::OptionalParseResult parseResult =
                                     parser.parseOptionalOperand(operand);
@@ -456,6 +482,7 @@ const char *const optionalOperandParserCode = R"(
   }
 )";
 const char *const operandParserCode = R"(
+  {0}OperandsLoc = parser.getCurrentLocation();
   if (parser.parseOperand({0}RawOperands[0]))
     return failure();
 )";
@@ -500,7 +527,6 @@ const char *const functionalTypeParserCode = R"(
 ///
 /// {0}: The name for the successor list.
 const char *successorListParserCode = R"(
-  ::llvm::SmallVector<::mlir::Block *, 2> {0}Successors;
   {
     ::mlir::Block *succ;
     auto firstSucc = parser.parseOptionalSuccessor(succ);
@@ -523,7 +549,6 @@ const char *successorListParserCode = R"(
 ///
 /// {0}: The name of the successor.
 const char *successorParserCode = R"(
-  ::mlir::Block *{0}Successor = nullptr;
   if (parser.parseSuccessor({0}Successor))
     return failure();
 )";
@@ -595,8 +620,34 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
 /// Generate the storage code required for parsing the given element.
 static void genElementParserStorage(Element *element, OpMethodBody &body) {
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
-    for (auto &childElement : optional->getElements())
-      genElementParserStorage(&childElement, body);
+    auto elements = optional->getElements();
+
+    // If the anchor is a unit attribute, it won't be parsed directly so elide
+    // it.
+    auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
+    Element *elidedAnchorElement = nullptr;
+    if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
+      elidedAnchorElement = anchor;
+    for (auto &childElement : elements)
+      if (&childElement != elidedAnchorElement)
+        genElementParserStorage(&childElement, body);
+
+  } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
+    for (auto ¶mElement : custom->getArguments())
+      genElementParserStorage(¶mElement, body);
+
+  } else if (isa<OperandsDirective>(element)) {
+    body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
+            "allOperands;\n";
+
+  } else if (isa<SuccessorsDirective>(element)) {
+    body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
+
+  } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+    const NamedAttribute *var = attr->getVar();
+    body << llvm::formatv("  {0} {1}Attr;\n", var->attr.getStorageType(),
+                          var->name);
+
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     StringRef name = operand->getVar()->name;
     if (operand->getVar()->isVariableLength()) {
@@ -608,10 +659,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
            << "  ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
            << "Operands(" << name << "RawOperands);";
     }
-    body << llvm::formatv(
-        "  ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n"
-        "  (void){0}OperandsLoc;\n",
-        name);
+    body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
+                          "  (void){0}OperandsLoc;\n",
+                          name);
+  } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+    StringRef name = successor->getVar()->name;
+    if (successor->getVar()->isVariadic()) {
+      body << llvm::formatv("  ::llvm::SmallVector<::mlir::Block *, 2> "
+                            "{0}Successors;\n",
+                            name);
+    } else {
+      body << llvm::formatv("  ::mlir::Block *{0}Successor = nullptr;\n", name);
+    }
+
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     ArgumentLengthKind lengthKind;
     StringRef name = getTypeListName(dir->getOperand(), lengthKind);
@@ -631,6 +691,106 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
   }
 }
 
+/// Generate the parser for a parameter to a custom directive.
+static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
+  body << ", ";
+  if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+    body << attr->getVar()->name << "Attr";
+
+  } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+    StringRef name = operand->getVar()->name;
+    ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
+    if (lengthKind == ArgumentLengthKind::Variadic)
+      body << llvm::formatv("{0}Operands", name);
+    else if (lengthKind == ArgumentLengthKind::Optional)
+      body << llvm::formatv("{0}Operand", name);
+    else
+      body << formatv("{0}RawOperands[0]", name);
+
+  } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+    StringRef name = successor->getVar()->name;
+    if (successor->getVar()->isVariadic())
+      body << llvm::formatv("{0}Successors", name);
+    else
+      body << llvm::formatv("{0}Successor", name);
+
+  } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+    ArgumentLengthKind lengthKind;
+    StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+    if (lengthKind == ArgumentLengthKind::Variadic)
+      body << llvm::formatv("{0}Types", listName);
+    else if (lengthKind == ArgumentLengthKind::Optional)
+      body << llvm::formatv("{0}Type", listName);
+    else
+      body << formatv("{0}RawTypes[0]", listName);
+  } else {
+    llvm_unreachable("unknown custom directive parameter");
+  }
+}
+
+/// Generate the parser for a custom directive.
+static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
+  body << "  {\n";
+
+  // Preprocess the directive variables.
+  // * Add a local variable for optional operands and types. This provides a
+  //   better API to the user defined parser methods.
+  // * Set the location of operand variables.
+  for (Element ¶m : dir->getArguments()) {
+    if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+      body << "    " << operand->getVar()->name
+           << "OperandsLoc = parser.getCurrentLocation();\n";
+      if (operand->getVar()->isOptional()) {
+        body << llvm::formatv(
+            "    llvm::Optional<::mlir::OpAsmParser::OperandType> "
+            "{0}Operand;\n",
+            operand->getVar()->name);
+      }
+    } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+      ArgumentLengthKind lengthKind;
+      StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+      if (lengthKind == ArgumentLengthKind::Optional)
+        body << llvm::formatv("    ::mlir::Type {0}Type;\n", listName);
+    }
+  }
+
+  body << "    if (parse" << dir->getName() << "(parser";
+  for (Element ¶m : dir->getArguments())
+    genCustomParameterParser(param, body);
+
+  body << "))\n"
+       << "      return failure();\n";
+
+  // After parsing, add handling for any of the optional constructs.
+  for (Element ¶m : dir->getArguments()) {
+    if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+      const NamedAttribute *var = attr->getVar();
+      if (var->attr.isOptional())
+        body << llvm::formatv("    if ({0}Attr)\n  ", var->name);
+
+      body << llvm::formatv(
+          "    result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name);
+    } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+      const NamedTypeConstraint *var = operand->getVar();
+      if (!var->isOptional())
+        continue;
+      body << llvm::formatv("    if ({0}Operand.hasValue())\n"
+                            "      {0}Operands.push_back(*{0}Operand);\n",
+                            var->name);
+    } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+      ArgumentLengthKind lengthKind;
+      StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+      if (lengthKind == ArgumentLengthKind::Optional) {
+        body << llvm::formatv("    if ({0}Type)\n"
+                              "      {0}Types.push_back({0}Type);\n",
+                              listName);
+      }
+    }
+  }
+
+  body << "  }\n";
+}
+
 /// Generate the parser for a single format element.
 static void genElementParser(Element *element, OpMethodBody &body,
                              FmtContext &attrTypeCtx) {
@@ -711,7 +871,7 @@ static void genElementParser(Element *element, OpMethodBody &body,
 
     body << formatv(var->attr.isOptional() ? optionalAttrParserCode
                                            : attrParserCode,
-                    var->attr.getStorageType(), var->name, attrTypeStr);
+                    var->name, attrTypeStr);
   } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
     ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
     StringRef name = operand->getVar()->name;
@@ -732,10 +892,11 @@ static void genElementParser(Element *element, OpMethodBody &body,
          << (attrDict->isWithKeyword() ? "WithKeyword" : "")
          << "(result.attributes))\n"
          << "    return failure();\n";
+  } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
+    genCustomDirectiveParser(customDir, body);
+
   } else if (isa<OperandsDirective>(element)) {
     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
-         << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
-            "allOperands;\n"
          << "  if (parser.parseOperandList(allOperands))\n"
          << "    return failure();\n";
   } else if (isa<SuccessorsDirective>(element)) {
@@ -980,6 +1141,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
     llvm::interleaveComma(op.getOperands(), body, interleaveFn);
     body << "}));\n";
   }
+
+  if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) {
+    body << "  result.addAttribute(\"result_segment_sizes\", "
+         << "parser.getBuilder().getI32VectorAttr({";
+    auto interleaveFn = [&](const NamedTypeConstraint &result) {
+      // If the result is variadic emit the parsed size.
+      if (result.isVariableLength())
+        body << "static_cast<int32_t>(" << result.name << "Types.size())";
+      else
+        body << "1";
+    };
+    llvm::interleaveComma(op.getResults(), body, interleaveFn);
+    body << "}));\n";
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -1007,6 +1182,8 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
   // Elide the variadic segment size attributes if necessary.
   if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
     body << "\"operand_segment_sizes\", ";
+  if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
+    body << "\"result_segment_sizes\", ";
   llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
     body << "\"" << attr->name << "\"";
   });
@@ -1038,6 +1215,42 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
 }
 
+/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
+/// space should be emitted before this element. `lastWasPunctuation` is true if
+/// the previous element was a punctuation literal.
+static void genCustomDirectivePrinter(CustomDirective *customDir,
+                                      OpMethodBody &body) {
+  body << "  print" << customDir->getName() << "(p";
+  for (Element ¶m : customDir->getArguments()) {
+    body << ", ";
+    if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+      body << attr->getVar()->name << "Attr()";
+
+    } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+      body << operand->getVar()->name << "()";
+
+    } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+      body << successor->getVar()->name << "()";
+
+    } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+      auto *typeOperand = dir->getOperand();
+      auto *operand = dyn_cast<OperandVariable>(typeOperand);
+      auto *var = operand ? operand->getVar()
+                          : cast<ResultVariable>(typeOperand)->getVar();
+      if (var->isVariadic())
+        body << var->name << "().getTypes()";
+      else if (var->isOptional())
+        body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+      else
+        body << var->name << "().getType()";
+    } else {
+      llvm_unreachable("unknown custom directive parameter");
+    }
+  }
+
+  body << ");\n";
+}
+
 /// Generate the C++ for an operand to a (*-)type directive.
 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
   if (isa<OperandsDirective>(arg))
@@ -1145,6 +1358,8 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
       body << "  ::llvm::interleaveComma(" << var->name << "(), p);\n";
     else
       body << "  p << " << var->name << "();\n";
+  } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
+    genCustomDirectivePrinter(dir, body);
   } else if (isa<OperandsDirective>(element)) {
     body << "  p << getOperation()->getOperands();\n";
   } else if (isa<SuccessorsDirective>(element)) {
@@ -1202,12 +1417,15 @@ class Token {
     caret,
     comma,
     equal,
+    less,
+    greater,
     question,
 
     // Keywords.
     keyword_start,
     kw_attr_dict,
     kw_attr_dict_w_keyword,
+    kw_custom,
     kw_functional_type,
     kw_operands,
     kw_results,
@@ -1353,6 +1571,10 @@ Token FormatLexer::lexToken() {
     return formToken(Token::comma, tokStart);
   case '=':
     return formToken(Token::equal, tokStart);
+  case '<':
+    return formToken(Token::less, tokStart);
+  case '>':
+    return formToken(Token::greater, tokStart);
   case '?':
     return formToken(Token::question, tokStart);
   case '(':
@@ -1406,6 +1628,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
       llvm::StringSwitch<Token::Kind>(str)
           .Case("attr-dict", Token::kw_attr_dict)
           .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
+          .Case("custom", Token::kw_custom)
           .Case("functional-type", Token::kw_functional_type)
           .Case("operands", Token::kw_operands)
           .Case("results", Token::kw_results)
@@ -1421,8 +1644,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
 
 /// Function to find an element within the given range that has the same name as
 /// 'name'.
-template <typename RangeT>
-static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
   return it != range.end() ? &*it : nullptr;
 }
@@ -1513,6 +1735,10 @@ class FormatParser {
   LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
                                        llvm::SMLoc loc, bool isTopLevel,
                                        bool withKeyword);
+  LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
+                                     llvm::SMLoc loc, bool isTopLevel);
+  LogicalResult parseCustomDirectiveParameter(
+      std::vector<std::unique_ptr<Element>> ¶meters);
   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
                                              Token tok, bool isTopLevel);
   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
@@ -1930,6 +2156,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
   case Token::kw_attr_dict_w_keyword:
     return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
                                   /*withKeyword=*/true);
+  case Token::kw_custom:
+    return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_functional_type:
     return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
   case Token::kw_operands:
@@ -2054,15 +2282,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
         seenVariables.insert(ele->getVar());
         return success();
       })
-      // Literals and type directives may be used, but they can't anchor the
-      // group.
-      .Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
-          [&](Element *) {
-            if (isAnchor)
-              return emitError(childLoc, "only variables can be used to anchor "
-                                         "an optional group");
-            return success();
-          })
+      // Literals, custom directives, and type directives may be used,
+      // but they can't anchor the group.
+      .Case<LiteralElement, CustomDirective, TypeDirective,
+            FunctionalTypeDirective>([&](Element *) {
+        if (isAnchor)
+          return emitError(childLoc, "only variables can be used to anchor "
+                                     "an optional group");
+        return success();
+      })
       .Default([&](Element *) {
         return emitError(childLoc, "only literals, types, and variables can be "
                                    "used within an optional group");
@@ -2084,6 +2312,71 @@ FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
   return success();
 }
 
+LogicalResult
+FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
+                                   llvm::SMLoc loc, bool isTopLevel) {
+  llvm::SMLoc curLoc = curToken.getLoc();
+
+  // Parse the custom directive name.
+  if (failed(
+          parseToken(Token::less, "expected '<' before custom directive name")))
+    return failure();
+
+  Token nameTok = curToken;
+  if (failed(parseToken(Token::identifier,
+                        "expected custom directive name identifier")) ||
+      failed(parseToken(Token::greater,
+                        "expected '>' after custom directive name")) ||
+      failed(parseToken(Token::l_paren,
+                        "expected '(' before custom directive parameters")))
+    return failure();
+
+  // Parse the child elements for this optional group.=
+  std::vector<std::unique_ptr<Element>> elements;
+  do {
+    if (failed(parseCustomDirectiveParameter(elements)))
+      return failure();
+    if (curToken.getKind() != Token::comma)
+      break;
+    consumeToken();
+  } while (true);
+
+  if (failed(parseToken(Token::r_paren,
+                        "expected ')' after custom directive parameters")))
+    return failure();
+
+  // After parsing all of the elements, ensure that all type directives refer
+  // only to variables.
+  for (auto &ele : elements) {
+    if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+      if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
+        return emitError(curLoc, "type directives within a custom directive "
+                                 "may only refer to variables");
+      }
+    }
+  }
+
+  element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
+                                              std::move(elements));
+  return success();
+}
+
+LogicalResult FormatParser::parseCustomDirectiveParameter(
+    std::vector<std::unique_ptr<Element>> ¶meters) {
+  llvm::SMLoc childLoc = curToken.getLoc();
+  parameters.push_back({});
+  if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
+    return failure();
+
+  // Verify that the element can be placed within a custom directive.
+  if (!isa<TypeDirective, AttributeVariable, OperandVariable,
+           SuccessorVariable>(parameters.back().get())) {
+    return emitError(childLoc, "only variables and types may be used as "
+                               "parameters to a custom directive");
+  }
+  return success();
+}
+
 LogicalResult
 FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
                                            Token tok, bool isTopLevel) {
        
    
    
More information about the Mlir-commits
mailing list