[Mlir-commits] [mlir] eaeadce - [mlir][OpFormatGen] Add initial support for regions in the custom op assembly format

River Riddle llvmlistbot at llvm.org
Mon Aug 31 13:26:57 PDT 2020


Author: River Riddle
Date: 2020-08-31T13:26:24-07:00
New Revision: eaeadce9bd11d50cecfdf9e97ac471acd38136ee

URL: https://github.com/llvm/llvm-project/commit/eaeadce9bd11d50cecfdf9e97ac471acd38136ee
DIFF: https://github.com/llvm/llvm-project/commit/eaeadce9bd11d50cecfdf9e97ac471acd38136ee.diff

LOG: [mlir][OpFormatGen] Add initial support for regions in the custom op assembly format

This adds some initial support for regions and does not support formatting the specific arguments of a region. For now this can be achieved by using a custom directive that formats the arguments and then parses the region.

Differential Revision: https://reviews.llvm.org/D86760

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/IR/OperationSupport.h
    mlir/lib/IR/OperationSupport.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-format-spec.td
    mlir/test/mlir-tblgen/op-format.mlir
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 167546e67522..1b1a2125e95d 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -681,6 +681,10 @@ The available directives are as follows:
 
     -   Represents all of the operands of an operation.
 
+*   `regions`
+
+    -   Represents all of the regions of an operation.
+
 *   `results`
 
     -   Represents all of the results of an operation.
@@ -700,13 +704,14 @@ The available directives are as follows:
 A literal is either a keyword or punctuation surrounded by \`\`.
 
 The following are the set of valid punctuation:
-  `:`, `,`, `=`, `<`, `>`, `(`, `)`, `[`, `]`, `->`
+
+`:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`
 
 #### Variables
 
 A variable is an entity that has been registered on the operation itself, i.e.
-an argument(attribute or operand), result, successor, etc. In the `CallOp`
-example above, the variables would be `$callee` and `$args`.
+an argument(attribute or operand), region, result, successor, etc. In the
+`CallOp` example above, the variables would be `$callee` and `$args`.
 
 Attribute variables are printed with their respective value type, unless that
 value type is buildable. In those cases, the type of the attribute is elided.
@@ -747,6 +752,9 @@ declarative parameter to `parse` method argument is detailed below:
     -   Single: `OpAsmParser::OperandType &`
     -   Optional: `Optional<OpAsmParser::OperandType> &`
     -   Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
+*   Region Variables
+    -   Single: `Region &`
+    -   Variadic: `SmallVectorImpl<std::unique_ptr<Region>> &`
 *   Successor Variables
     -   Single: `Block *&`
     -   Variadic: `SmallVectorImpl<Block *> &`
@@ -770,6 +778,9 @@ declarative parameter to `print` method argument is detailed below:
     -   Single: `Value`
     -   Optional: `Value`
     -   Variadic: `OperandRange`
+*   Region Variables
+    -   Single: `Region &`
+    -   Variadic: `MutableArrayRef<Region>`
 *   Successor Variables
     -   Single: `Block *`
     -   Variadic: `SuccessorRange`
@@ -788,8 +799,8 @@ of the assembly format can be marked as `optional` based on the presence of this
 information. An optional group is defined by wrapping a set of elements within
 `()` followed by a `?` and has the following requirements:
 
-*   The first element of the group must either be a literal, attribute, or an
-    operand.
+*   The first element of the group must either be a attribute, literal, operand,
+    or region.
     -   This is because the first element must be optionally parsable.
 *   Exactly one argument variable within the group must be marked as the anchor
     of the group.
@@ -797,11 +808,15 @@ information. An optional group is defined by wrapping a set of elements within
         should be printed/parsed.
     -   An element is marked as the anchor by adding a trailing `^`.
     -   The first element is *not* required to be the anchor of the group.
+    -   When a non-variadic region anchors a group, the detector for printing
+        the group is if the region is empty.
 *   Literals, variables, custom directives, and type directives are the only
     valid elements within the group.
     -   Any attribute variable may be used, but only optional attributes can be
         marked as the anchor.
     -   Only variadic or optional operand arguments can be used.
+    -   All region variables can be used. When a non-variable length region is
+        used, if the group is not present the region is empty.
     -   The operands to a type directive must be defined within the optional
         group.
 
@@ -853,18 +868,22 @@ foo.op
 The format specification has a certain set of requirements that must be adhered
 to:
 
-1. The output and operation name are never shown as they are fixed and cannot be
-   altered.
-1. All operands within the operation must appear within the format, either
-   individually or with the `operands` directive.
-1. All operand and result types must appear within the format using the various
-   `type` directives, either individually or with the `operands` or `results`
-   directives.
-1. The `attr-dict` directive must always be present.
-1. Must not contain overlapping information; e.g. multiple instances of
-   'attr-dict', types, operands, etc.
-   -  Note that `attr-dict` does not overlap with individual attributes. These
-      attributes will simply be elided when printing the attribute dictionary.
+1.  The output and operation name are never shown as they are fixed and cannot
+    be altered.
+1.  All operands within the operation must appear within the format, either
+    individually or with the `operands` directive.
+1.  All regions within the operation must appear within the format, either
+    individually or with the `regions` directive.
+1.  All successors within the operation must appear within the format, either
+    individually or with the `successors` directive.
+1.  All operand and result types must appear within the format using the various
+    `type` directives, either individually or with the `operands` or `results`
+    directives.
+1.  The `attr-dict` directive must always be present.
+1.  Must not contain overlapping information; e.g. multiple instances of
+    'attr-dict', types, operands, etc.
+    -   Note that `attr-dict` does not overlap with individual attributes. These
+        attributes will simply be elided when printing the attribute dictionary.
 
 ##### Type Inference
 

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5962a1a48668..a30ffd8f75b4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -424,8 +424,8 @@ class OpAsmParser {
                                                      Type type,
                                                      StringRef attrName,
                                                      NamedAttrList &attrs) = 0;
-  OptionalParseResult parseOptionalAttribute(Attribute &result,
-                                             StringRef attrName,
+  template <typename AttrT>
+  OptionalParseResult parseOptionalAttribute(AttrT &result, StringRef attrName,
                                              NamedAttrList &attrs) {
     return parseOptionalAttribute(result, Type(), attrName, attrs);
   }
@@ -433,6 +433,7 @@ class OpAsmParser {
   /// Specialized variants of `parseOptionalAttribute` that remove potential
   /// ambiguities in syntax.
   virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+                                                     Type type,
                                                      StringRef attrName,
                                                      NamedAttrList &attrs) = 0;
 
@@ -621,16 +622,23 @@ class OpAsmParser {
   /// can only be set to true for regions attached to operations that are
   /// "IsolatedFromAbove".
   virtual ParseResult parseRegion(Region &region,
-                                  ArrayRef<OperandType> arguments,
-                                  ArrayRef<Type> argTypes,
+                                  ArrayRef<OperandType> arguments = {},
+                                  ArrayRef<Type> argTypes = {},
                                   bool enableNameShadowing = false) = 0;
 
   /// Parses a region if present.
   virtual ParseResult parseOptionalRegion(Region &region,
-                                          ArrayRef<OperandType> arguments,
-                                          ArrayRef<Type> argTypes,
+                                          ArrayRef<OperandType> arguments = {},
+                                          ArrayRef<Type> argTypes = {},
                                           bool enableNameShadowing = false) = 0;
 
+  /// Parses a region if present. If the region is present, a new region is
+  /// allocated and placed in `region`. If no region is present or on failure,
+  /// `region` remains untouched.
+  virtual OptionalParseResult parseOptionalRegion(
+      std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
+      ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
+
   /// Parse a region argument, this argument is resolved when calling
   /// 'parseRegion'.
   virtual ParseResult parseRegionArgument(OperandType &argument) = 0;

diff  --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index b1758ec8f5ca..b0e1205eefe6 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -414,6 +414,10 @@ struct OperationState {
   /// region is null, a new empty region will be attached to the Operation.
   void addRegion(std::unique_ptr<Region> &&region);
 
+  /// Take ownership of a set of regions that should be attached to the
+  /// Operation.
+  void addRegions(MutableArrayRef<std::unique_ptr<Region>> regions);
+
   /// Get the context held by this operation state.
   MLIRContext *getContext() const { return location->getContext(); }
 };

diff  --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index b477a8a23900..ab84f4e8cf17 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -199,6 +199,12 @@ void OperationState::addRegion(std::unique_ptr<Region> &&region) {
   regions.push_back(std::move(region));
 }
 
+void OperationState::addRegions(
+    MutableArrayRef<std::unique_ptr<Region>> regions) {
+  for (std::unique_ptr<Region> &region : regions)
+    addRegion(std::move(region));
+}
+
 //===----------------------------------------------------------------------===//
 // OperandStorage
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index b7cae2778c10..4e17ccd8022d 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -221,8 +221,9 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
     return result;
   }
 }
-OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute) {
-  return parseOptionalAttributeWithToken(Token::l_square, attribute);
+OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
+                                                   Type type) {
+  return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
 }
 
 /// Attribute dictionary.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index cc102b9e97a4..d6065f758fc1 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1045,7 +1045,6 @@ class CustomOpAsmParser : public OpAsmParser {
   }
 
   /// Parse an optional attribute.
-  /// Template utilities to simplify specifying multiple derived overloads.
   template <typename AttrT>
   OptionalParseResult
   parseOptionalAttributeAndAddToList(AttrT &result, Type type,
@@ -1056,25 +1055,15 @@ class CustomOpAsmParser : public OpAsmParser {
       attrs.push_back(parser.builder.getNamedAttr(attrName, result));
     return parseResult;
   }
-  template <typename AttrT>
-  OptionalParseResult parseOptionalAttributeAndAddToList(AttrT &result,
-                                                         StringRef attrName,
-                                                         NamedAttrList &attrs) {
-    OptionalParseResult parseResult = parser.parseOptionalAttribute(result);
-    if (parseResult.hasValue() && succeeded(*parseResult))
-      attrs.push_back(parser.builder.getNamedAttr(attrName, result));
-    return parseResult;
-  }
-
   OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
                                              StringRef attrName,
                                              NamedAttrList &attrs) override {
     return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
   }
-  OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
+  OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
                                              StringRef attrName,
                                              NamedAttrList &attrs) override {
-    return parseOptionalAttributeAndAddToList(result, attrName, attrs);
+    return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
   }
 
   /// Parse a named dictionary into 'result' if it is present.
@@ -1355,6 +1344,23 @@ class CustomOpAsmParser : public OpAsmParser {
     return parseRegion(region, arguments, argTypes, enableNameShadowing);
   }
 
+  /// Parses a region if present. If the region is present, a new region is
+  /// allocated and placed in `region`. If no region is present, `region`
+  /// remains untouched.
+  OptionalParseResult
+  parseOptionalRegion(std::unique_ptr<Region> &region,
+                      ArrayRef<OperandType> arguments, ArrayRef<Type> argTypes,
+                      bool enableNameShadowing = false) override {
+    if (parser.getToken().isNot(Token::l_brace))
+      return llvm::None;
+    std::unique_ptr<Region> newRegion = std::make_unique<Region>();
+    if (parseRegion(*newRegion, arguments, argTypes, enableNameShadowing))
+      return failure();
+
+    region = std::move(newRegion);
+    return success();
+  }
+
   /// Parse a region argument. The type of the argument will be resolved later
   /// by a call to `parseRegion`.
   ParseResult parseRegionArgument(OperandType &argument) override {

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 61e54be83139..c01a0a004072 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -187,7 +187,7 @@ class Parser {
   /// Parse an optional attribute with the provided type.
   OptionalParseResult parseOptionalAttribute(Attribute &attribute,
                                              Type type = {});
-  OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute);
+  OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
 
   /// Parse an optional attribute that is demarcated by a specific token.
   template <typename AttributeT>
@@ -197,8 +197,8 @@ class Parser {
     if (getToken().isNot(kind))
       return llvm::None;
 
-    if (Attribute parsedAttr = parseAttribute()) {
-      attr = parsedAttr.cast<ArrayAttr>();
+    if (Attribute parsedAttr = parseAttribute(type)) {
+      attr = parsedAttr.cast<AttributeT>();
       return success();
     }
     return failure();

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 292f5ed4b641..d75422e84124 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -319,6 +319,19 @@ static ParseResult parseCustomDirectiveOperandsAndTypes(
     return failure();
   return success();
 }
+static ParseResult parseCustomDirectiveRegions(
+    OpAsmParser &parser, Region &region,
+    SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
+  if (parser.parseRegion(region))
+    return failure();
+  if (failed(parser.parseOptionalComma()))
+    return success();
+  std::unique_ptr<Region> varRegion = std::make_unique<Region>();
+  if (parser.parseRegion(*varRegion))
+    return failure();
+  varRegions.emplace_back(std::move(varRegion));
+  return success();
+}
 static ParseResult
 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
                                SmallVectorImpl<Block *> &varSuccessors) {
@@ -361,6 +374,15 @@ printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
   printCustomDirectiveResults(printer, operandType, optOperandType,
                               varOperandTypes);
 }
+static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region &region,
+                                        MutableArrayRef<Region> varRegions) {
+  printer.printRegion(region);
+  if (!varRegions.empty()) {
+    printer << ", ";
+    for (Region &region : varRegions)
+      printer.printRegion(region);
+  }
+}
 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
                                            Block *successor,
                                            SuccessorRange varSuccessors) {

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0e186d0cd29b..bc26a8659831 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1161,8 +1161,13 @@ def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> {
 //===----------------------------------------------------------------------===//
 
 def TestRegionBuilderOp : TEST_Op<"region_builder">;
-def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]>,
-  Arguments<(ins Variadic<AnyType>)>;
+def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]> {
+  let arguments = (ins Variadic<AnyType>);
+  let builders = [
+    OpBuilder<"OpBuilder &builder, OperationState &state",
+              [{ build(builder, state, {}); }]>
+  ];
+}
 def TestCastOp : TEST_Op<"cast">,
   Arguments<(ins Variadic<AnyType>)>, Results<(outs AnyType)>;
 def TestInvalidOp : TEST_Op<"invalid", [Terminator]>,
@@ -1333,6 +1338,43 @@ def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> {
   let assemblyFormat = "$buildable attr-dict";
 }
 
+// Test various mixings of region formatting.
+class FormatRegionBase<string suffix, string fmt>
+    : TEST_Op<"format_region_" # suffix # "_op"> {
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = fmt;
+}
+def FormatRegionAOp : FormatRegionBase<"a", [{
+  regions attr-dict
+}]>;
+def FormatRegionBOp : FormatRegionBase<"b", [{
+  $region attr-dict
+}]>;
+def FormatRegionCOp : FormatRegionBase<"c", [{
+  (`region` $region^)? attr-dict
+}]>;
+class FormatVariadicRegionBase<string suffix, string fmt>
+    : TEST_Op<"format_variadic_region_" # suffix # "_op"> {
+  let regions = (region VariadicRegion<AnyRegion>:$regions);
+  let assemblyFormat = fmt;
+}
+def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{
+  $regions attr-dict
+}]>;
+def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{
+  ($regions^ `found_regions`)? attr-dict
+}]>;
+class FormatRegionImplicitTerminatorBase<string suffix, string fmt>
+    : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op",
+              [SingleBlockImplicitTerminator<"TestReturnOp">]> {
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = fmt;
+}
+def FormatFormatRegionImplicitTerminatorAOp
+    : FormatRegionImplicitTerminatorBase<"a", [{
+  $region attr-dict
+}]>;
+
 // Test various mixings of result type formatting.
 class FormatResultBase<string suffix, string fmt>
     : TEST_Op<"format_result_" # suffix # "_op"> {
@@ -1454,6 +1496,16 @@ def FormatCustomDirectiveOperandsAndTypes
   }];
 }
 
+def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> {
+  let regions = (region AnyRegion:$region, VariadicRegion<AnyRegion>:$regions);
+  let assemblyFormat = [{
+    custom<CustomDirectiveRegions>(
+      $region, $regions
+    )
+    attr-dict
+  }];
+}
+
 def FormatCustomDirectiveResults
     : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> {
   let results = (outs AnyType:$result, Optional<AnyType>:$optResult,

diff  --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 8e7c8ec56a2a..60189943ddab 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -133,6 +133,28 @@ def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{
   operands attr-dict
 }]>;
 
+//===----------------------------------------------------------------------===//
+// regions
+
+// CHECK: error: 'regions' directive creates overlap in format
+def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{
+  regions regions attr-dict
+}]>;
+// CHECK: error: 'regions' directive creates overlap in format
+def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{
+  $region regions attr-dict
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+// CHECK: error: 'regions' is only valid as a top-level directive
+def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{
+  type(regions)
+}]>;
+// CHECK-NOT: error:
+def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{
+  regions attr-dict
+}]>;
+
 //===----------------------------------------------------------------------===//
 // results
 
@@ -249,7 +271,7 @@ def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{
 def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{
   ($attr)? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first element of an operand group must be an attribute, literal, or operand
+// CHECK: error: first element of an operand group must be an attribute, literal, operand, or region
 def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{
   (type($operand) $operand^)? attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;
@@ -290,7 +312,7 @@ def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{
 // Variables
 //===----------------------------------------------------------------------===//
 
-// CHECK: error: expected variable to refer to an argument, result, or successor
+// CHECK: error: expected variable to refer to an argument, region, result, or successor
 def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{
   $unknown_arg attr-dict
 }]>;
@@ -330,11 +352,35 @@ def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{
 def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{
   (`foo` $attr^)? `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<ElementsAttr>:$attr)>;
-// CHECK-NOT: error:
+// CHECK: error: region 'region' is already bound
 def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{
+  $region $region attr-dict
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+// CHECK: error: region 'region' is already bound
+def VariableInvalidK : TestFormat_Op<"variable_invalid_K", [{
+  regions $region attr-dict
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+// CHECK: error: regions can only be used at the top level
+def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{
+  type($region)
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+// CHECK: error: region #0, named 'region', not found
+def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{
+  attr-dict
+}]> {
+  let regions = (region AnyRegion:$region);
+}
+// CHECK-NOT: error:
+def VariableValidA : TestFormat_Op<"variable_valid_a", [{
   $attr `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
-def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{
+def VariableValidB : TestFormat_Op<"variable_valid_b", [{
   (`foo` $attr^)? `:` attr-dict
 }]>, Arguments<(ins OptionalAttr<I1Attr>:$attr)>;
 

diff  --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 40eb17ac52dc..9f7c9c0f4809 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -40,6 +40,72 @@ test.format_attr_dict_w_keyword attributes {attr = 10 : i64, opt_attr = 10 : i64
 // CHECK: test.format_buildable_type_op %[[I64]]
 %ignored = test.format_buildable_type_op %i64
 
+//===----------------------------------------------------------------------===//
+// Format regions
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_region_a_op {
+// CHECK-NEXT: test.return
+test.format_region_a_op {
+  "test.return"() : () -> ()
+}
+
+// CHECK: test.format_region_b_op {
+// CHECK-NEXT: test.return
+test.format_region_b_op {
+  "test.return"() : () -> ()
+}
+
+// CHECK: test.format_region_c_op region {
+// CHECK-NEXT: test.return
+test.format_region_c_op region {
+  "test.return"() : () -> ()
+}
+// CHECK: test.format_region_c_op
+// CHECK-NOT: region {
+test.format_region_c_op
+
+// CHECK: test.format_variadic_region_a_op {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_variadic_region_a_op {
+  "test.return"() : () -> ()
+}, {
+  "test.return"() : () -> ()
+}
+// CHECK: test.format_variadic_region_b_op {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: } found_regions
+test.format_variadic_region_b_op {
+  "test.return"() : () -> ()
+}, {
+  "test.return"() : () -> ()
+} found_regions
+// CHECK: test.format_variadic_region_b_op
+// CHECK-NOT: {
+// CHECK-NOT: found_regions
+test.format_variadic_region_b_op
+
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: }
+test.format_implicit_terminator_region_a_op {
+  "test.return"() : () -> ()
+}
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: test.return"() {foo.attr
+test.format_implicit_terminator_region_a_op {
+  "test.return"() {foo.attr} : () -> ()
+}
+// CHECK: test.format_implicit_terminator_region_a_op {
+// CHECK-NEXT: test.return"(%[[I64]]) : (i64)
+test.format_implicit_terminator_region_a_op {
+  "test.return"(%i64) : (i64) -> ()
+}
+
 //===----------------------------------------------------------------------===//
 // Format results
 //===----------------------------------------------------------------------===//
@@ -147,6 +213,24 @@ test.format_custom_directive_operands_and_types %i64, %i64 -> (%i64) : i64, i64
 // CHECK: test.format_custom_directive_operands_and_types %[[I64]] -> (%[[I64]]) : i64 -> (i64)
 test.format_custom_directive_operands_and_types %i64 -> (%i64) : i64 -> (i64)
 
+// CHECK: test.format_custom_directive_regions {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_custom_directive_regions {
+  "test.return"() : () -> ()
+}
+
+// CHECK: test.format_custom_directive_regions {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }, {
+// CHECK-NEXT: test.return
+// CHECK-NEXT: }
+test.format_custom_directive_regions {
+  "test.return"() : () -> ()
+}, {
+  "test.return"() : () -> ()
+}
+
 // CHECK: test.format_custom_directive_results : i64, i64 -> (i64)
 test.format_custom_directive_results : i64, i64 -> (i64)
 

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 684c8ad8cb17..1542e9c55e41 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -16,6 +16,7 @@
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/Sequence.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -48,6 +49,7 @@ class Element {
     CustomDirective,
     FunctionalTypeDirective,
     OperandsDirective,
+    RegionsDirective,
     ResultsDirective,
     SuccessorsDirective,
     TypeDirective,
@@ -58,6 +60,7 @@ class Element {
     /// This element is an variable value.
     AttributeVariable,
     OperandVariable,
+    RegionVariable,
     ResultVariable,
     SuccessorVariable,
 
@@ -119,6 +122,10 @@ struct AttributeVariable
 using OperandVariable =
     VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
 
+/// This class represents a variable that refers to a region.
+using RegionVariable =
+    VariableElement<NamedRegion, Element::Kind::RegionVariable>;
+
 /// This class represents a variable that refers to a result.
 using ResultVariable =
     VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
@@ -133,7 +140,8 @@ using SuccessorVariable =
 
 namespace {
 /// This class implements single kind directives.
-template <Element::Kind type> class DirectiveElement : public Element {
+template <Element::Kind type>
+class DirectiveElement : public Element {
 public:
   DirectiveElement() : Element(type){};
   static bool classof(const Element *ele) { return ele->getKind() == type; }
@@ -142,6 +150,10 @@ template <Element::Kind type> class DirectiveElement : public Element {
 /// all of the operands of an operation.
 using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
 
+/// This class represents the `regions` directive. This directive represents
+/// all of the regions of an operation.
+using RegionsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
+
 /// This class represents the `results` directive. This directive represents
 /// all of the results of an operation.
 using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
@@ -350,13 +362,23 @@ struct OperationFormat {
       : allOperands(false), allOperandTypes(false), allResultTypes(false) {
     operandTypes.resize(op.getNumOperands(), TypeResolution());
     resultTypes.resize(op.getNumResults(), TypeResolution());
+
+    hasImplicitTermTrait =
+        llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
+          return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
+        });
   }
 
   /// Generate the operation parser from this format.
   void genParser(Operator &op, OpClass &opClass);
+  /// Generate the parser code for a specific format element.
+  void genElementParser(Element *element, OpMethodBody &body,
+                        FmtContext &attrTypeCtx);
   /// Generate the c++ to resolve the types of operands and results during
   /// parsing.
   void genParserTypeResolution(Operator &op, OpMethodBody &body);
+  /// Generate the c++ to resolve regions during parsing.
+  void genParserRegionResolution(Operator &op, OpMethodBody &body);
   /// Generate the c++ to resolve successors during parsing.
   void genParserSuccessorResolution(Operator &op, OpMethodBody &body);
   /// Generate the c++ to handling variadic segment size traits.
@@ -365,6 +387,10 @@ struct OperationFormat {
   /// Generate the operation printer from this format.
   void genPrinter(Operator &op, OpClass &opClass);
 
+  /// Generate the printer code for a specific format element.
+  void genElementPrinter(Element *element, OpMethodBody &body, Operator &op,
+                         bool &shouldEmitSpace, bool &lastWasPunctuation);
+
   /// The various elements in this format.
   std::vector<std::unique_ptr<Element>> elements;
 
@@ -372,11 +398,18 @@ struct OperationFormat {
   /// contains these, it can not contain individual type resolvers.
   bool allOperands, allOperandTypes, allResultTypes;
 
+  /// A flag indicating if this operation has the SingleBlockImplicitTerminator
+  /// trait.
+  bool hasImplicitTermTrait;
+
   /// A map of buildable types to indices.
   llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
 
   /// The index of the buildable type, if valid, for every operand and result.
   std::vector<TypeResolution> operandTypes, resultTypes;
+
+  /// The set of attributes explicitly used within the format.
+  SmallVector<const NamedAttribute *, 8> usedAttributes;
 };
 } // end anonymous namespace
 
@@ -541,6 +574,60 @@ const char *const functionalTypeParserCode = R"(
   {1}Types = {0}__{1}_functionType.getResults();
 )";
 
+/// The code snippet used to generate a parser call for a region list.
+///
+/// {0}: The name for the region list.
+const char *regionListParserCode = R"(
+  {
+    std::unique_ptr<::mlir::Region> region;
+    auto firstRegionResult = parser.parseOptionalRegion(region);
+    if (firstRegionResult.hasValue()) {
+      if (failed(*firstRegionResult))
+        return failure();
+      {0}Regions.emplace_back(std::move(region));
+
+      // Parse any trailing regions.
+      while (succeeded(parser.parseOptionalComma())) {
+        region = std::make_unique<::mlir::Region>();
+        if (parser.parseRegion(*region))
+          return failure();
+        {0}Regions.emplace_back(std::move(region));
+      }
+    }
+  }
+)";
+
+/// The code snippet used to ensure a list of regions have terminators.
+///
+/// {0}: The name of the region list.
+const char *regionListEnsureTerminatorParserCode = R"(
+  for (auto &region : {0}Regions)
+    ensureTerminator(*region, parser.getBuilder(), result.location);
+)";
+
+/// The code snippet used to generate a parser call for an optional region.
+///
+/// {0}: The name of the region.
+const char *optionalRegionParserCode = R"(
+  if (parser.parseOptionalRegion(*{0}Region))
+    return failure();
+)";
+
+/// The code snippet used to generate a parser call for a region.
+///
+/// {0}: The name of the region.
+const char *regionParserCode = R"(
+  if (parser.parseRegion(*{0}Region))
+    return failure();
+)";
+
+/// The code snippet used to ensure a region has a terminator.
+///
+/// {0}: The name of the region.
+const char *regionEnsureTerminatorParserCode = R"(
+  ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
+)";
+
 /// The code snippet used to generate a parser call for a successor list.
 ///
 /// {0}: The name for the successor list.
@@ -658,6 +745,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
     body << "  ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
             "allOperands;\n";
 
+  } else if (isa<RegionsDirective>(element)) {
+    body << "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
+            "fullRegions;\n";
+
   } else if (isa<SuccessorsDirective>(element)) {
     body << "  ::llvm::SmallVector<::mlir::Block *, 2> fullSuccessors;\n";
 
@@ -680,6 +771,20 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
     body << llvm::formatv("  ::llvm::SMLoc {0}OperandsLoc;\n"
                           "  (void){0}OperandsLoc;\n",
                           name);
+
+  } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+    StringRef name = region->getVar()->name;
+    if (region->getVar()->isVariadic()) {
+      body << llvm::formatv(
+          "  ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
+          "{0}Regions;\n",
+          name);
+    } else {
+      body << llvm::formatv("  std::unique_ptr<::mlir::Region> {0}Region = "
+                            "std::make_unique<::mlir::Region>();\n",
+                            name);
+    }
+
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
     StringRef name = successor->getVar()->name;
     if (successor->getVar()->isVariadic()) {
@@ -725,6 +830,13 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
     else
       body << formatv("{0}RawOperands[0]", name);
 
+  } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
+    StringRef name = region->getVar()->name;
+    if (region->getVar()->isVariadic())
+      body << llvm::formatv("{0}Regions", name);
+    else
+      body << llvm::formatv("*{0}Region", name);
+
   } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
     StringRef name = successor->getVar()->name;
     if (successor->getVar()->isVariadic())
@@ -809,9 +921,39 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
   body << "  }\n";
 }
 
-/// Generate the parser for a single format element.
-static void genElementParser(Element *element, OpMethodBody &body,
-                             FmtContext &attrTypeCtx) {
+void OperationFormat::genParser(Operator &op, OpClass &opClass) {
+  auto &method = opClass.newMethod(
+      "::mlir::ParseResult", "parse",
+      "::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
+      OpMethod::MP_Static);
+  auto &body = method.body();
+
+  // Generate variables to store the operands and type within the format. This
+  // allows for referencing these variables in the presence of optional
+  // groupings.
+  for (auto &element : elements)
+    genElementParserStorage(&*element, body);
+
+  // A format context used when parsing attributes with buildable types.
+  FmtContext attrTypeCtx;
+  attrTypeCtx.withBuilder("parser.getBuilder()");
+
+  // Generate parsers for each of the elements.
+  for (auto &element : elements)
+    genElementParser(element.get(), body, attrTypeCtx);
+
+  // Generate the code to resolve the operand/result types and successors now
+  // that they have been parsed.
+  genParserTypeResolution(op, body);
+  genParserRegionResolution(op, body);
+  genParserSuccessorResolution(op, body);
+  genParserVariadicSegmentResolution(op, body);
+
+  body << "  return success();\n";
+}
+
+void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
+                                       FmtContext &attrTypeCtx) {
   /// Optional Group.
   if (auto *optional = dyn_cast<OptionalElement>(element)) {
     auto elements = optional->getElements();
@@ -829,6 +971,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
     } else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
       genElementParser(opVar, body, attrTypeCtx);
       body << "  if (!" << opVar->getVar()->name << "Operands.empty()) {\n";
+    } else if (auto *regionVar = dyn_cast<RegionVariable>(firstElement)) {
+      const NamedRegion *region = regionVar->getVar();
+      if (region->isVariadic()) {
+        genElementParser(regionVar, body, attrTypeCtx);
+        body << "  if (!" << region->name << "Regions.empty()) {\n";
+      } else {
+        body << llvm::formatv(optionalRegionParserCode, region->name);
+        body << "  if (!" << region->name << "Region->empty()) {\n  ";
+        if (hasImplicitTermTrait)
+          body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
+      }
     }
 
     // If the anchor is a unit attribute, we don't need to print it. When
@@ -907,6 +1060,17 @@ static void genElementParser(Element *element, OpMethodBody &body,
       body << llvm::formatv(optionalOperandParserCode, name);
     else
       body << formatv(operandParserCode, name);
+
+  } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+    bool isVariadic = region->getVar()->isVariadic();
+    body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
+                          region->getVar()->name);
+    if (hasImplicitTermTrait) {
+      body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
+                                       : regionEnsureTerminatorParserCode,
+                            region->getVar()->name);
+    }
+
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
     bool isVariadic = successor->getVar()->isVariadic();
     body << formatv(isVariadic ? successorListParserCode : successorParserCode,
@@ -925,8 +1089,15 @@ static void genElementParser(Element *element, OpMethodBody &body,
     body << "  ::llvm::SMLoc allOperandLoc = parser.getCurrentLocation();\n"
          << "  if (parser.parseOperandList(allOperands))\n"
          << "    return failure();\n";
+
+  } else if (isa<RegionsDirective>(element)) {
+    body << llvm::formatv(regionListParserCode, "full");
+    if (hasImplicitTermTrait)
+      body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
+
   } else if (isa<SuccessorsDirective>(element)) {
     body << llvm::formatv(successorListParserCode, "full");
+
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
     ArgumentLengthKind lengthKind;
     StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@@ -946,36 +1117,6 @@ static void genElementParser(Element *element, OpMethodBody &body,
   }
 }
 
-void OperationFormat::genParser(Operator &op, OpClass &opClass) {
-  auto &method = opClass.newMethod(
-      "::mlir::ParseResult", "parse",
-      "::mlir::OpAsmParser &parser, ::mlir::OperationState &result",
-      OpMethod::MP_Static);
-  auto &body = method.body();
-
-  // Generate variables to store the operands and type within the format. This
-  // allows for referencing these variables in the presence of optional
-  // groupings.
-  for (auto &element : elements)
-    genElementParserStorage(&*element, body);
-
-  // A format context used when parsing attributes with buildable types.
-  FmtContext attrTypeCtx;
-  attrTypeCtx.withBuilder("parser.getBuilder()");
-
-  // Generate parsers for each of the elements.
-  for (auto &element : elements)
-    genElementParser(element.get(), body, attrTypeCtx);
-
-  // Generate the code to resolve the operand/result types and successors now
-  // that they have been parsed.
-  genParserTypeResolution(op, body);
-  genParserSuccessorResolution(op, body);
-  genParserVariadicSegmentResolution(op, body);
-
-  body << "  return success();\n";
-}
-
 void OperationFormat::genParserTypeResolution(Operator &op,
                                               OpMethodBody &body) {
   // If any of type resolutions use transformed variables, make sure that the
@@ -1133,6 +1274,25 @@ void OperationFormat::genParserTypeResolution(Operator &op,
   }
 }
 
+void OperationFormat::genParserRegionResolution(Operator &op,
+                                                OpMethodBody &body) {
+  // Check for the case where all regions were parsed.
+  bool hasAllRegions = llvm::any_of(
+      elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
+  if (hasAllRegions) {
+    body << "  result.addRegions(fullRegions);\n";
+    return;
+  }
+
+  // Otherwise, handle each region individually.
+  for (const NamedRegion &region : op.getRegions()) {
+    if (region.isVariadic())
+      body << "  result.addRegions(" << region.name << "Regions);\n";
+    else
+      body << "  result.addRegion(std::move(" << region.name << "Region));\n";
+  }
+}
+
 void OperationFormat::genParserSuccessorResolution(Operator &op,
                                                    OpMethodBody &body) {
   // Check for the case where all successors were parsed.
@@ -1186,23 +1346,26 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
 //===----------------------------------------------------------------------===//
 // PrinterGen
 
-/// Generate the printer for the 'attr-dict' directive.
-static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
-                               OpMethodBody &body, bool withKeyword) {
-  // Collect all of the attributes used in the format, these will be elided.
-  SmallVector<const NamedAttribute *, 1> usedAttributes;
-  for (auto &it : fmt.elements) {
-    if (auto *attr = dyn_cast<AttributeVariable>(it.get()))
-      usedAttributes.push_back(attr->getVar());
-    // Collect the optional attributes.
-    if (auto *opt = dyn_cast<OptionalElement>(it.get())) {
-      for (auto &elem : opt->getElements()) {
-        if (auto *attr = dyn_cast<AttributeVariable>(&elem))
-          usedAttributes.push_back(attr->getVar());
-      }
+/// The code snippet used to generate a printer call for a region of an
+// operation that has the SingleBlockImplicitTerminator trait.
+///
+/// {0}: The name of the region.
+const char *regionSingleBlockImplicitTerminatorPrinterCode = R"(
+  {
+    bool printTerminator = true;
+    if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{
+      printTerminator = !term->getMutableAttrDict().empty() ||
+                        term->getNumOperands() != 0 ||
+                        term->getNumResults() != 0;
     }
+    p.printRegion({0}, /*printEntryBlockArgs=*/true,
+                  /*printBlockTerminators=*/printTerminator);
   }
+)";
 
+/// Generate the printer for the 'attr-dict' directive.
+static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
+                               OpMethodBody &body, bool withKeyword) {
   body << "  p.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "")
        << "(getAttrs(), /*elidedAttrs=*/{";
   // Elide the variadic segment size attributes if necessary.
@@ -1210,9 +1373,9 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
     body << "\"operand_segment_sizes\", ";
   if (!fmt.allResultTypes && op.getTrait("OpTrait::AttrSizedResultSegments"))
     body << "\"result_segment_sizes\", ";
-  llvm::interleaveComma(usedAttributes, body, [&](const NamedAttribute *attr) {
-    body << "\"" << attr->name << "\"";
-  });
+  llvm::interleaveComma(
+      fmt.usedAttributes, body,
+      [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
   body << "});\n";
 }
 
@@ -1255,6 +1418,9 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
     } else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
       body << operand->getVar()->name << "()";
 
+    } else if (auto *region = dyn_cast<RegionVariable>(&param)) {
+      body << region->getVar()->name << "()";
+
     } else if (auto *successor = dyn_cast<SuccessorVariable>(&param)) {
       body << successor->getVar()->name << "()";
 
@@ -1277,6 +1443,24 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
   body << ");\n";
 }
 
+/// Generate the printer for a region with the given variable name.
+static void genRegionPrinter(const Twine &regionName, OpMethodBody &body,
+                             bool hasImplicitTermTrait) {
+  if (hasImplicitTermTrait)
+    body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
+                          regionName);
+  else
+    body << "  p.printRegion(" << regionName << ");\n";
+}
+static void genVariadicRegionPrinter(const Twine &regionListName,
+                                     OpMethodBody &body,
+                                     bool hasImplicitTermTrait) {
+  body << "    llvm::interleaveComma(" << regionListName
+       << ", p, [&](::mlir::Region &region) {\n      ";
+  genRegionPrinter("region", body, hasImplicitTermTrait);
+  body << "    });\n";
+}
+
 /// Generate the C++ for an operand to a (*-)type directive.
 static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
   if (isa<OperandsDirective>(arg))
@@ -1296,10 +1480,9 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
               << "().getType())";
 }
 
-/// Generate the code for printing the given element.
-static void genElementPrinter(Element *element, OpMethodBody &body,
-                              OperationFormat &fmt, Operator &op,
-                              bool &shouldEmitSpace, bool &lastWasPunctuation) {
+void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
+                                        Operator &op, bool &shouldEmitSpace,
+                                        bool &lastWasPunctuation) {
   if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
     return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
                              lastWasPunctuation);
@@ -1314,6 +1497,11 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
         body << "  if (" << var->name << "()) {\n";
       else if (var->isVariadic())
         body << "  if (!" << var->name << "().empty()) {\n";
+    } else if (auto *region = dyn_cast<RegionVariable>(anchor)) {
+      const NamedRegion *var = region->getVar();
+      // TODO: Add a check for optional here when ODS supports it.
+      body << "  if (!" << var->name << "().empty()) {\n";
+
     } else {
       body << "  if (getAttr(\""
            << cast<AttributeVariable>(anchor)->getVar()->name << "\")) {\n";
@@ -1332,7 +1520,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
     // Emit each of the elements.
     for (Element &childElement : elements) {
       if (&childElement != elidedAnchorElement) {
-        genElementPrinter(&childElement, body, fmt, op, shouldEmitSpace,
+        genElementPrinter(&childElement, body, op, shouldEmitSpace,
                           lastWasPunctuation);
       }
     }
@@ -1342,7 +1530,7 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
 
   // Emit the attribute dictionary.
   if (auto *attrDict = dyn_cast<AttrDictDirective>(element)) {
-    genAttrDictPrinter(fmt, op, body, attrDict->isWithKeyword());
+    genAttrDictPrinter(*this, op, body, attrDict->isWithKeyword());
     lastWasPunctuation = false;
     return;
   }
@@ -1384,6 +1572,13 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
     } else {
       body << "  p << " << operand->getVar()->name << "();\n";
     }
+  } else if (auto *region = dyn_cast<RegionVariable>(element)) {
+    const NamedRegion *var = region->getVar();
+    if (var->isVariadic()) {
+      genVariadicRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+    } else {
+      genRegionPrinter(var->name + "()", body, hasImplicitTermTrait);
+    }
   } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
     const NamedSuccessor *var = successor->getVar();
     if (var->isVariadic())
@@ -1394,6 +1589,9 @@ static void genElementPrinter(Element *element, OpMethodBody &body,
     genCustomDirectivePrinter(dir, body);
   } else if (isa<OperandsDirective>(element)) {
     body << "  p << getOperation()->getOperands();\n";
+  } else if (isa<RegionsDirective>(element)) {
+    genVariadicRegionPrinter("getOperation()->getRegions()", body,
+                             hasImplicitTermTrait);
   } else if (isa<SuccessorsDirective>(element)) {
     body << "  ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
   } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
@@ -1426,7 +1624,7 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
   // punctuation.
   bool shouldEmitSpace = true, lastWasPunctuation = false;
   for (auto &element : elements)
-    genElementPrinter(element.get(), body, *this, op, shouldEmitSpace,
+    genElementPrinter(element.get(), body, op, shouldEmitSpace,
                       lastWasPunctuation);
 }
 
@@ -1460,6 +1658,7 @@ class Token {
     kw_custom,
     kw_functional_type,
     kw_operands,
+    kw_regions,
     kw_results,
     kw_successors,
     kw_type,
@@ -1663,6 +1862,7 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
           .Case("custom", Token::kw_custom)
           .Case("functional-type", Token::kw_functional_type)
           .Case("operands", Token::kw_operands)
+          .Case("regions", Token::kw_regions)
           .Case("results", Token::kw_results)
           .Case("successors", Token::kw_successors)
           .Case("type", Token::kw_type)
@@ -1676,7 +1876,8 @@ Token FormatLexer::lexIdentifier(const char *tokStart) {
 
 /// Function to find an element within the given range that has the same name as
 /// 'name'.
-template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
+template <typename RangeT>
+static auto findArg(RangeT &&range, StringRef name) {
   auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; });
   return it != range.end() ? &*it : nullptr;
 }
@@ -1719,6 +1920,9 @@ class FormatParser {
   verifyOperands(llvm::SMLoc loc,
                  llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
 
+  /// Verify the state of operation regions within the format.
+  LogicalResult verifyRegions(llvm::SMLoc loc);
+
   /// Verify the state of operation results within the format.
   LogicalResult
   verifyResults(llvm::SMLoc loc,
@@ -1775,6 +1979,8 @@ class FormatParser {
                                              Token tok, bool isTopLevel);
   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
                                        llvm::SMLoc loc, bool isTopLevel);
+  LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
+                                      llvm::SMLoc loc, bool isTopLevel);
   LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
                                       llvm::SMLoc loc, bool isTopLevel);
   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
@@ -1821,11 +2027,12 @@ class FormatParser {
 
   // The following are various bits of format state used for verification
   // during parsing.
-  bool hasAllOperands = false, hasAttrDict = false;
-  bool hasAllSuccessors = false;
+  bool hasAttrDict = false;
+  bool hasAllRegions = false, hasAllSuccessors = false;
   llvm::SmallBitVector seenOperandTypes, seenResultTypes;
+  llvm::SmallSetVector<const NamedAttribute *, 8> seenAttrs;
   llvm::DenseSet<const NamedTypeConstraint *> seenOperands;
-  llvm::DenseSet<const NamedAttribute *> seenAttrs;
+  llvm::DenseSet<const NamedRegion *> seenRegions;
   llvm::DenseSet<const NamedSuccessor *> seenSuccessors;
   llvm::DenseSet<const NamedTypeConstraint *> optionalVariables;
 };
@@ -1867,13 +2074,11 @@ LogicalResult FormatParser::parse() {
   if (failed(verifyAttributes(loc)) ||
       failed(verifyResults(loc, variableTyResolver)) ||
       failed(verifyOperands(loc, variableTyResolver)) ||
-      failed(verifySuccessors(loc)))
+      failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
     return failure();
 
-  // Check to see if we are formatting all of the operands.
-  fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) {
-    return isa<OperandsDirective>(elt.get());
-  });
+  // Collect the set of used attributes in the format.
+  fmt.usedAttributes = seenAttrs.takeVector();
   return success();
 }
 
@@ -1953,7 +2158,7 @@ LogicalResult FormatParser::verifyOperands(
     NamedTypeConstraint &operand = op.getOperand(i);
 
     // Check that the operand itself is in the format.
-    if (!hasAllOperands && !seenOperands.count(&operand)) {
+    if (!fmt.allOperands && !seenOperands.count(&operand)) {
       return emitErrorAndNote(loc,
                               "operand #" + Twine(i) + ", named '" +
                                   operand.name + "', not found",
@@ -1976,7 +2181,7 @@ LogicalResult FormatParser::verifyOperands(
     // Similarly to results, allow a custom builder for resolving the type if
     // we aren't using the 'operands' directive.
     Optional<StringRef> builder = operand.constraint.getBuilderCall();
-    if (!builder || (hasAllOperands && operand.isVariableLength())) {
+    if (!builder || (fmt.allOperands && operand.isVariableLength())) {
       return emitErrorAndNote(
           loc,
           "type of operand #" + Twine(i) + ", named '" + operand.name +
@@ -1991,6 +2196,24 @@ LogicalResult FormatParser::verifyOperands(
   return success();
 }
 
+LogicalResult FormatParser::verifyRegions(llvm::SMLoc loc) {
+  // Check that all of the regions are within the format.
+  if (hasAllRegions)
+    return success();
+
+  for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
+    const NamedRegion &region = op.getRegion(i);
+    if (!seenRegions.count(&region)) {
+      return emitErrorAndNote(loc,
+                              "region #" + Twine(i) + ", named '" +
+                                  region.name + "', not found",
+                              "suggest adding a '$" + region.name +
+                                  "' directive to the custom assembly format");
+    }
+  }
+  return success();
+}
+
 LogicalResult FormatParser::verifyResults(
     llvm::SMLoc loc,
     llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
@@ -2108,7 +2331,7 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
   if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
     return seenResultTypes.test(arg - op.result_begin()) ? arg : nullptr;
   if (const NamedAttribute *attr = findArg(op.getAttributes(), name))
-    return seenAttrs.find_as(attr) != seenAttrs.end() ? attr : nullptr;
+    return seenAttrs.count(attr) ? attr : nullptr;
   return nullptr;
 }
 
@@ -2142,7 +2365,7 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
   // op.
   /// Attributes
   if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
-    if (isTopLevel && !seenAttrs.insert(attr).second)
+    if (isTopLevel && !seenAttrs.insert(attr))
       return emitError(loc, "attribute '" + name + "' is already bound");
     element = std::make_unique<AttributeVariable>(attr);
     return success();
@@ -2150,12 +2373,21 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
   /// Operands
   if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
     if (isTopLevel) {
-      if (hasAllOperands || !seenOperands.insert(operand).second)
+      if (fmt.allOperands || !seenOperands.insert(operand).second)
         return emitError(loc, "operand '" + name + "' is already bound");
     }
     element = std::make_unique<OperandVariable>(operand);
     return success();
   }
+  /// Regions
+  if (const NamedRegion *region = findArg(op.getRegions(), name)) {
+    if (!isTopLevel)
+      return emitError(loc, "regions can only be used at the top level");
+    if (hasAllRegions || !seenRegions.insert(region).second)
+      return emitError(loc, "region '" + name + "' is already bound");
+    element = std::make_unique<RegionVariable>(region);
+    return success();
+  }
   /// Results.
   if (const auto *result = findArg(op.getResults(), name)) {
     if (isTopLevel)
@@ -2172,8 +2404,8 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
     element = std::make_unique<SuccessorVariable>(successor);
     return success();
   }
-  return emitError(
-      loc, "expected variable to refer to an argument, result, or successor");
+  return emitError(loc, "expected variable to refer to an argument, region, "
+                        "result, or successor");
 }
 
 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
@@ -2194,6 +2426,8 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
     return parseFunctionalTypeDirective(element, dirTok, isTopLevel);
   case Token::kw_operands:
     return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel);
+  case Token::kw_regions:
+    return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_results:
     return parseResultsDirective(element, dirTok.getLoc(), isTopLevel);
   case Token::kw_successors:
@@ -2247,9 +2481,10 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
   // optional fashion.
   Element *firstElement = &*elements.front();
   if (!isa<AttributeVariable>(firstElement) &&
-      !isa<LiteralElement>(firstElement) && !isa<OperandVariable>(firstElement))
+      !isa<LiteralElement>(firstElement) &&
+      !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
     return emitError(curLoc, "first element of an operand group must be an "
-                             "attribute, literal, or operand");
+                             "attribute, literal, operand, or region");
 
   // After parsing all of the elements, ensure that all type directives refer
   // only to elements within the group.
@@ -2314,10 +2549,15 @@ LogicalResult FormatParser::parseOptionalChildElement(
         seenVariables.insert(ele->getVar());
         return success();
       })
+      .Case<RegionVariable>([&](RegionVariable *) {
+        // TODO: When ODS has proper support for marking "optional" regions, add
+        // a check here.
+        return success();
+      })
       // Literals, custom directives, and type directives may be used,
       // but they can't anchor the group.
-      .Case<LiteralElement, CustomDirective, TypeDirective,
-            FunctionalTypeDirective>([&](Element *) {
+      .Case<LiteralElement, CustomDirective, FunctionalTypeDirective,
+            OptionalElement, TypeDirective>([&](Element *) {
         if (isAnchor)
           return emitError(childLoc, "only variables can be used to anchor "
                                      "an optional group");
@@ -2401,7 +2641,7 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
     return failure();
 
   // Verify that the element can be placed within a custom directive.
-  if (!isa<TypeDirective, AttributeVariable, OperandVariable,
+  if (!isa<TypeDirective, AttributeVariable, OperandVariable, RegionVariable,
            SuccessorVariable>(parameters.back().get())) {
     return emitError(childLoc, "only variables and types may be used as "
                                "parameters to a custom directive");
@@ -2433,13 +2673,27 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
 LogicalResult
 FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
                                      llvm::SMLoc loc, bool isTopLevel) {
-  if (isTopLevel && (hasAllOperands || !seenOperands.empty()))
-    return emitError(loc, "'operands' directive creates overlap in format");
-  hasAllOperands = true;
+  if (isTopLevel) {
+    if (fmt.allOperands || !seenOperands.empty())
+      return emitError(loc, "'operands' directive creates overlap in format");
+    fmt.allOperands = true;
+  }
   element = std::make_unique<OperandsDirective>();
   return success();
 }
 
+LogicalResult
+FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
+                                    llvm::SMLoc loc, bool isTopLevel) {
+  if (!isTopLevel)
+    return emitError(loc, "'regions' is only valid as a top-level directive");
+  if (hasAllRegions || !seenRegions.empty())
+    return emitError(loc, "'regions' directive creates overlap in format");
+  hasAllRegions = true;
+  element = std::make_unique<RegionsDirective>();
+  return success();
+}
+
 LogicalResult
 FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
                                     llvm::SMLoc loc, bool isTopLevel) {


        


More information about the Mlir-commits mailing list