[Mlir-commits] [mlir] 88c6e25 - [mlir][OpFormatGen] Add support for specifiy "custom" directives.
River Riddle
llvmlistbot at llvm.org
Mon Aug 31 13:26:53 PDT 2020
Author: River Riddle
Date: 2020-08-31T13:26:23-07:00
New Revision: 88c6e25e4f0630bd9204cb02787fcb67e097a43a
URL: https://github.com/llvm/llvm-project/commit/88c6e25e4f0630bd9204cb02787fcb67e097a43a
DIFF: https://github.com/llvm/llvm-project/commit/88c6e25e4f0630bd9204cb02787fcb67e097a43a.diff
LOG: [mlir][OpFormatGen] Add support for specifiy "custom" directives.
This revision adds support for custom directives to the declarative assembly format. This allows for users to use C++ for printing and parsing subsections of an otherwise declaratively specified format. The custom directive is structured as follows:
```
custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
```
`user-directive` is used as a suffix when this directive is used during printing and parsing. When parsing, `parseUserDirective` will be invoked. When printing, `printUserDirective` will be invoked. The first parameter to these methods must be a reference to either the OpAsmParser, or OpAsmPrinter. The type of rest of the parameters is dependent on the `Params` specified in the assembly format.
Differential Revision: https://reviews.llvm.org/D84719
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpImplementation.h
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 418da6a857dc..167546e67522 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -664,6 +664,12 @@ The available directives are as follows:
- Represents the attribute dictionary of the operation, but prefixes the
dictionary with an `attributes` keyword.
+* `custom` < UserDirective > ( Params )
+
+ - Represents a custom directive implemented by the user in C++.
+ - See the [Custom Directives](#custom-directives) section below for more
+ details.
+
* `functional-type` ( inputs , results )
- Formats the `inputs` and `results` arguments as a
@@ -705,6 +711,75 @@ example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided.
+#### Custom Directives
+
+The declarative assembly format specification allows for handling a large
+majority of the common cases when formatting an operation. For the operations
+that require or desire specifying parts of the operation in a form not supported
+by the declarative syntax, custom directives may be specified. A custom
+directive essentially allows for users to use C++ for printing and parsing
+subsections of an otherwise declaratively specified format. Looking at the
+specification of a custom directive above:
+
+```
+custom-directive ::= `custom` `<` UserDirective `>` `(` Params `)`
+```
+
+A custom directive has two main parts: The `UserDirective` and the `Params`. A
+custom directive is transformed into a call to a `print*` and a `parse*` method
+when generating the C++ code for the format. The `UserDirective` is an
+identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
+would result in calls to `parseMyDirective` and `printMyDirective` wihtin the
+parser and printer respectively. `Params` may be any combination of variables
+(i.e. Attribute, Operand, Successor, etc.) and type directives. The type
+directives must refer to a variable, but that variable need not also be a
+parameter to the custom directive.
+
+The arguments to the `parse<UserDirective>` method is firstly a reference to the
+`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
+corresponding to the parameters specified in the format. The mapping of
+declarative parameter to `parse` method argument is detailed below:
+
+* Attribute Variables
+ - Single: `<Attribute-Storage-Type>(e.g. Attribute) &`
+ - Optional: `<Attribute-Storage-Type>(e.g. Attribute) &`
+* Operand Variables
+ - Single: `OpAsmParser::OperandType &`
+ - Optional: `Optional<OpAsmParser::OperandType> &`
+ - Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+* Successor Variables
+ - Single: `Block *&`
+ - Variadic: `SmallVectorImpl<Block *> &`
+* Type Directives
+ - Single: `Type &`
+ - Optional: `Type &`
+ - Variadic: `SmallVectorImpl<Type> &`
+
+When a variable is optional, the value should only be specified if the variable
+is present. Otherwise, the value should remain `None` or null.
+
+The arguments to the `print<UserDirective>` method is firstly a reference to the
+`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters
+corresponding to the parameters specified in the format. The mapping of
+declarative parameter to `print` method argument is detailed below:
+
+* Attribute Variables
+ - Single: `<Attribute-Storage-Type>(e.g. Attribute)`
+ - Optional: `<Attribute-Storage-Type>(e.g. Attribute)`
+* Operand Variables
+ - Single: `Value`
+ - Optional: `Value`
+ - Variadic: `OperandRange`
+* Successor Variables
+ - Single: `Block *`
+ - Variadic: `SuccessorRange`
+* Type Directives
+ - Single: `Type`
+ - Optional: `Type`
+ - Variadic: `TypeRange`
+
+When a variable is optional, the provided value may be null.
+
#### Optional Groups
In certain situations operations may have "optional" information, e.g.
@@ -722,8 +797,8 @@ information. An optional group is defined by wrapping a set of elements within
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.
-* Literals, variables, and type directives are the only valid elements within
- the group.
+* Literals, variables, custom directives, and type directives are the only
+ valid elements within the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic or optional operand arguments can be used.
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index e0726a9901cf..5962a1a48668 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -202,6 +202,10 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
llvm::interleaveComma(types, p);
return p;
}
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
+ llvm::interleaveComma(types, p);
+ return p;
+}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
llvm::interleaveComma(types, p);
return p;
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 8fedb088e0b7..292f5ed4b641 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -267,6 +267,108 @@ void FoldToCallOp::getCanonicalizationPatterns(
results.insert<FoldToCallOpPattern>(context);
}
+//===----------------------------------------------------------------------===//
+// Test Format* operations
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Parsing
+
+static ParseResult parseCustomDirectiveOperands(
+ OpAsmParser &parser, OpAsmParser::OperandType &operand,
+ Optional<OpAsmParser::OperandType> &optOperand,
+ SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
+ if (parser.parseOperand(operand))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ optOperand.emplace();
+ if (parser.parseOperand(*optOperand))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseOperandList(varOperands) || parser.parseRParen())
+ return failure();
+ return success();
+}
+static ParseResult
+parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
+ Type &optOperandType,
+ SmallVectorImpl<Type> &varOperandTypes) {
+ if (parser.parseColon())
+ return failure();
+
+ if (parser.parseType(operandType))
+ return failure();
+ if (succeeded(parser.parseOptionalComma())) {
+ if (parser.parseType(optOperandType))
+ return failure();
+ }
+ if (parser.parseArrow() || parser.parseLParen() ||
+ parser.parseTypeList(varOperandTypes) || parser.parseRParen())
+ return failure();
+ return success();
+}
+static ParseResult parseCustomDirectiveOperandsAndTypes(
+ OpAsmParser &parser, OpAsmParser::OperandType &operand,
+ Optional<OpAsmParser::OperandType> &optOperand,
+ SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
+ Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
+ if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
+ parseCustomDirectiveResults(parser, operandType, optOperandType,
+ varOperandTypes))
+ return failure();
+ return success();
+}
+static ParseResult
+parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
+ SmallVectorImpl<Block *> &varSuccessors) {
+ if (parser.parseSuccessor(successor))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ Block *varSuccessor;
+ if (parser.parseSuccessor(varSuccessor))
+ return failure();
+ varSuccessors.append(2, varSuccessor);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Printing
+
+static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
+ Value optOperand,
+ OperandRange varOperands) {
+ printer << operand;
+ if (optOperand)
+ printer << ", " << optOperand;
+ printer << " -> (" << varOperands << ")";
+}
+static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
+ Type optOperandType,
+ TypeRange varOperandTypes) {
+ printer << " : " << operandType;
+ if (optOperandType)
+ printer << ", " << optOperandType;
+ printer << " -> (" << varOperandTypes << ")";
+}
+static void
+printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
+ Value optOperand, OperandRange varOperands,
+ Type operandType, Type optOperandType,
+ TypeRange varOperandTypes) {
+ printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
+ printCustomDirectiveResults(printer, operandType, optOperandType,
+ varOperandTypes);
+}
+static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
+ Block *successor,
+ SuccessorRange varSuccessors) {
+ printer << successor;
+ if (!varSuccessors.empty())
+ printer << ", " << varSuccessors.front();
+}
+
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 022732d55016..bbe246a011af 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1414,8 +1414,60 @@ def FormatOptionalUnitAttrNoElide
}
//===----------------------------------------------------------------------===//
-// AllTypesMatch type inference
+// Custom Directives
+
+def FormatCustomDirectiveOperands
+ : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> {
+ let arguments = (ins I64:$operand, Optional<I64>:$optOperand,
+ Variadic<I64>:$varOperands);
+ let assemblyFormat = [{
+ custom<CustomDirectiveOperands>(
+ $operand, $optOperand, $varOperands
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveOperandsAndTypes
+ : TEST_Op<"format_custom_directive_operands_and_types",
+ [AttrSizedOperandSegments]> {
+ let arguments = (ins AnyType:$operand, Optional<AnyType>:$optOperand,
+ Variadic<AnyType>:$varOperands);
+ let assemblyFormat = [{
+ custom<CustomDirectiveOperandsAndTypes>(
+ $operand, $optOperand, $varOperands,
+ type($operand), type($optOperand), type($varOperands)
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveResults
+ : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
+ let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
+ Variadic<AnyType>:$varResults);
+ let assemblyFormat = [{
+ custom<CustomDirectiveResults>(
+ type($result), type($optResult), type($varResults)
+ )
+ attr-dict
+ }];
+}
+
+def FormatCustomDirectiveSuccessors
+ : TEST_Op<"format_custom_directive_successors", [Terminator]> {
+ let successors = (successor AnySuccessor:$successor,
+ VariadicSuccessor<AnySuccessor>:$successors);
+ let assemblyFormat = [{
+ custom<CustomDirectiveSuccessors>(
+ $successor, $successors
+ )
+ attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
+// AllTypesMatch type inference
def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [
AllTypesMatch<["value1", "value2", "result"]>
@@ -1435,7 +1487,6 @@ def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [
//===----------------------------------------------------------------------===//
// TypesMatchWith type inference
-//===----------------------------------------------------------------------===//
def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [
TypesMatchWith<"result type matches operand", "value", "result", "$_self">
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 3a3c500d76b3..8e7c8ec56a2a 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -42,6 +42,49 @@ def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{
attr-dict-with-keyword
}]>;
+//===----------------------------------------------------------------------===//
+// custom
+
+// CHECK: error: expected '<' before custom directive name
+def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{
+ custom(
+}]>;
+// CHECK: error: expected custom directive name identifier
+def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{
+ custom<>
+}]>;
+// CHECK: error: expected '>' after custom directive name
+def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{
+ custom<MyDirective(
+}]>;
+// CHECK: error: expected '(' before custom directive parameters
+def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{
+ custom<MyDirective>)
+}]>;
+// CHECK: error: only variables and types may be used as parameters to a custom directive
+def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{
+ custom<MyDirective>(operands)
+}]>;
+// CHECK: error: expected ')' after custom directive parameters
+def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{
+ custom<MyDirective>($operand<
+}]>, Arguments<(ins I64:$operand)>;
+// CHECK: error: type directives within a custom directive may only refer to variables
+def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{
+ custom<MyDirective>(type(operands))
+}]>;
+
+// CHECK-NOT: error
+def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{
+ custom<MyDirective>($operand) attr-dict
+}]>, Arguments<(ins Optional<I64>:$operand)>;
+def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{
+ custom<MyDirective>($operand, type($operand), type($result)) attr-dict
+}]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>;
+def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{
+ custom<MyDirective>($attr) attr-dict
+}]>, Arguments<(ins I64Attr:$attr)>;
+
//===----------------------------------------------------------------------===//
// functional-type
@@ -238,6 +281,10 @@ def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{
def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{
($arg^)
}]>, Arguments<(ins Variadic<I64>:$arg)>;
+// CHECK: error: only variables can be used to anchor an optional group
+def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
+ (custom<MyDirective>($arg)^)?
+}]>, Arguments<(ins I64:$arg)>;
//===----------------------------------------------------------------------===//
// Variables
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 959bbdc5c6bb..96a923ba81ec 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -122,6 +122,40 @@ test.format_optional_operand_result_b_op( : ) : i64
// CHECK: test.format_optional_operand_result_b_op : i64
test.format_optional_operand_result_b_op : i64
+//===----------------------------------------------------------------------===//
+// Format custom directives
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_custom_directive_operands %[[I64]], %[[I64]] -> (%[[I64]])
+test.format_custom_directive_operands %i64, %i64 -> (%i64)
+
+// CHECK: test.format_custom_directive_operands %[[I64]] -> (%[[I64]])
+test.format_custom_directive_operands %i64 -> (%i64)
+
+// CHECK: test.format_custom_directive_operands_and_types %[[I64]], %[[I64]] -> (%[[I64]]) : i64, i64 -> (i64)
+test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
+test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
+test.format_custom_directive_results : i64, i64 -> (i64)
+
+// CHECK: test.format_custom_directive_results : i64 -> (i64)
+test.format_custom_directive_results : i64 -> (i64)
+
+func @foo() {
+ // CHECK: test.format_custom_directive_successors ^bb1, ^bb2
+ test.format_custom_directive_successors ^bb1, ^bb2
+
+^bb1:
+ // CHECK: test.format_custom_directive_successors ^bb2
+ test.format_custom_directive_successors ^bb2
+
+^bb2:
+ return
+}
+
//===----------------------------------------------------------------------===//
// Format trait type inference
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3cce5262b50d..d3fdc9ec323e 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -45,6 +45,7 @@ class Element {
enum class Kind {
/// This element is a directive.
AttrDictDirective,
+ CustomDirective,
FunctionalTypeDirective,
OperandsDirective,
ResultsDirective,
@@ -132,8 +133,7 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
-template <Element::Kind type>
-class DirectiveElement : public Element {
+template <Element::Kind type> class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -164,6 +164,33 @@ class AttrDictDirective
bool withKeyword;
};
+/// This class represents a custom format directive that is implemented by the
+/// user in C++.
+class CustomDirective : public Element {
+public:
+ CustomDirective(StringRef name,
+ std::vector<std::unique_ptr<Element>> &&arguments)
+ : Element{Kind::CustomDirective}, name(name),
+ arguments(std::move(arguments)) {}
+
+ static bool classof(const Element *element) {
+ return element->getKind() == Kind::CustomDirective;
+ }
+
+ /// Return the name of this optional element.
+ StringRef getName() const { return name; }
+
+ /// Return the arguments to the custom directive.
+ auto getArguments() const { return llvm::make_pointee_range(arguments); }
+
+private:
+ /// The user provided name of the directive.
+ StringRef name;
+
+ /// The arguments to the custom directive.
+ std::vector<std::unique_ptr<Element>> arguments;
+};
+
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
@@ -370,19 +397,16 @@ static bool canFormatEnumAttr(const NamedAttribute *attr) {
/// The code snippet used to generate a parser call for an attribute.
///
-/// {0}: The storage type of the attribute.
-/// {1}: The name of the attribute.
-/// {2}: The type for the attribute.
+/// {0}: The name of the attribute.
+/// {1}: The type for the attribute.
const char *const attrParserCode = R"(
- {0} {1}Attr;
- if (parser.parseAttribute({1}Attr{2}, "{1}", result.attributes))
+ if (parser.parseAttribute({0}Attr{1}, "{0}", result.attributes))
return failure();
)";
const char *const optionalAttrParserCode = R"(
- {0} {1}Attr;
{
::mlir::OptionalParseResult parseResult =
- parser.parseOptionalAttribute({1}Attr{2}, "{1}", result.attributes);
+ parser.parseOptionalAttribute({0}Attr{1}, "{0}", result.attributes);
if (parseResult.hasValue() && failed(*parseResult))
return failure();
}
@@ -408,11 +432,11 @@ const char *const enumAttrParserCode = R"(
return parser.emitError(loc, "invalid ")
<< "{0} attribute specification: " << attrVal;
- result.addAttribute("{0}", {3});
+ {0}Attr = {3};
+ result.addAttribute("{0}", {0}Attr);
}
)";
const char *const optionalEnumAttrParserCode = R"(
- Attribute {0}Attr;
{
::mlir::StringAttr attrVal;
::mlir::NamedAttrList attrStorage;
@@ -440,11 +464,13 @@ const char *const optionalEnumAttrParserCode = R"(
///
/// {0}: The name of the operand.
const char *const variadicOperandParserCode = R"(
+ {0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperandList({0}Operands))
return failure();
)";
const char *const optionalOperandParserCode = R"(
{
+ {0}OperandsLoc = parser.getCurrentLocation();
::mlir::OpAsmParser::OperandType operand;
::mlir::OptionalParseResult parseResult =
parser.parseOptionalOperand(operand);
@@ -456,6 +482,7 @@ const char *const optionalOperandParserCode = R"(
}
)";
const char *const operandParserCode = R"(
+ {0}OperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand({0}RawOperands[0]))
return failure();
)";
@@ -500,7 +527,6 @@ const char *const functionalTypeParserCode = R"(
///
/// {0}: The name for the successor list.
const char *successorListParserCode = R"(
- ::llvm::SmallVector<::mlir::Block *, 2> {0}Successors;
{
::mlir::Block *succ;
auto firstSucc = parser.parseOptionalSuccessor(succ);
@@ -523,7 +549,6 @@ const char *successorListParserCode = R"(
///
/// {0}: The name of the successor.
const char *successorParserCode = R"(
- ::mlir::Block *{0}Successor = nullptr;
if (parser.parseSuccessor({0}Successor))
return failure();
)";
@@ -595,8 +620,34 @@ static void genLiteralParser(StringRef value, OpMethodBody &body) {
/// Generate the storage code required for parsing the given element.
static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- for (auto &childElement : optional->getElements())
- genElementParserStorage(&childElement, body);
+ auto elements = optional->getElements();
+
+ // If the anchor is a unit attribute, it won't be parsed directly so elide
+ // it.
+ auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
+ Element *elidedAnchorElement = nullptr;
+ if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
+ elidedAnchorElement = anchor;
+ for (auto &childElement : elements)
+ if (&childElement != elidedAnchorElement)
+ genElementParserStorage(&childElement, body);
+
+ } else if (auto *custom = dyn_cast<CustomDirective>(element)) {
+ for (auto ¶mElement : custom->getArguments())
+ genElementParserStorage(¶mElement, body);
+
+ } else if (isa<OperandsDirective>(element)) {
+ body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
+ "allOperands;\n";
+
+ } else if (isa<SuccessorsDirective>(element)) {
+ body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
+
+ } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
+ const NamedAttribute *var = attr->getVar();
+ body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
+ var->name);
+
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
if (operand->getVar()->isVariableLength()) {
@@ -608,10 +659,19 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> " << name
<< "Operands(" << name << "RawOperands);";
}
- body << llvm::formatv(
- " ::llvm::SMLoc {0}OperandsLoc = parser.getCurrentLocation();\n"
- " (void){0}OperandsLoc;\n",
- name);
+ body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
+ " (void){0}OperandsLoc;\n",
+ name);
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
+ StringRef name = successor->getVar()->name;
+ if (successor->getVar()->isVariadic()) {
+ body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
+ "{0}Successors;\n",
+ name);
+ } else {
+ body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
+ }
+
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef name = getTypeListName(dir->getOperand(), lengthKind);
@@ -631,6 +691,106 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
}
}
+/// Generate the parser for a parameter to a custom directive.
+static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
+ body << ", ";
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ body << attr->getVar()->name << "Attr";
+
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ StringRef name = operand->getVar()->name;
+ ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv("{0}Operands", name);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv("{0}Operand", name);
+ else
+ body << formatv("{0}RawOperands[0]", name);
+
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+ StringRef name = successor->getVar()->name;
+ if (successor->getVar()->isVariadic())
+ body << llvm::formatv("{0}Successors", name);
+ else
+ body << llvm::formatv("{0}Successor", name);
+
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Variadic)
+ body << llvm::formatv("{0}Types", listName);
+ else if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv("{0}Type", listName);
+ else
+ body << formatv("{0}RawTypes[0]", listName);
+ } else {
+ llvm_unreachable("unknown custom directive parameter");
+ }
+}
+
+/// Generate the parser for a custom directive.
+static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
+ body << " {\n";
+
+ // Preprocess the directive variables.
+ // * Add a local variable for optional operands and types. This provides a
+ // better API to the user defined parser methods.
+ // * Set the location of operand variables.
+ for (Element ¶m : dir->getArguments()) {
+ if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ body << " " << operand->getVar()->name
+ << "OperandsLoc = parser.getCurrentLocation();\n";
+ if (operand->getVar()->isOptional()) {
+ body << llvm::formatv(
+ " llvm::Optional<::mlir::OpAsmParser::OperandType> "
+ "{0}Operand;\n",
+ operand->getVar()->name);
+ }
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Optional)
+ body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
+ }
+ }
+
+ body << " if (parse" << dir->getName() << "(parser";
+ for (Element ¶m : dir->getArguments())
+ genCustomParameterParser(param, body);
+
+ body << "))\n"
+ << " return failure();\n";
+
+ // After parsing, add handling for any of the optional constructs.
+ for (Element ¶m : dir->getArguments()) {
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ const NamedAttribute *var = attr->getVar();
+ if (var->attr.isOptional())
+ body << llvm::formatv(" if ({0}Attr)\n ", var->name);
+
+ body << llvm::formatv(
+ " result.attributes.addAttribute(\"{0}\", {0}Attr);", var->name);
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ const NamedTypeConstraint *var = operand->getVar();
+ if (!var->isOptional())
+ continue;
+ body << llvm::formatv(" if ({0}Operand.hasValue())\n"
+ " {0}Operands.push_back(*{0}Operand);\n",
+ var->name);
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ ArgumentLengthKind lengthKind;
+ StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ if (lengthKind == ArgumentLengthKind::Optional) {
+ body << llvm::formatv(" if ({0}Type)\n"
+ " {0}Types.push_back({0}Type);\n",
+ listName);
+ }
+ }
+ }
+
+ body << " }\n";
+}
+
/// Generate the parser for a single format element.
static void genElementParser(Element *element, OpMethodBody &body,
FmtContext &attrTypeCtx) {
@@ -711,7 +871,7 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << formatv(var->attr.isOptional() ? optionalAttrParserCode
: attrParserCode,
- var->attr.getStorageType(), var->name, attrTypeStr);
+ var->name, attrTypeStr);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
@@ -732,10 +892,11 @@ static void genElementParser(Element *element, OpMethodBody &body,
<< (attrDict->isWithKeyword() ? "WithKeyword" : "")
<< "(result.attributes))\n"
<< " return failure();\n";
+ } else if (auto *customDir = dyn_cast<CustomDirective>(element)) {
+ genCustomDirectiveParser(customDir, body);
+
} else if (isa<OperandsDirective>(element)) {
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
- << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
- "allOperands;\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
} else if (isa<SuccessorsDirective>(element)) {
@@ -980,6 +1141,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
}
+
+ if (!allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments")) {
+ body << " result.addAttribute(\"result_segment_sizes\", "
+ << "parser.getBuilder().getI32VectorAttr({";
+ auto interleaveFn = [&](const NamedTypeConstraint &result) {
+ // If the result is variadic emit the parsed size.
+ if (result.isVariableLength())
+ body << "static_cast<int32_t>(" << result.name << "Types.size())";
+ else
+ body << "1";
+ };
+ llvm::interleaveComma(op.getResults(), body, interleaveFn);
+ body << "}));\n";
+ }
}
//===----------------------------------------------------------------------===//
@@ -1007,6 +1182,8 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
// Elide the variadic segment size attributes if necessary.
if (!fmt.allOperands && op.getTrait("OpTrait::AttrSizedOperandSegments"))
body << "\"operand_segment_sizes\", ";
+ if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
+ body << "\"result_segment_sizes\", ";
llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
body << "\"" << attr->name << "\"";
});
@@ -1038,6 +1215,42 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
}
+/// Generate the printer for a literal value. `shouldEmitSpace` is true if a
+/// space should be emitted before this element. `lastWasPunctuation` is true if
+/// the previous element was a punctuation literal.
+static void genCustomDirectivePrinter(CustomDirective *customDir,
+ OpMethodBody &body) {
+ body << " print" << customDir->getName() << "(p";
+ for (Element ¶m : customDir->getArguments()) {
+ body << ", ";
+ if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ body << attr->getVar()->name << "Attr()";
+
+ } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ body << operand->getVar()->name << "()";
+
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+ body << successor->getVar()->name << "()";
+
+ } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ auto *typeOperand = dir->getOperand();
+ auto *operand = dyn_cast<OperandVariable>(typeOperand);
+ auto *var = operand ? operand->getVar()
+ : cast<ResultVariable>(typeOperand)->getVar();
+ if (var->isVariadic())
+ body << var->name << "().getTypes()";
+ else if (var->isOptional())
+ body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name);
+ else
+ body << var->name << "().getType()";
+ } else {
+ llvm_unreachable("unknown custom directive parameter");
+ }
+ }
+
+ body << ");\n";
+}
+
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg))
@@ -1145,6 +1358,8 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " ::llvm::interleaveComma(" << var->name << "(), p);\n";
else
body << " p << " << var->name << "();\n";
+ } else if (auto *dir = dyn_cast<CustomDirective>(element)) {
+ genCustomDirectivePrinter(dir, body);
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
} else if (isa<SuccessorsDirective>(element)) {
@@ -1202,12 +1417,15 @@ class Token {
caret,
comma,
equal,
+ less,
+ greater,
question,
// Keywords.
keyword_start,
kw_attr_dict,
kw_attr_dict_w_keyword,
+ kw_custom,
kw_functional_type,
kw_operands,
kw_results,
@@ -1353,6 +1571,10 @@ Token FormatLexer::lexToken() {
return formToken(Token::comma, tokStart);
case '=':
return formToken(Token::equal, tokStart);
+ case '<':
+ return formToken(Token::less, tokStart);
+ case '>':
+ return formToken(Token::greater, tokStart);
case '?':
return formToken(Token::question, tokStart);
case '(':
@@ -1406,6 +1628,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
llvm::StringSwitch<Token::Kind>(str)
.Case("attr-dict", Token::kw_attr_dict)
.Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
+ .Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
.Case("results", Token::kw_results)
@@ -1421,8 +1644,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
-template <typename RangeT>
-static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
@@ -1513,6 +1735,10 @@ class FormatParser {
LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel,
bool withKeyword);
+ LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel);
+ LogicalResult parseCustomDirectiveParameter(
+ std::vector<std::unique_ptr<Element>> ¶meters);
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
@@ -1930,6 +2156,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
case Token::kw_attr_dict_w_keyword:
return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel,
/*withKeyword=*/true);
+ case Token::kw_custom:
+ return parseCustomDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_functional_type:
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
case Token::kw_operands:
@@ -2054,15 +2282,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
seenVariables.insert(ele->getVar());
return success();
})
- // Literals and type directives may be used, but they can't anchor the
- // group.
- .Case<LiteralElement, TypeDirective, FunctionalTypeDirective>(
- [&](Element *) {
- if (isAnchor)
- return emitError(childLoc, "only variables can be used to anchor "
- "an optional group");
- return success();
- })
+ // Literals, custom directives, and type directives may be used,
+ // but they can't anchor the group.
+ .Case<LiteralElement, CustomDirective, TypeDirective,
+ FunctionalTypeDirective>([&](Element *) {
+ if (isAnchor)
+ return emitError(childLoc, "only variables can be used to anchor "
+ "an optional group");
+ return success();
+ })
.Default([&](Element *) {
return emitError(childLoc, "only literals, types, and variables can be "
"used within an optional group");
@@ -2084,6 +2312,71 @@ FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
return success();
}
+LogicalResult
+FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel) {
+ llvm::SMLoc curLoc = curToken.getLoc();
+
+ // Parse the custom directive name.
+ if (failed(
+ parseToken(Token::less, "expected '<' before custom directive name")))
+ return failure();
+
+ Token nameTok = curToken;
+ if (failed(parseToken(Token::identifier,
+ "expected custom directive name identifier")) ||
+ failed(parseToken(Token::greater,
+ "expected '>' after custom directive name")) ||
+ failed(parseToken(Token::l_paren,
+ "expected '(' before custom directive parameters")))
+ return failure();
+
+ // Parse the child elements for this optional group.=
+ std::vector<std::unique_ptr<Element>> elements;
+ do {
+ if (failed(parseCustomDirectiveParameter(elements)))
+ return failure();
+ if (curToken.getKind() != Token::comma)
+ break;
+ consumeToken();
+ } while (true);
+
+ if (failed(parseToken(Token::r_paren,
+ "expected ')' after custom directive parameters")))
+ return failure();
+
+ // After parsing all of the elements, ensure that all type directives refer
+ // only to variables.
+ for (auto &ele : elements) {
+ if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
+ if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
+ return emitError(curLoc, "type directives within a custom directive "
+ "may only refer to variables");
+ }
+ }
+ }
+
+ element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
+ std::move(elements));
+ return success();
+}
+
+LogicalResult FormatParser::parseCustomDirectiveParameter(
+ std::vector<std::unique_ptr<Element>> ¶meters) {
+ llvm::SMLoc childLoc = curToken.getLoc();
+ parameters.push_back({});
+ if (failed(parseElement(parameters.back(), /*isTopLevel=*/true)))
+ return failure();
+
+ // Verify that the element can be placed within a custom directive.
+ if (!isa<TypeDirective, AttributeVariable, OperandVariable,
+ SuccessorVariable>(parameters.back().get())) {
+ return emitError(childLoc, "only variables and types may be used as "
+ "parameters to a custom directive");
+ }
+ return success();
+}
+
LogicalResult
FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
Token tok, bool isTopLevel) {
More information about the Mlir-commits
mailing list