[Mlir-commits] [mlir] eaeadce - [mlir][OpFormatGen] Add initial support for regions in the custom op assembly format
River Riddle
llvmlistbot at llvm.org
Mon Aug 31 13:26:57 PDT 2020
Author: River Riddle
Date: 2020-08-31T13:26:24-07:00
New Revision: eaeadce9bd11d50cecfdf9e97ac471acd38136ee
URL: https://github.com/llvm/llvm-project/commit/eaeadce9bd11d50cecfdf9e97ac471acd38136ee
DIFF: https://github.com/llvm/llvm-project/commit/eaeadce9bd11d50cecfdf9e97ac471acd38136ee.diff
LOG: [mlir][OpFormatGen] Add initial support for regions in the custom op assembly format
This adds some initial support for regions and does not support formatting the specific arguments of a region. For now this can be achieved by using a custom directive that formats the arguments and then parses the region.
Differential Revision: https://reviews.llvm.org/D86760
Added:
Modified:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/OperationSupport.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/test/mlir-tblgen/op-format.mlir
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 167546e67522..1b1a2125e95d 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -681,6 +681,10 @@ The available directives are as follows:
- Represents all of the operands of an operation.
+* `regions`
+
+ - Represents all of the regions of an operation.
+
* `results`
- Represents all of the results of an operation.
@@ -700,13 +704,14 @@ The available directives are as follows:
A literal is either a keyword or punctuation surrounded by \`\`.
The following are the set of valid punctuation:
- `:`, `,`, `=`, `<`, `>`, `(`, `)`, `[`, `]`, `->`
+
+`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`
#### Variables
A variable is an entity that has been registered on the operation itself, i.e.
-an argument(attribute or operand), result, successor, etc. In the `CallOp`
-example above, the variables would be `$callee` and `$args`.
+an argument(attribute or operand), region, result, successor, etc. In the
+`CallOp` example above, the variables would be `$callee` and `$args`.
Attribute variables are printed with their respective value type, unless that
value type is buildable. In those cases, the type of the attribute is elided.
@@ -747,6 +752,9 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `OpAsmParser::OperandType &`
- Optional: `Optional<OpAsmParser::OperandType> &`
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+* Region Variables
+ - Single: `Region &`
+ - Variadic: `SmallVectorImpl<std::unique_ptr<Region>> &`
* Successor Variables
- Single: `Block *&`
- Variadic: `SmallVectorImpl<Block *> &`
@@ -770,6 +778,9 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Value`
- Optional: `Value`
- Variadic: `OperandRange`
+* Region Variables
+ - Single: `Region &`
+ - Variadic: `MutableArrayRef<Region>`
* Successor Variables
- Single: `Block *`
- Variadic: `SuccessorRange`
@@ -788,8 +799,8 @@ of the assembly format can be marked as `optional` based on the presence of this
information. An optional group is defined by wrapping a set of elements within
`()` followed by a `?` and has the following requirements:
-* The first element of the group must either be a literal, attribute, or an
- operand.
+* The first element of the group must either be a attribute, literal, operand,
+ or region.
- This is because the first element must be optionally parsable.
* Exactly one argument variable within the group must be marked as the anchor
of the group.
@@ -797,11 +808,15 @@ information. An optional group is defined by wrapping a set of elements within
should be printed/parsed.
- An element is marked as the anchor by adding a trailing `^`.
- The first element is *not* required to be the anchor of the group.
+ - When a non-variadic region anchors a group, the detector for printing
+ the group is if the region is empty.
* Literals, variables, custom directives, and type directives are the only
valid elements within the group.
- Any attribute variable may be used, but only optional attributes can be
marked as the anchor.
- Only variadic or optional operand arguments can be used.
+ - All region variables can be used. When a non-variable length region is
+ used, if the group is not present the region is empty.
- The operands to a type directive must be defined within the optional
group.
@@ -853,18 +868,22 @@ foo.op
The format specification has a certain set of requirements that must be adhered
to:
-1. The output and operation name are never shown as they are fixed and cannot be
- altered.
-1. All operands within the operation must appear within the format, either
- individually or with the `operands` directive.
-1. All operand and result types must appear within the format using the various
- `type` directives, either individually or with the `operands` or `results`
- directives.
-1. The `attr-dict` directive must always be present.
-1. Must not contain overlapping information; e.g. multiple instances of
- 'attr-dict', types, operands, etc.
- - Note that `attr-dict` does not overlap with individual attributes. These
- attributes will simply be elided when printing the attribute dictionary.
+1. The output and operation name are never shown as they are fixed and cannot
+ be altered.
+1. All operands within the operation must appear within the format, either
+ individually or with the `operands` directive.
+1. All regions within the operation must appear within the format, either
+ individually or with the `regions` directive.
+1. All successors within the operation must appear within the format, either
+ individually or with the `successors` directive.
+1. All operand and result types must appear within the format using the various
+ `type` directives, either individually or with the `operands` or `results`
+ directives.
+1. The `attr-dict` directive must always be present.
+1. Must not contain overlapping information; e.g. multiple instances of
+ 'attr-dict', types, operands, etc.
+ - Note that `attr-dict` does not overlap with individual attributes. These
+ attributes will simply be elided when printing the attribute dictionary.
##### Type Inference
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5962a1a48668..a30ffd8f75b4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -424,8 +424,8 @@ class OpAsmParser {
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
- OptionalParseResult parseOptionalAttribute(Attribute &result,
- StringRef attrName,
+ template <typename AttrT>
+ OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
NamedAttrList &attrs) {
return parseOptionalAttribute(result, Type(), attrName, attrs);
}
@@ -433,6 +433,7 @@ class OpAsmParser {
/// Specialized variants of `parseOptionalAttribute` that remove potential
/// ambiguities in syntax.
virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+ Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
@@ -621,16 +622,23 @@ class OpAsmParser {
/// can only be set to true for regions attached to operations that are
/// "IsolatedFromAbove".
virtual ParseResult parseRegion(Region ®ion,
- ArrayRef<OperandType> arguments,
- ArrayRef<Type> argTypes,
+ ArrayRef<OperandType> arguments = {},
+ ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;
/// Parses a region if present.
virtual ParseResult parseOptionalRegion(Region ®ion,
- ArrayRef<OperandType> arguments,
- ArrayRef<Type> argTypes,
+ ArrayRef<OperandType> arguments = {},
+ ArrayRef<Type> argTypes = {},
bool enableNameShadowing = false) = 0;
+ /// Parses a region if present. If the region is present, a new region is
+ /// allocated and placed in `region`. If no region is present or on failure,
+ /// `region` remains untouched.
+ virtual OptionalParseResult parseOptionalRegion(
+ std::unique_ptr<Region> ®ion, ArrayRef<OperandType> arguments = {},
+ ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
+
/// Parse a region argument, this argument is resolved when calling
/// 'parseRegion'.
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index b1758ec8f5ca..b0e1205eefe6 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -414,6 +414,10 @@ struct OperationState {
/// region is null, a new empty region will be attached to the Operation.
void addRegion(std::unique_ptr<Region> &®ion);
+ /// Take ownership of a set of regions that should be attached to the
+ /// Operation.
+ void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
+
/// Get the context held by this operation state.
MLIRContext *getContext() const { return location->getContext(); }
};
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index b477a8a23900..ab84f4e8cf17 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -199,6 +199,12 @@ void OperationState::addRegion(std::unique_ptr<Region> &®ion) {
regions.push_back(std::move(region));
}
+void OperationState::addRegions(
+ MutableArrayRef<std::unique_ptr<Region>> regions) {
+ for (std::unique_ptr<Region> ®ion : regions)
+ addRegion(std::move(region));
+}
+
//===----------------------------------------------------------------------===//
// OperandStorage
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index b7cae2778c10..4e17ccd8022d 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -221,8 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
return result;
}
}
-OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
- return parseOptionalAttributeWithToken(Token::l_square, attribute);
+OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
+ Type type) {
+ return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
/// Attribute dictionary.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index cc102b9e97a4..d6065f758fc1 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1045,7 +1045,6 @@ class CustomOpAsmParser : public OpAsmParser {
}
/// Parse an optional attribute.
- /// Template utilities to simplify specifying multiple derived overloads.
template <typename AttrT>
OptionalParseResult
parseOptionalAttributeAndAddToList(AttrT &result, Type type,
@@ -1056,25 +1055,15 @@ class CustomOpAsmParser : public OpAsmParser {
attrs.push_back(parser.builder.getNamedAttr(attrName, result));
return parseResult;
}
- template <typename AttrT>
- OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
- StringRef attrName,
- NamedAttrList &attrs) {
- OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
- if (parseResult.hasValue() && succeeded(*parseResult))
- attrs.push_back(parser.builder.getNamedAttr(attrName, result));
- return parseResult;
- }
-
OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
- OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+ OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
StringRef attrName,
NamedAttrList &attrs) override {
- return parseOptionalAttributeAndAddToList(result, attrName, attrs);
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
/// Parse a named dictionary into 'result' if it is present.
@@ -1355,6 +1344,23 @@ class CustomOpAsmParser : public OpAsmParser {
return parseRegion(region, arguments, argTypes, enableNameShadowing);
}
+ /// Parses a region if present. If the region is present, a new region is
+ /// allocated and placed in `region`. If no region is present, `region`
+ /// remains untouched.
+ OptionalParseResult
+ parseOptionalRegion(std::unique_ptr<Region> ®ion,
+ ArrayRef<OperandType> arguments, ArrayRef<Type> argTypes,
+ bool enableNameShadowing = false) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return llvm::None;
+ std::unique_ptr<Region> newRegion = std::make_unique<Region>();
+ if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
+ return failure();
+
+ region = std::move(newRegion);
+ return success();
+ }
+
/// Parse a region argument. The type of the argument will be resolved later
/// by a call to `parseRegion`.
ParseResult parseRegionArgument(OperandType &argument) override {
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 61e54be83139..c01a0a004072 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -187,7 +187,7 @@ class Parser {
/// Parse an optional attribute with the provided type.
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
- OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
+ OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
/// Parse an optional attribute that is demarcated by a specific token.
template <typename AttributeT>
@@ -197,8 +197,8 @@ class Parser {
if (getToken().isNot(kind))
return llvm::None;
- if (Attribute parsedAttr = parseAttribute()) {
- attr = parsedAttr.cast<ArrayAttr>();
+ if (Attribute parsedAttr = parseAttribute(type)) {
+ attr = parsedAttr.cast<AttributeT>();
return success();
}
return failure();
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 292f5ed4b641..d75422e84124 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -319,6 +319,19 @@ static ParseResult parseCustomDirectiveOperandsAndTypes(
return failure();
return success();
}
+static ParseResult parseCustomDirectiveRegions(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+ if (parser.parseRegion(region))
+ return failure();
+ if (failed(parser.parseOptionalComma()))
+ return success();
+ std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+ if (parser.parseRegion(*varRegion))
+ return failure();
+ varRegions.emplace_back(std::move(varRegion));
+ return success();
+}
static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
SmallVectorImpl<Block *> &varSuccessors) {
@@ -361,6 +374,15 @@ printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
printCustomDirectiveResults(printer, operandType, optOperandType,
varOperandTypes);
}
+static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion,
+ MutableArrayRef<Region> varRegions) {
+ printer.printRegion(region);
+ if (!varRegions.empty()) {
+ printer << ", ";
+ for (Region ®ion : varRegions)
+ printer.printRegion(region);
+ }
+}
static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
Block *successor,
SuccessorRange varSuccessors) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0e186d0cd29b..bc26a8659831 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1161,8 +1161,13 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
//===----------------------------------------------------------------------===//
def TestRegionBuilderOp : TEST_Op<"region_builder">;
-def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
- Arguments<(ins Variadic<AnyType>)>;
+def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]> {
+ let arguments = (ins Variadic<AnyType>);
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &state",
+ [{ build(builder, state, {}); }]>
+ ];
+}
def TestCastOp : TEST_Op<"cast">,
Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
@@ -1333,6 +1338,43 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
let assemblyFormat = "$buildable attr-dict";
}
+// Test various mixings of region formatting.
+class FormatRegionBase<string suffix, string fmt>
+ : TEST_Op<"format_region_" # suffix # "_op"> {
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = fmt;
+}
+def FormatRegionAOp : FormatRegionBase<"a", [{
+ regions attr-dict
+}]>;
+def FormatRegionBOp : FormatRegionBase<"b", [{
+ $region attr-dict
+}]>;
+def FormatRegionCOp : FormatRegionBase<"c", [{
+ (`region` $region^)? attr-dict
+}]>;
+class FormatVariadicRegionBase<string suffix, string fmt>
+ : TEST_Op<"format_variadic_region_" # suffix # "_op"> {
+ let regions = (region VariadicRegion<AnyRegion>:$regions);
+ let assemblyFormat = fmt;
+}
+def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
+ $regions attr-dict
+}]>;
+def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
+ ($regions^ `found_regions`)? attr-dict
+}]>;
+class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
+ : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
+ [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+ let regions = (region AnyRegion:$region);
+ let assemblyFormat = fmt;
+}
+def FormatFormatRegionImplicitTerminatorAOp
+ : FormatRegionImplicitTerminatorBase<"a", [{
+ $region attr-dict
+}]>;
+
// Test various mixings of result type formatting.
class FormatResultBase<string suffix, string fmt>
: TEST_Op<"format_result_" # suffix # "_op"> {
@@ -1454,6 +1496,16 @@ def FormatCustomDirectiveOperandsAndTypes
}];
}
+def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
+ let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$regions);
+ let assemblyFormat = [{
+ custom<CustomDirectiveRegions>(
+ $region, $regions
+ )
+ attr-dict
+ }];
+}
+
def FormatCustomDirectiveResults
: TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
let results = (outs AnyType:$result, Optional<AnyType>:$optResult,
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 8e7c8ec56a2a..60189943ddab 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -133,6 +133,28 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
operands attr-dict
}]>;
+//===----------------------------------------------------------------------===//
+// regions
+
+// CHECK: error: 'regions' directive creates overlap in format
+def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{
+ regions regions attr-dict
+}]>;
+// CHECK: error: 'regions' directive creates overlap in format
+def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{
+ $region regions attr-dict
+}]> {
+ let regions = (region AnyRegion:$region);
+}
+// CHECK: error: 'regions' is only valid as a top-level directive
+def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{
+ type(regions)
+}]>;
+// CHECK-NOT: error:
+def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{
+ regions attr-dict
+}]>;
+
//===----------------------------------------------------------------------===//
// results
@@ -249,7 +271,7 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first element of an operand group must be an attribute, literal, or operand
+// CHECK: error: first element of an operand group must be an attribute, literal, operand, or region
def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
(type($operand) $operand^)? attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
@@ -290,7 +312,7 @@ def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
// Variables
//===----------------------------------------------------------------------===//
-// CHECK: error: expected variable to refer to an argument, result, or successor
+// CHECK: error: expected variable to refer to an argument, region, result, or successor
def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
$unknown_arg attr-dict
}]>;
@@ -330,11 +352,35 @@ def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
-// CHECK-NOT: error:
+// CHECK: error: region 'region' is already bound
def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
+ $region $region attr-dict
+}]> {
+ let regions = (region AnyRegion:$region);
+}
+// CHECK: error: region 'region' is already bound
+def VariableInvalidK : TestFormat_Op<"variable_invalid_K", [{
+ regions $region attr-dict
+}]> {
+ let regions = (region AnyRegion:$region);
+}
+// CHECK: error: regions can only be used at the top level
+def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{
+ type($region)
+}]> {
+ let regions = (region AnyRegion:$region);
+}
+// CHECK: error: region #0, named 'region', not found
+def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{
+ attr-dict
+}]> {
+ let regions = (region AnyRegion:$region);
+}
+// CHECK-NOT: error:
+def VariableValidA : TestFormat_Op<"variable_valid_a", [{
$attr `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
-def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
+def VariableValidB : TestFormat_Op<"variable_valid_b", [{
(`foo` $attr^)? `:` attr-dict
}]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 40eb17ac52dc..9f7c9c0f4809 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -40,6 +40,72 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64, opt_attr = 10 : i64
// CHECK: test.format_buildable_type_op %[[I64]]
%ignored = test.format_buildable_type_op %i64
+//===----------------------------------------------------------------------===//
+// Format regions
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_region_a_op {
+// CHECK-NEXT: test.return
+test.format_region_a_op {
+ "test.return"() : () -> ()
+}
+
+// CHECK: test.format_region_b_op {
+// CHECK-NEXT: test.return
+test.format_region_b_op {
+ "test.return"() : () -> ()
+}
+
+// CHECK: test.format_region_c_op region {
+// CHECK-NEXT: test.return
+test.format_region_c_op region {
+ "test.return"() : () -> ()
+}
+// CHECK: test.format_region_c_op
+// CHECK-NOT: region {
+test.format_region_c_op
+
+// CHECK: test.format_variadic_region_a_op {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_variadic_region_a_op {
+ "test.return"() : () -> ()
+}, {
+ "test.return"() : () -> ()
+}
+// CHECK: test.format_variadic_region_b_op {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: } found_regions
+test.format_variadic_region_b_op {
+ "test.return"() : () -> ()
+}, {
+ "test.return"() : () -> ()
+} found_regions
+// CHECK: test.format_variadic_region_b_op
+// CHECK-NOT: {
+// CHECK-NOT: found_regions
+test.format_variadic_region_b_op
+
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: }
+test.format_implicit_terminator_region_a_op {
+ "test.return"() : () -> ()
+}
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: test.return"() {foo.attr
+test.format_implicit_terminator_region_a_op {
+ "test.return"() {foo.attr} : () -> ()
+}
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: test.return"(%[[I64]]) : (i64)
+test.format_implicit_terminator_region_a_op {
+ "test.return"(%i64) : (i64) -> ()
+}
+
//===----------------------------------------------------------------------===//
// Format results
//===----------------------------------------------------------------------===//
@@ -147,6 +213,24 @@ test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64
// CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
+// CHECK: test.format_custom_directive_regions {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_custom_directive_regions {
+ "test.return"() : () -> ()
+}
+
+// CHECK: test.format_custom_directive_regions {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_custom_directive_regions {
+ "test.return"() : () -> ()
+}, {
+ "test.return"() : () -> ()
+}
+
// CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
test.format_custom_directive_results : i64, i64 -> (i64)
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 684c8ad8cb17..1542e9c55e41 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -16,6 +16,7 @@
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -48,6 +49,7 @@ class Element {
CustomDirective,
FunctionalTypeDirective,
OperandsDirective,
+ RegionsDirective,
ResultsDirective,
SuccessorsDirective,
TypeDirective,
@@ -58,6 +60,7 @@ class Element {
/// This element is an variable value.
AttributeVariable,
OperandVariable,
+ RegionVariable,
ResultVariable,
SuccessorVariable,
@@ -119,6 +122,10 @@ struct AttributeVariable
using OperandVariable =
VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
+/// This class represents a variable that refers to a region.
+using RegionVariable =
+ VariableElement<NamedRegion, Element::Kind::RegionVariable>;
+
/// This class represents a variable that refers to a result.
using ResultVariable =
VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
@@ -133,7 +140,8 @@ using SuccessorVariable =
namespace {
/// This class implements single kind directives.
-template <Element::Kind type> class DirectiveElement : public Element {
+template <Element::Kind type>
+class DirectiveElement : public Element {
public:
DirectiveElement() : Element(type){};
static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -142,6 +150,10 @@ template <Element::Kind type> class DirectiveElement : public Element {
/// all of the operands of an operation.
using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
+/// This class represents the `regions` directive. This directive represents
+/// all of the regions of an operation.
+using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
+
/// This class represents the `results` directive. This directive represents
/// all of the results of an operation.
using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
@@ -350,13 +362,23 @@ struct OperationFormat {
: allOperands(false), allOperandTypes(false), allResultTypes(false) {
operandTypes.resize(op.getNumOperands(), TypeResolution());
resultTypes.resize(op.getNumResults(), TypeResolution());
+
+ hasImplicitTermTrait =
+ llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
+ return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
+ });
}
/// Generate the operation parser from this format.
void genParser(Operator &op, OpClass &opClass);
+ /// Generate the parser code for a specific format element.
+ void genElementParser(Element *element, OpMethodBody &body,
+ FmtContext &attrTypeCtx);
/// Generate the c++ to resolve the types of operands and results during
/// parsing.
void genParserTypeResolution(Operator &op, OpMethodBody &body);
+ /// Generate the c++ to resolve regions during parsing.
+ void genParserRegionResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to resolve successors during parsing.
void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
/// Generate the c++ to handling variadic segment size traits.
@@ -365,6 +387,10 @@ struct OperationFormat {
/// Generate the operation printer from this format.
void genPrinter(Operator &op, OpClass &opClass);
+ /// Generate the printer code for a specific format element.
+ void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
+ bool &shouldEmitSpace, bool &lastWasPunctuation);
+
/// The various elements in this format.
std::vector<std::unique_ptr<Element>> elements;
@@ -372,11 +398,18 @@ struct OperationFormat {
/// contains these, it can not contain individual type resolvers.
bool allOperands, allOperandTypes, allResultTypes;
+ /// A flag indicating if this operation has the SingleBlockImplicitTerminator
+ /// trait.
+ bool hasImplicitTermTrait;
+
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
/// The index of the buildable type, if valid, for every operand and result.
std::vector<TypeResolution> operandTypes, resultTypes;
+
+ /// The set of attributes explicitly used within the format.
+ SmallVector<const NamedAttribute *, 8> usedAttributes;
};
} // end anonymous namespace
@@ -541,6 +574,60 @@ const char *const functionalTypeParserCode = R"(
{1}Types = {0}__{1}_functionType.getResults();
)";
+/// The code snippet used to generate a parser call for a region list.
+///
+/// {0}: The name for the region list.
+const char *regionListParserCode = R"(
+ {
+ std::unique_ptr<::mlir::Region> region;
+ auto firstRegionResult = parser.parseOptionalRegion(region);
+ if (firstRegionResult.hasValue()) {
+ if (failed(*firstRegionResult))
+ return failure();
+ {0}Regions.emplace_back(std::move(region));
+
+ // Parse any trailing regions.
+ while (succeeded(parser.parseOptionalComma())) {
+ region = std::make_unique<::mlir::Region>();
+ if (parser.parseRegion(*region))
+ return failure();
+ {0}Regions.emplace_back(std::move(region));
+ }
+ }
+ }
+)";
+
+/// The code snippet used to ensure a list of regions have terminators.
+///
+/// {0}: The name of the region list.
+const char *regionListEnsureTerminatorParserCode = R"(
+ for (auto ®ion : {0}Regions)
+ ensureTerminator(*region, parser.getBuilder(), result.location);
+)";
+
+/// The code snippet used to generate a parser call for an optional region.
+///
+/// {0}: The name of the region.
+const char *optionalRegionParserCode = R"(
+ if (parser.parseOptionalRegion(*{0}Region))
+ return failure();
+)";
+
+/// The code snippet used to generate a parser call for a region.
+///
+/// {0}: The name of the region.
+const char *regionParserCode = R"(
+ if (parser.parseRegion(*{0}Region))
+ return failure();
+)";
+
+/// The code snippet used to ensure a region has a terminator.
+///
+/// {0}: The name of the region.
+const char *regionEnsureTerminatorParserCode = R"(
+ ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
+)";
+
/// The code snippet used to generate a parser call for a successor list.
///
/// {0}: The name for the successor list.
@@ -658,6 +745,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
"allOperands;\n";
+ } else if (isa<RegionsDirective>(element)) {
+ body << " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
+ "fullRegions;\n";
+
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
@@ -680,6 +771,20 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
name);
+
+ } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+ StringRef name = region->getVar()->name;
+ if (region->getVar()->isVariadic()) {
+ body << llvm::formatv(
+ " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
+ "{0}Regions;\n",
+ name);
+ } else {
+ body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
+ "std::make_unique<::mlir::Region>();\n",
+ name);
+ }
+
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) {
@@ -725,6 +830,13 @@ static void genCustomParameterParser(Element ¶m, OpMethodBody &body) {
else
body << formatv("{0}RawOperands[0]", name);
+ } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
+ StringRef name = region->getVar()->name;
+ if (region->getVar()->isVariadic())
+ body << llvm::formatv("{0}Regions", name);
+ else
+ body << llvm::formatv("*{0}Region", name);
+
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic())
@@ -809,9 +921,39 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << " }\n";
}
-/// Generate the parser for a single format element.
-static void genElementParser(Element *element, OpMethodBody &body,
- FmtContext &attrTypeCtx) {
+void OperationFormat::genParser(Operator &op, OpClass &opClass) {
+ auto &method = opClass.newMethod(
+ "::mlir::ParseResult", "parse",
+ "::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
+ OpMethod::MP_Static);
+ auto &body = method.body();
+
+ // Generate variables to store the operands and type within the format. This
+ // allows for referencing these variables in the presence of optional
+ // groupings.
+ for (auto &element : elements)
+ genElementParserStorage(&*element, body);
+
+ // A format context used when parsing attributes with buildable types.
+ FmtContext attrTypeCtx;
+ attrTypeCtx.withBuilder("parser.getBuilder()");
+
+ // Generate parsers for each of the elements.
+ for (auto &element : elements)
+ genElementParser(element.get(), body, attrTypeCtx);
+
+ // Generate the code to resolve the operand/result types and successors now
+ // that they have been parsed.
+ genParserTypeResolution(op, body);
+ genParserRegionResolution(op, body);
+ genParserSuccessorResolution(op, body);
+ genParserVariadicSegmentResolution(op, body);
+
+ body << " return success();\n";
+}
+
+void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
+ FmtContext &attrTypeCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
auto elements = optional->getElements();
@@ -829,6 +971,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
genElementParser(opVar, body, attrTypeCtx);
body << " if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
+ } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
+ const NamedRegion *region = regionVar->getVar();
+ if (region->isVariadic()) {
+ genElementParser(regionVar, body, attrTypeCtx);
+ body << " if (!" << region->name << "Regions.empty()) {\n";
+ } else {
+ body << llvm::formatv(optionalRegionParserCode, region->name);
+ body << " if (!" << region->name << "Region->empty()) {\n ";
+ if (hasImplicitTermTrait)
+ body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
+ }
}
// If the anchor is a unit attribute, we don't need to print it. When
@@ -907,6 +1060,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << llvm::formatv(optionalOperandParserCode, name);
else
body << formatv(operandParserCode, name);
+
+ } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+ bool isVariadic = region->getVar()->isVariadic();
+ body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
+ region->getVar()->name);
+ if (hasImplicitTermTrait) {
+ body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
+ : regionEnsureTerminatorParserCode,
+ region->getVar()->name);
+ }
+
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
body << formatv(isVariadic ? successorListParserCode : successorParserCode,
@@ -925,8 +1089,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
body << " ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
<< " if (parser.parseOperandList(allOperands))\n"
<< " return failure();\n";
+
+ } else if (isa<RegionsDirective>(element)) {
+ body << llvm::formatv(regionListParserCode, "full");
+ if (hasImplicitTermTrait)
+ body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
+
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
+
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -946,36 +1117,6 @@ static void genElementParser(Element *element, OpMethodBody &body,
}
}
-void OperationFormat::genParser(Operator &op, OpClass &opClass) {
- auto &method = opClass.newMethod(
- "::mlir::ParseResult", "parse",
- "::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
- OpMethod::MP_Static);
- auto &body = method.body();
-
- // Generate variables to store the operands and type within the format. This
- // allows for referencing these variables in the presence of optional
- // groupings.
- for (auto &element : elements)
- genElementParserStorage(&*element, body);
-
- // A format context used when parsing attributes with buildable types.
- FmtContext attrTypeCtx;
- attrTypeCtx.withBuilder("parser.getBuilder()");
-
- // Generate parsers for each of the elements.
- for (auto &element : elements)
- genElementParser(element.get(), body, attrTypeCtx);
-
- // Generate the code to resolve the operand/result types and successors now
- // that they have been parsed.
- genParserTypeResolution(op, body);
- genParserSuccessorResolution(op, body);
- genParserVariadicSegmentResolution(op, body);
-
- body << " return success();\n";
-}
-
void OperationFormat::genParserTypeResolution(Operator &op,
OpMethodBody &body) {
// If any of type resolutions use transformed variables, make sure that the
@@ -1133,6 +1274,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
}
}
+void OperationFormat::genParserRegionResolution(Operator &op,
+ OpMethodBody &body) {
+ // Check for the case where all regions were parsed.
+ bool hasAllRegions = llvm::any_of(
+ elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
+ if (hasAllRegions) {
+ body << " result.addRegions(fullRegions);\n";
+ return;
+ }
+
+ // Otherwise, handle each region individually.
+ for (const NamedRegion ®ion : op.getRegions()) {
+ if (region.isVariadic())
+ body << " result.addRegions(" << region.name << "Regions);\n";
+ else
+ body << " result.addRegion(std::move(" << region.name << "Region));\n";
+ }
+}
+
void OperationFormat::genParserSuccessorResolution(Operator &op,
OpMethodBody &body) {
// Check for the case where all successors were parsed.
@@ -1186,23 +1346,26 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
//===----------------------------------------------------------------------===//
// PrinterGen
-/// Generate the printer for the 'attr-dict' directive.
-static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
- OpMethodBody &body, bool withKeyword) {
- // Collect all of the attributes used in the format, these will be elided.
- SmallVector<const NamedAttribute *, 1> usedAttributes;
- for (auto &it : fmt.elements) {
- if (auto *attr = dyn_cast<AttributeVariable>(it.get()))
- usedAttributes.push_back(attr->getVar());
- // Collect the optional attributes.
- if (auto *opt = dyn_cast<OptionalElement>(it.get())) {
- for (auto &elem : opt->getElements()) {
- if (auto *attr = dyn_cast<AttributeVariable>(&elem))
- usedAttributes.push_back(attr->getVar());
- }
+/// The code snippet used to generate a printer call for a region of an
+// operation that has the SingleBlockImplicitTerminator trait.
+///
+/// {0}: The name of the region.
+const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
+ {
+ bool printTerminator = true;
+ if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
+ printTerminator = !term->getMutableAttrDict().empty() ||
+ term->getNumOperands() != 0 ||
+ term->getNumResults() != 0;
}
+ p.printRegion({0}, /*printEntryBlockArgs=*/true,
+ /*printBlockTerminators=*/printTerminator);
}
+)";
+/// Generate the printer for the 'attr-dict' directive.
+static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
+ OpMethodBody &body, bool withKeyword) {
body << " p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
<< "(getAttrs(), /*elidedAttrs=*/{";
// Elide the variadic segment size attributes if necessary.
@@ -1210,9 +1373,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
body << "\"operand_segment_sizes\", ";
if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
body << "\"result_segment_sizes\", ";
- llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
- body << "\"" << attr->name << "\"";
- });
+ llvm::interleaveComma(
+ fmt.usedAttributes, body,
+ [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
body << "});\n";
}
@@ -1255,6 +1418,9 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
} else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
body << operand->getVar()->name << "()";
+ } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
+ body << region->getVar()->name << "()";
+
} else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
body << successor->getVar()->name << "()";
@@ -1277,6 +1443,24 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
body << ");\n";
}
+/// Generate the printer for a region with the given variable name.
+static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body,
+ bool hasImplicitTermTrait) {
+ if (hasImplicitTermTrait)
+ body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
+ regionName);
+ else
+ body << " p.printRegion(" << regionName << ");\n";
+}
+static void genVariadicRegionPrinter(const Twine ®ionListName,
+ OpMethodBody &body,
+ bool hasImplicitTermTrait) {
+ body << " llvm::interleaveComma(" << regionListName
+ << ", p, [&](::mlir::Region ®ion) {\n ";
+ genRegionPrinter("region", body, hasImplicitTermTrait);
+ body << " });\n";
+}
+
/// Generate the C++ for an operand to a (*-)type directive.
static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
if (isa<OperandsDirective>(arg))
@@ -1296,10 +1480,9 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
<< "().getType())";
}
-/// Generate the code for printing the given element.
-static void genElementPrinter(Element *element, OpMethodBody &body,
- OperationFormat &fmt, Operator &op,
- bool &shouldEmitSpace, bool &lastWasPunctuation) {
+void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
+ Operator &op, bool &shouldEmitSpace,
+ bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
lastWasPunctuation);
@@ -1314,6 +1497,11 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
body << " if (" << var->name << "()) {\n";
else if (var->isVariadic())
body << " if (!" << var->name << "().empty()) {\n";
+ } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
+ const NamedRegion *var = region->getVar();
+ // TODO: Add a check for optional here when ODS supports it.
+ body << " if (!" << var->name << "().empty()) {\n";
+
} else {
body << " if (getAttr(\""
<< cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
@@ -1332,7 +1520,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit each of the elements.
for (Element &childElement : elements) {
if (&childElement != elidedAnchorElement) {
- genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
+ genElementPrinter(&childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
@@ -1342,7 +1530,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
// Emit the attribute dictionary.
if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
- genAttrDictPrinter(fmt, op, body, attrDict->isWithKeyword());
+ genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
lastWasPunctuation = false;
return;
}
@@ -1384,6 +1572,13 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
} else {
body << " p << " << operand->getVar()->name << "();\n";
}
+ } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+ const NamedRegion *var = region->getVar();
+ if (var->isVariadic()) {
+ genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+ } else {
+ genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+ }
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
const NamedSuccessor *var = successor->getVar();
if (var->isVariadic())
@@ -1394,6 +1589,9 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
genCustomDirectivePrinter(dir, body);
} else if (isa<OperandsDirective>(element)) {
body << " p << getOperation()->getOperands();\n";
+ } else if (isa<RegionsDirective>(element)) {
+ genVariadicRegionPrinter("getOperation()->getRegions()", body,
+ hasImplicitTermTrait);
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
@@ -1426,7 +1624,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false;
for (auto &element : elements)
- genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
+ genElementPrinter(element.get(), body, op, shouldEmitSpace,
lastWasPunctuation);
}
@@ -1460,6 +1658,7 @@ class Token {
kw_custom,
kw_functional_type,
kw_operands,
+ kw_regions,
kw_results,
kw_successors,
kw_type,
@@ -1663,6 +1862,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
.Case("custom", Token::kw_custom)
.Case("functional-type", Token::kw_functional_type)
.Case("operands", Token::kw_operands)
+ .Case("regions", Token::kw_regions)
.Case("results", Token::kw_results)
.Case("successors", Token::kw_successors)
.Case("type", Token::kw_type)
@@ -1676,7 +1876,8 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
/// Function to find an element within the given range that has the same name as
/// 'name'.
-template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT>
+static auto findArg(RangeT &&range, StringRef name) {
auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
return it != range.end() ? &*it : nullptr;
}
@@ -1719,6 +1920,9 @@ class FormatParser {
verifyOperands(llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
+ /// Verify the state of operation regions within the format.
+ LogicalResult verifyRegions(llvm::SMLoc loc);
+
/// Verify the state of operation results within the format.
LogicalResult
verifyResults(llvm::SMLoc loc,
@@ -1775,6 +1979,8 @@ class FormatParser {
Token tok, bool isTopLevel);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
+ LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel);
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
@@ -1821,11 +2027,12 @@ class FormatParser {
// The following are various bits of format state used for verification
// during parsing.
- bool hasAllOperands = false, hasAttrDict = false;
- bool hasAllSuccessors = false;
+ bool hasAttrDict = false;
+ bool hasAllRegions = false, hasAllSuccessors = false;
llvm::SmallBitVector seenOperandTypes, seenResultTypes;
+ llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
- llvm::DenseSet<const NamedAttribute *> seenAttrs;
+ llvm::DenseSet<const NamedRegion *> seenRegions;
llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
};
@@ -1867,13 +2074,11 @@ LogicalResult FormatParser::parse() {
if (failed(verifyAttributes(loc)) ||
failed(verifyResults(loc, variableTyResolver)) ||
failed(verifyOperands(loc, variableTyResolver)) ||
- failed(verifySuccessors(loc)))
+ failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
return failure();
- // Check to see if we are formatting all of the operands.
- fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
- return isa<OperandsDirective>(elt.get());
- });
+ // Collect the set of used attributes in the format.
+ fmt.usedAttributes = seenAttrs.takeVector();
return success();
}
@@ -1953,7 +2158,7 @@ LogicalResult FormatParser::verifyOperands(
NamedTypeConstraint &operand = op.getOperand(i);
// Check that the operand itself is in the format.
- if (!hasAllOperands && !seenOperands.count(&operand)) {
+ if (!fmt.allOperands && !seenOperands.count(&operand)) {
return emitErrorAndNote(loc,
"operand #" + Twine(i) + ", named '" +
operand.name + "', not found",
@@ -1976,7 +2181,7 @@ LogicalResult FormatParser::verifyOperands(
// Similarly to results, allow a custom builder for resolving the type if
// we aren't using the 'operands' directive.
Optional<StringRef> builder = operand.constraint.getBuilderCall();
- if (!builder || (hasAllOperands && operand.isVariableLength())) {
+ if (!builder || (fmt.allOperands && operand.isVariableLength())) {
return emitErrorAndNote(
loc,
"type of operand #" + Twine(i) + ", named '" + operand.name +
@@ -1991,6 +2196,24 @@ LogicalResult FormatParser::verifyOperands(
return success();
}
+LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
+ // Check that all of the regions are within the format.
+ if (hasAllRegions)
+ return success();
+
+ for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
+ const NamedRegion ®ion = op.getRegion(i);
+ if (!seenRegions.count(®ion)) {
+ return emitErrorAndNote(loc,
+ "region #" + Twine(i) + ", named '" +
+ region.name + "', not found",
+ "suggest adding a '$" + region.name +
+ "' directive to the custom assembly format");
+ }
+ }
+ return success();
+}
+
LogicalResult FormatParser::verifyResults(
llvm::SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
@@ -2108,7 +2331,7 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
- return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
+ return seenAttrs.count(attr) ? attr : nullptr;
return nullptr;
}
@@ -2142,7 +2365,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
// op.
/// Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
- if (isTopLevel && !seenAttrs.insert(attr).second)
+ if (isTopLevel && !seenAttrs.insert(attr))
return emitError(loc, "attribute '" + name + "' is already bound");
element = std::make_unique<AttributeVariable>(attr);
return success();
@@ -2150,12 +2373,21 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
/// Operands
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
if (isTopLevel) {
- if (hasAllOperands || !seenOperands.insert(operand).second)
+ if (fmt.allOperands || !seenOperands.insert(operand).second)
return emitError(loc, "operand '" + name + "' is already bound");
}
element = std::make_unique<OperandVariable>(operand);
return success();
}
+ /// Regions
+ if (const NamedRegion *region = findArg(op.getRegions(), name)) {
+ if (!isTopLevel)
+ return emitError(loc, "regions can only be used at the top level");
+ if (hasAllRegions || !seenRegions.insert(region).second)
+ return emitError(loc, "region '" + name + "' is already bound");
+ element = std::make_unique<RegionVariable>(region);
+ return success();
+ }
/// Results.
if (const auto *result = findArg(op.getResults(), name)) {
if (isTopLevel)
@@ -2172,8 +2404,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
element = std::make_unique<SuccessorVariable>(successor);
return success();
}
- return emitError(
- loc, "expected variable to refer to an argument, result, or successor");
+ return emitError(loc, "expected variable to refer to an argument, region, "
+ "result, or successor");
}
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@@ -2194,6 +2426,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
case Token::kw_operands:
return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
+ case Token::kw_regions:
+ return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
case Token::kw_successors:
@@ -2247,9 +2481,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
// optional fashion.
Element *firstElement = &*elements.front();
if (!isa<AttributeVariable>(firstElement) &&
- !isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
+ !isa<LiteralElement>(firstElement) &&
+ !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
return emitError(curLoc, "first element of an operand group must be an "
- "attribute, literal, or operand");
+ "attribute, literal, operand, or region");
// After parsing all of the elements, ensure that all type directives refer
// only to elements within the group.
@@ -2314,10 +2549,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
seenVariables.insert(ele->getVar());
return success();
})
+ .Case<RegionVariable>([&](RegionVariable *) {
+ // TODO: When ODS has proper support for marking "optional" regions, add
+ // a check here.
+ return success();
+ })
// Literals, custom directives, and type directives may be used,
// but they can't anchor the group.
- .Case<LiteralElement, CustomDirective, TypeDirective,
- FunctionalTypeDirective>([&](Element *) {
+ .Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
+ OptionalElement, TypeDirective>([&](Element *) {
if (isAnchor)
return emitError(childLoc, "only variables can be used to anchor "
"an optional group");
@@ -2401,7 +2641,7 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return failure();
// Verify that the element can be placed within a custom directive.
- if (!isa<TypeDirective, AttributeVariable, OperandVariable,
+ if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
SuccessorVariable>(parameters.back().get())) {
return emitError(childLoc, "only variables and types may be used as "
"parameters to a custom directive");
@@ -2433,13 +2673,27 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
LogicalResult
FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
- if (isTopLevel && (hasAllOperands || !seenOperands.empty()))
- return emitError(loc, "'operands' directive creates overlap in format");
- hasAllOperands = true;
+ if (isTopLevel) {
+ if (fmt.allOperands || !seenOperands.empty())
+ return emitError(loc, "'operands' directive creates overlap in format");
+ fmt.allOperands = true;
+ }
element = std::make_unique<OperandsDirective>();
return success();
}
+LogicalResult
+FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
+ llvm::SMLoc loc, bool isTopLevel) {
+ if (!isTopLevel)
+ return emitError(loc, "'regions' is only valid as a top-level directive");
+ if (hasAllRegions || !seenRegions.empty())
+ return emitError(loc, "'regions' directive creates overlap in format");
+ hasAllRegions = true;
+ element = std::make_unique<RegionsDirective>();
+ return success();
+}
+
LogicalResult
FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, bool isTopLevel) {
More information about the Mlir-commits
mailing list