[Mlir-commits] [mlir] 9a2fdc3 - [MLIR] Attribute and type formats in ODS

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 8 09:38:32 PST 2021


Author: Jeff Niu
Date: 2021-11-08T17:38:28Z
New Revision: 9a2fdc369dae21a6bb1d2145479a76d3775c980b

URL: https://github.com/llvm/llvm-project/commit/9a2fdc369dae21a6bb1d2145479a76d3775c980b
DIFF: https://github.com/llvm/llvm-project/commit/9a2fdc369dae21a6bb1d2145479a76d3775c980b.diff

LOG: [MLIR] Attribute and type formats in ODS

Declarative attribute and type formats with assembly formats. Define an
`assemblyFormat` field in attribute and type defs with a `mnemonic` to
generate a parser and printer.

```tablegen
def MyAttr : AttrDef<MyDialect, "MyAttr"> {
  let parameters = (ins "int64_t":$count, "AffineMap":$map);
  let mnemonic = "my_attr";
  let assemblyFormat = "`<` $count `,` $map `>`";
}
```

Use `struct` to define a comma-separated list of key-value pairs:

```tablegen
def MyType : TypeDef<MyDialect, "MyType"> {
  let parameters = (ins "int":$one, "int":$two, "int":$three);
  let mnemonic = "my_attr";
  let assemblyFormat = "`<` $three `:` struct($one, $two) `>`";
}
```

Use `struct(*)` to capture all parameters.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D111594

Added: 
    mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
    mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
    mlir/tools/mlir-tblgen/FormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.h

Modified: 
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/include/mlir/TableGen/Dialect.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestAttributes.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/CMakeLists.txt
    mlir/tools/mlir-tblgen/OpFormatGen.cpp
    utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 30d6a6e9412e8..0f8edc5bf1ae7 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -382,3 +382,172 @@ the things named `*Type` are generally now named `*Attr`.
 
 Aside from that, all of the interfaces for uniquing and storage construction are
 all the same.
+
+## Defining Custom Parsers and Printers using Assembly Formats
+
+Attributes and types defined in ODS with a mnemonic can define an
+`assemblyFormat` to declaratively describe custom parsers and printers. The
+assembly format consists of literals, variables, and directives.
+
+* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
+  `` `keyword` `` or `` `<` ``.
+* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
+  which captures one attribute or type parameter.
+* A directive is a keyword followed by an optional argument list that defines
+  special parser and printer behaviour.
+
+```tablegen
+// An example type with an assembly format.
+def MyType : TypeDef<My_Dialect, "MyType"> {
+  // Define a mnemonic to allow the dialect's parser hook to call into the
+  // generated parser.
+  let mnemonic = "my_type";
+
+  // Define two parameters whose C++ types are indicated in string literals.
+  let parameters = (ins "int":$count, "AffineMap":$map);
+
+  // Define the assembly format. Surround the format with less `<` and greater
+  // `>` so that MLIR's printers use the pretty format.
+  let assemblyFormat = "`<` $count `,` `map` `=` $map `>`";
+}
+```
+
+The declarative assembly format for `MyType` results in the following format
+in the IR:
+
+```mlir
+!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
+```
+
+### Parameter Parsing and Printing
+
+For many basic parameter types, no additional work is needed to define how
+these parameters are parsed or printerd.
+
+* The default printer for any parameter is `$_printer << $_self`,
+  where `$_self` is the C++ value of the parameter and `$_printer` is a
+  `DialectAsmPrinter`.
+* The default parser for a parameter is
+  `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
+  of the parameter and `$_parser` is a `DialectAsmParser`.
+
+Printing and parsing behaviour can be added to additional C++ types by
+overloading these functions or by defining a `parser` and `printer` in an ODS
+parameter class.
+
+Example of overloading:
+
+```c++
+using MyParameter = std::pair<int, int>;
+
+DialectAsmPrinter &operator<<(DialectAsmPrinter &printer, MyParameter param) {
+  printer << param.first << " * " << param.second;
+}
+
+template <> struct FieldParser<MyParameter> {
+  static FailureOr<MyParameter> parse(DialectAsmParser &parser) {
+    int a, b;
+    if (parser.parseInteger(a) || parser.parseStar() ||
+        parser.parseInteger(b))
+      return failure();
+    return MyParameter(a, b);
+  }
+};
+```
+
+Example of using ODS parameter classes:
+
+```
+def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
+  let printer = [{ $_printer << $_self.first << " * " << $_self.second }];
+  let parser = [{ [&] -> FailureOr<std::pair<int, int>> {
+    int a, b;
+    if ($_parser.parseInteger(a) || $_parser.parseStar() ||
+        $_parser.parseInteger(b))
+      return failure();
+    return std::make_pair(a, b);
+  }() }];
+}
+```
+
+A type using this parameter with the assembly format `` `<` $myParam `>` ``
+will look as follows in the IR:
+
+```mlir
+!my_dialect.my_type<42 * 24>
+```
+
+#### Non-POD Parameters
+
+Parameters that aren't plain-old-data (e.g. references) may need to define a
+`cppStorageType` to contain the data until it is copied into the allocator.
+For example, `StringRefParameter` uses `std::string` as its storage type,
+whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
+for these parameters are expected to return `FailureOr<$cppStorageType>`.
+
+### Assembly Format Directives
+
+Attribute and type assembly formats have the following directives:
+
+* `params`: capture all parameters of an attribute or type.
+* `struct`: generate a "struct-like" parser and printer for a list of key-value
+  pairs.
+
+#### `params` Directive
+
+This directive is used to refer to all parameters of an attribute or type.
+When used as a top-level directive, `params` generates a parser and printer for
+a comma-separated list of the parameters. For example:
+
+```tablegen
+def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
+  let parameters = (ins "int":$a, "int":$b);
+  let mnemonic = "pair";
+  let assemblyFormat = "`<` params `>`";
+}
+```
+
+In the IR, this type will appear as:
+
+```mlir
+!my_dialect.pair<42, 24>
+```
+
+The `params` directive can also be passed to other directives, such as `struct`,
+as an argument that refers to all parameters in place of explicitly listing all
+parameters as variables.
+
+#### `struct` Directive
+
+The `struct` directive accepts a list of variables to capture and will generate
+a parser and printer for a comma-separated list of key-value pairs. The
+variables are printed in the order they are specified in the argument list **but
+can be parsed in any order**. For example:
+
+```tablegen
+def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
+  let parameters = (ins StringRefParameter<>:$sym_name,
+                        "int":$a, "int":$b, "int":$c);
+  let mnemonic = "struct";
+  let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`";
+}
+```
+
+In the IR, this type can appear with any permutation of the order of the
+parameters captured in the directive.
+
+```mlir
+!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3>
+!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1>
+```
+
+Passing `params` as the only argument to `struct` makes the directive capture
+all the parameters of the attribute or type. For the same type above, an
+assembly format of `` `<` struct(params) `>` `` will result in:
+
+```mlir
+!my_dialect.struct<b = 2, sym_name = "foo", c = 3, a = 1>
+```
+
+The order in which the parameters are printed is the order in which they are
+declared in the attribute's or type's `parameter` list.

diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 728e24605c29f..9302eb9146d4b 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -47,6 +47,74 @@ class DialectAsmParser : public AsmParser {
   virtual StringRef getFullSymbolSpec() const = 0;
 };
 
+//===----------------------------------------------------------------------===//
+// Parse Fields
+//===----------------------------------------------------------------------===//
+
+/// Provide a template class that can be specialized by users to dispatch to
+/// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`,
+/// where `T` is the parameter storage type, to parse custom types.
+template <typename T, typename = T>
+struct FieldParser;
+
+/// Parse an attribute.
+template <typename AttributeT>
+struct FieldParser<
+    AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
+                                 AttributeT>> {
+  static FailureOr<AttributeT> parse(DialectAsmParser &parser) {
+    AttributeT value;
+    if (parser.parseAttribute(value))
+      return failure();
+    return value;
+  }
+};
+
+/// Parse any integer.
+template <typename IntT>
+struct FieldParser<IntT,
+                   std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
+  static FailureOr<IntT> parse(DialectAsmParser &parser) {
+    IntT value;
+    if (parser.parseInteger(value))
+      return failure();
+    return value;
+  }
+};
+
+/// Parse a string.
+template <>
+struct FieldParser<std::string> {
+  static FailureOr<std::string> parse(DialectAsmParser &parser) {
+    std::string value;
+    if (parser.parseString(&value))
+      return failure();
+    return value;
+  }
+};
+
+/// Parse any container that supports back insertion as a list.
+template <typename ContainerT>
+struct FieldParser<
+    ContainerT, std::enable_if_t<std::is_member_function_pointer<
+                                     decltype(&ContainerT::push_back)>::value,
+                                 ContainerT>> {
+  using ElementT = typename ContainerT::value_type;
+  static FailureOr<ContainerT> parse(DialectAsmParser &parser) {
+    ContainerT elements;
+    auto elementParser = [&]() {
+      auto element = FieldParser<ElementT>::parse(parser);
+      if (failed(element))
+        return failure();
+      elements.push_back(element.getValue());
+      return success();
+    };
+    if (parser.parseCommaSeparatedList(elementParser))
+      return failure();
+    return elements;
+  }
+};
+
 } // end namespace mlir
 
-#endif
+#endif // MLIR_IR_DIALECTIMPLEMENTATION_H

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index e63a2672bf316..37bf1d233c2b3 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2886,6 +2886,11 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
   code printer = ?;
   code parser = ?;
 
+  // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
+  // specified at the same time as either 'printer' or 'parser'. The generated
+  // printer requires 'genAccessors' to be true.
+  string assemblyFormat = ?;
+
   // If set, generate accessors for each parameter.
   bit genAccessors = 1;
 
@@ -2964,10 +2969,22 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   string cppType = type;
   // The C++ type of the accessor for this parameter.
   string cppAccessorType = !if(!empty(accessorType), type, accessorType);
+  // The C++ storage type of of this parameter if it is a reference, e.g.
+  // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
+  string cppStorageType = ?;
   // One-line human-readable description of the argument.
   string summary = desc;
   // The format string for the asm syntax (documentation only).
   string syntax = ?;
+  // The default parameter parser is `::mlir::parseField<T>($_parser)`, which
+  // returns `FailureOr<T>`. Overload `parseField` to support parsing for your
+  // type. Or you can provide a customer printer. For attributes, "$_type" will
+  // be replaced with the required attribute type.
+  string parser = ?;
+  // The default parameter printer is `$_printer << $_self`. Overload the stream
+  // operator of `DialectAsmPrinter` as necessary to print your type. Or you can
+  // provide a custom printer.
+  string printer = ?;
 }
 class AttrParameter<string type, string desc, string accessorType = "">
  : AttrOrTypeParameter<type, desc, accessorType>;
@@ -2978,6 +2995,8 @@ class TypeParameter<string type, string desc, string accessorType = "">
 class StringRefParameter<string desc = ""> :
     AttrOrTypeParameter<"::llvm::StringRef", desc> {
   let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+  let printer = [{$_printer << '"' << $_self << '"';}];
+  let cppStorageType = "std::string";
 }
 
 // For APFloats, which require comparison.
@@ -2990,6 +3009,7 @@ class APFloatParameter<string desc> :
 class ArrayRefParameter<string arrayOf, string desc = ""> :
     AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
   let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+  let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">";
 }
 
 // For classes which require allocation and have their own allocateInto method.

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index db1f7a3c071d2..34e6cd08ea3c7 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -182,10 +182,10 @@ operator<<(AsmPrinterT &p, const TypeRange &types) {
   llvm::interleaveComma(types, p);
   return p;
 }
-template <typename AsmPrinterT>
+template <typename AsmPrinterT, typename ElementT>
 inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
                         AsmPrinterT &>
-operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
+operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
   llvm::interleaveComma(types, p);
   return p;
 }

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 2029c0e624cd3..09294c2fa8081 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -101,6 +101,9 @@ class AttrOrTypeDef {
   // None. Otherwise, returns the contents of that code block.
   Optional<StringRef> getParserCode() const;
 
+  // Returns the custom assembly format, if one was specified.
+  Optional<StringRef> getAssemblyFormat() const;
+
   // Returns true if the accessors based on the parameters should be generated.
   bool genAccessors() const;
 
@@ -199,6 +202,15 @@ class AttrOrTypeParameter {
   // Get the C++ accessor type of this parameter.
   StringRef getCppAccessorType() const;
 
+  // Get the C++ storage type of this parameter.
+  StringRef getCppStorageType() const;
+
+  // Get an optional C++ parameter parser.
+  Optional<StringRef> getParser() const;
+
+  // Get an optional C++ parameter printer.
+  Optional<StringRef> getPrinter() const;
+
   // Get a description of this parameter for documentation purposes.
   Optional<StringRef> getSummary() const;
 

diff  --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 2de0d9b0406eb..3030d6556b5bd 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -1,3 +1,4 @@
+//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 2a0ad96ea4e93..f43949c30a222 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -132,6 +132,10 @@ Optional<StringRef> AttrOrTypeDef::getParserCode() const {
   return def->getValueAsOptionalString("parser");
 }
 
+Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
+  return def->getValueAsOptionalString("assemblyFormat");
+}
+
 bool AttrOrTypeDef::genAccessors() const {
   return def->getValueAsBit("genAccessors");
 }
@@ -219,6 +223,32 @@ StringRef AttrOrTypeParameter::getCppAccessorType() const {
   return getCppType();
 }
 
+StringRef AttrOrTypeParameter::getCppStorageType() const {
+  if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
+    if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
+      return *type;
+  }
+  return getCppType();
+}
+
+Optional<StringRef> AttrOrTypeParameter::getParser() const {
+  auto *parameterType = def->getArg(index);
+  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+    if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
+      return *parser;
+  }
+  return {};
+}
+
+Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
+  auto *parameterType = def->getArg(index);
+  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+    if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
+      return *printer;
+  }
+  return {};
+}
+
 Optional<StringRef> AttrOrTypeParameter::getSummary() const {
   auto *parameterType = def->getArg(index);
   if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 3062fd6c65ca9..06e2599def7df 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -116,4 +116,44 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
   );
 }
 
+def TestParamOne : AttrParameter<"int64_t", ""> {}
+
+def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
+  let printer = "$_printer << '\"' << $_self << '\"'";
+}
+
+def TestParamFour : ArrayRefParameter<"int", ""> {
+  let cppStorageType = "llvm::SmallVector<int>";
+  let parser = "::parseIntArray($_parser)";
+  let printer = "::printIntArray($_printer, $_self)";
+}
+
+def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
+  let parameters = (
+    ins
+    TestParamOne:$one,
+    TestParamTwo:$two,
+    "::mlir::IntegerAttr":$three,
+    TestParamFour:$four
+  );
+
+  let mnemonic = "attr_with_format";
+  let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`";
+  let genVerifyDecl = 1;
+}
+
+def TestAttrUgly : Test_Attr<"TestAttrUgly"> {
+  let parameters = (ins "::mlir::Attribute":$attr);
+
+  let mnemonic = "attr_ugly";
+  let assemblyFormat = "`begin` $attr `end`";
+}
+
+def TestAttrParams: Test_Attr<"TestAttrParams"> {
+  let parameters = (ins "int":$v0, "int":$v1);
+
+  let mnemonic = "attr_params";
+  let assemblyFormat = "`<` params `>`";
+}
+
 #endif // TEST_ATTRDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 9cd9c574a7bf2..e0c8ebbb7b93f 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -16,9 +16,11 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/bit.h"
 
 using namespace mlir;
 using namespace test;
@@ -127,6 +129,36 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+LogicalResult
+TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+                               int64_t one, std::string two, IntegerAttr three,
+                               ArrayRef<int> four) {
+  if (four.size() != static_cast<unsigned>(one))
+    return emitError() << "expected 'one' to equal 'four.size()'";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Utility Functions for Generated Attributes
+//===----------------------------------------------------------------------===//
+
+static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
+  SmallVector<int> ints;
+  if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
+        ints.push_back(0);
+        return parser.parseInteger(ints.back());
+      }) ||
+      parser.parseRSquare())
+    return failure();
+  return ints;
+}
+
+static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
+  printer << '[';
+  llvm::interleaveComma(ints, printer);
+  printer << ']';
+}
+
 //===----------------------------------------------------------------------===//
 // TestSubElementsAccessAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index aef9baa894737..66008f12d574b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -15,6 +15,7 @@
 
 // To get the test dialect def.
 include "TestOps.td"
+include "TestAttrDefs.td"
 include "mlir/IR/BuiltinTypes.td"
 include "mlir/Interfaces/DataLayoutInterfaces.td"
 
@@ -189,4 +190,44 @@ def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
   let mnemonic = "test_type_with_trait";
 }
 
+// Type with assembly format.
+def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> {
+  let parameters = (
+    ins
+    TestParamOne:$one,
+    TestParamTwo:$two,
+    "::mlir::Attribute":$three
+  );
+
+  let mnemonic = "type_with_format";
+  let assemblyFormat = "`<` $one `,` struct($three, $two) `>`";
+}
+
+// Test dispatch to parseField
+def TestTypeNoParser : Test_Type<"TestTypeNoParser"> {
+  let parameters = (
+    ins
+    "uint32_t":$one,
+    ArrayRefParameter<"int64_t">:$two,
+    StringRefParameter<>:$three,
+    "::test::CustomParam":$four
+  );
+
+  let mnemonic = "no_parser";
+  let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`";
+}
+
+def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
+  let parameters = (
+    ins
+    "int":$v0,
+    "int":$v1,
+    "int":$v2,
+    "int":$v3
+  );
+
+  let mnemonic = "struct_capture_all";
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 9da2e1713d9d0..7614ae401d1f0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -38,8 +38,38 @@ struct FieldInfo {
   }
 };
 
+/// A custom type for a test type parameter.
+struct CustomParam {
+  int value;
+
+  bool operator==(const CustomParam &other) const {
+    return other.value == value;
+  }
+};
+
+inline llvm::hash_code hash_value(const test::CustomParam &param) {
+  return llvm::hash_value(param.value);
+}
+
 } // namespace test
 
+namespace mlir {
+template <>
+struct FieldParser<test::CustomParam> {
+  static FailureOr<test::CustomParam> parse(DialectAsmParser &parser) {
+    auto value = FieldParser<int>::parse(parser);
+    if (failed(value))
+      return failure();
+    return test::CustomParam{value.getValue()};
+  }
+};
+} // end namespace mlir
+
+inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer,
+                                           const test::CustomParam &param) {
+  return printer << param.value;
+}
+
 #include "TestTypeInterfaces.h.inc"
 
 #define GET_TYPEDEF_CLASSES
@@ -52,17 +82,19 @@ namespace test {
 struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
   using KeyTy = ::llvm::StringRef;
 
-  explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {}
+  explicit TestRecursiveTypeStorage(::llvm::StringRef key)
+      : name(key), body(::mlir::Type()) {}
 
   bool operator==(const KeyTy &other) const { return name == other; }
 
-  static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator,
-                                             const KeyTy &key) {
+  static TestRecursiveTypeStorage *
+  construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
     return new (allocator.allocate<TestRecursiveTypeStorage>())
         TestRecursiveTypeStorage(allocator.copyInto(key));
   }
 
-  ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) {
+  ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
+                               ::mlir::Type newBody) {
     // Cannot set a 
diff erent body than before.
     if (body && body != newBody)
       return ::mlir::failure();
@@ -79,11 +111,13 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
 /// type, potentially itself. This requires the body to be mutated separately
 /// from type creation.
 class TestRecursiveType
-    : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type, TestRecursiveTypeStorage> {
+    : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
+                                    TestRecursiveTypeStorage> {
 public:
   using Base::Base;
 
-  static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) {
+  static TestRecursiveType get(::mlir::MLIRContext *ctx,
+                               ::llvm::StringRef name) {
     return Base::get(ctx, name);
   }
 

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
new file mode 100644
index 0000000000000..372aef6dfa3e5
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -0,0 +1,76 @@
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+  let name = "TestDialect";
+  let cppNamespace = "::test";
+}
+
+class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
+  let mnemonic = asm;
+}
+
+/// Test format is missing a parameter capture.
+def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
+  let parameters = (ins "int":$v0, "int":$v1);
+  // CHECK: format is missing reference to parameter: v1
+  let assemblyFormat = "`<` $v0 `>`";
+}
+
+/// Test format has duplicate parameter captures.
+def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
+  let parameters = (ins "int":$v0, "int":$v1);
+  // CHECK: duplicate parameter 'v0'
+  let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
+}
+
+/// Test format has invalid syntax.
+def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
+  let parameters = (ins "int":$v0, "int":$v1);
+  // CHECK: expected literal, directive, or variable
+  let assemblyFormat = "`<` $v0, $v1 `>`";
+}
+
+/// Test struct directive has invalid syntax.
+def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
+  let parameters = (ins "int":$v0);
+  // CHECK: literals may only be used in the top-level section of the format
+  // CHECK: expected a variable in `struct` argument list
+  let assemblyFormat = "`<` struct($v0, `,`) `>`";
+}
+
+/// Test struct directive cannot capture zero parameters.
+def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
+  let parameters = (ins "int":$v0);
+  // CHECK: `struct` argument list expected a variable or directive
+  let assemblyFormat = "`<` struct() $v0 `>`";
+}
+
+/// Test capture parameter that does not exist.
+def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
+  let parameters = (ins "int":$v0);
+  // CHECK: InvalidTypeF has no parameter named 'v1'
+  let assemblyFormat = "`<` $v0 $v1 `>`";
+}
+
+/// Test duplicate capture of parameter in capture-all struct.
+def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
+  let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
+  // CHECK: duplicate parameter 'v0'
+  let assemblyFormat = "`<` struct(params) $v0 `>`";
+}
+
+/// Test capture-all struct duplicate capture.
+def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
+  let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
+  // CHECK: `params` captures duplicate parameter: v0
+  let assemblyFormat = "`<` $v0 struct(params) `>`";
+}
+
+/// Test capture of parameter after `params` directive.
+def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
+  let parameters = (ins "int":$v0);
+  // CHECK: duplicate parameter 'v0'
+  let assemblyFormat = "`<` params $v0 `>`";
+}

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
new file mode 100644
index 0000000000000..f403f6f4b059d
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @test_roundtrip_parameter_parsers
+// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">
+// CHECK: !test.type_with_format<2147, three = "hi", two = "hi">
+func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi">
+attributes {
+  // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>
+  attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>,
+  // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>
+  attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>,
+  // CHECK: #test<"attr_ugly begin 5 : index end">
+  attr2 = #test<"attr_ugly begin 5 : index end">,
+  // CHECK: #test.attr_params<42, 24>
+  attr3 = #test.attr_params<42, 24>
+}
+
+// CHECK-LABEL: @test_roundtrip_default_parsers_struct
+// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
+// CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
+func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.mlir b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
new file mode 100644
index 0000000000000..3ff638c5f640f
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
@@ -0,0 +1,127 @@
+// RUN: mlir-opt --split-input-file %s --verify-diagnostics
+
+func private @test_ugly_attr_cannot_be_pretty() -> () attributes {
+  // expected-error at +1 {{expected 'begin'}}
+  attr = #test.attr_ugly
+}
+
+// -----
+
+func private @test_ugly_attr_no_mnemonic() -> () attributes {
+  // expected-error at +1 {{expected valid keyword}}
+  attr = #test<"">
+}
+
+// -----
+
+func private @test_ugly_attr_parser_dispatch() -> () attributes {
+  // expected-error at +1 {{expected 'begin'}}
+  attr = #test<"attr_ugly">
+}
+
+// -----
+
+func private @test_ugly_attr_missing_parameter() -> () attributes {
+  // expected-error at +2 {{failed to parse TestAttrUgly parameter 'attr'}}
+  // expected-error at +1 {{expected non-function type}}
+  attr = #test<"attr_ugly begin">
+}
+
+// -----
+
+func private @test_ugly_attr_missing_literal() -> () attributes {
+  // expected-error at +1 {{expected 'end'}}
+  attr = #test<"attr_ugly begin \"string_attr\"">
+}
+
+// -----
+
+func private @test_pretty_attr_expects_less() -> () attributes {
+  // expected-error at +1 {{expected '<'}}
+  attr = #test.attr_with_format
+}
+
+// -----
+
+func private @test_pretty_attr_missing_param() -> () attributes {
+  // expected-error at +2 {{expected integer value}}
+  // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'one'}}
+  attr = #test.attr_with_format<>
+}
+
+// -----
+
+func private @test_parse_invalid_param() -> () attributes {
+  // Test parameter parser failure is propagated
+  // expected-error at +2 {{expected integer value}}
+  // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'one'}}
+  attr = #test.attr_with_format<"hi">
+}
+
+// -----
+
+func private @test_pretty_attr_invalid_syntax() -> () attributes {
+  // expected-error at +1 {{expected ':'}}
+  attr = #test.attr_with_format<42>
+}
+
+// -----
+
+func private @test_struct_missing_key() -> () attributes {
+  // expected-error at +2 {{expected valid keyword}}
+  // expected-error at +1 {{expected a parameter name in struct}}
+  attr = #test.attr_with_format<42 :>
+}
+
+// -----
+
+func private @test_struct_unknown_key() -> () attributes {
+  // expected-error at +1 {{duplicate or unknown struct parameter}}
+  attr = #test.attr_with_format<42 : nine = "foo">
+}
+
+// -----
+
+func private @test_struct_duplicate_key() -> () attributes {
+  // expected-error at +1 {{duplicate or unknown struct parameter}}
+  attr = #test.attr_with_format<42 : two = "foo", two = "bar">
+}
+
+// -----
+
+func private @test_struct_not_enough_values() -> () attributes {
+  // expected-error at +1 {{expected ','}}
+  attr = #test.attr_with_format<42 : two = "foo">
+}
+
+// -----
+
+func private @test_parse_param_after_struct() -> () attributes {
+  // expected-error at +2 {{expected non-function type}}
+  // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'three'}}
+  attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : >
+}
+
+// -----
+
+// expected-error at +1 {{expected '<'}}
+func private @test_invalid_type() -> !test.type_with_format
+
+// -----
+
+// expected-error at +2 {{expected integer value}}
+// expected-error at +1 {{failed to parse TestTypeWithFormat parameter 'one'}}
+func private @test_pretty_type_invalid_param() -> !test.type_with_format<>
+
+// -----
+
+// expected-error at +2 {{expected ':'}}
+// expected-error at +1 {{failed to parse TestTypeWithFormat parameter 'three'}}
+func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>>
+
+// -----
+
+func private @test_verifier_fails() -> () attributes {
+  // expected-error at +1 {{expected 'one' to equal 'four.size()'}}
+  attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64>
+}

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
new file mode 100644
index 0000000000000..2d426935fa415
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -0,0 +1,394 @@
+// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE
+
+include "mlir/IR/OpBase.td"
+
+/// Test that attribute and type printers and parsers are correctly generated.
+def Test_Dialect : Dialect {
+  let name = "TestDialect";
+  let cppNamespace = "::test";
+}
+
+class TestAttr<string name> : AttrDef<Test_Dialect, name>;
+class TestType<string name> : TypeDef<Test_Dialect, name>;
+
+def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> {
+  let parser = "::parseAttrParamA($_parser, $_type)";
+  let printer = "::printAttrParamA($_printer, $_self)";
+}
+
+def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> {
+  let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)";
+  let printer = "::printAttrB($_printer, $_self)";
+}
+
+def TypeParamA : TypeParameter<"TestParamC", "a type param C"> {
+  let parser = "::parseTypeParamC($_parser)";
+  let printer = "$_printer << $_self";
+}
+
+def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
+  let parser = "someFcnCall()";
+  let printer = "myPrinter($_self)";
+}
+
+/// Check simple attribute parser and printer are generated correctly.
+
+// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR:                                    ::mlir::Type attrType) {
+// ATTR:   FailureOr<IntegerAttr> _result_value;
+// ATTR:   FailureOr<TestParamA> _result_complex;
+// ATTR:   if (parser.parseKeyword("hello"))
+// ATTR:     return {};
+// ATTR:   if (parser.parseEqual())
+// ATTR:     return {};
+// ATTR:   _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// ATTR:   if (failed(_result_value))
+// ATTR:     return {};
+// ATTR:   if (parser.parseComma())
+// ATTR:     return {};
+// ATTR:   _result_complex = ::parseAttrParamA(parser, attrType);
+// ATTR:   if (failed(_result_complex))
+// ATTR:     return {};
+// ATTR:   if (parser.parseRParen())
+// ATTR:     return {};
+// ATTR:   return TestAAttr::get(parser.getContext(),
+// ATTR:                         _result_value.getValue(),
+// ATTR:                         _result_complex.getValue());
+// ATTR: }
+
+// ATTR: void TestAAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR:   printer << "attr_a";
+// ATTR:   printer << ' ' << "hello";
+// ATTR:   printer << ' ' << "=";
+// ATTR:   printer << ' ';
+// ATTR:   printer << getValue();
+// ATTR:   printer << ",";
+// ATTR:   printer << ' ';
+// ATTR:   ::printAttrParamA(printer, getComplex());
+// ATTR:   printer << ")";
+// ATTR: }
+
+def AttrA : TestAttr<"TestA"> {
+  let parameters = (ins
+      "IntegerAttr":$value,
+      AttrParamA:$complex
+  );
+
+  let mnemonic = "attr_a";
+  let assemblyFormat = "`hello` `=` $value `,` $complex `)`";
+}
+
+/// Test simple struct parser and printer are generated correctly.
+
+// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR:                                    ::mlir::Type attrType) {
+// ATTR:   bool _seen_v0 = false;
+// ATTR:   bool _seen_v1 = false;
+// ATTR:   for (unsigned _index = 0; _index < 2; ++_index) {
+// ATTR:     StringRef _paramKey;
+// ATTR:     if (parser.parseKeyword(&_paramKey))
+// ATTR:       return {};
+// ATTR:     if (parser.parseEqual())
+// ATTR:       return {};
+// ATTR:     if (!_seen_v0 && _paramKey == "v0") {
+// ATTR:       _seen_v0 = true;
+// ATTR:       _result_v0 = ::parseAttrParamA(parser, attrType);
+// ATTR:       if (failed(_result_v0))
+// ATTR:         return {};
+// ATTR:     } else if (!_seen_v1 && _paramKey == "v1") {
+// ATTR:       _seen_v1 = true;
+// ATTR:       _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser);
+// ATTR:       if (failed(_result_v1))
+// ATTR:         return {};
+// ATTR:     } else {
+// ATTR:       return {};
+// ATTR:     }
+// ATTR:     if ((_index != 2 - 1) && parser.parseComma())
+// ATTR:       return {};
+// ATTR:   }
+// ATTR:   return TestBAttr::get(parser.getContext(),
+// ATTR:                         _result_v0.getValue(),
+// ATTR:                         _result_v1.getValue());
+// ATTR: }
+
+// ATTR: void TestBAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR:   printer << "v0";
+// ATTR:   printer << ' ' << "=";
+// ATTR:   printer << ' ';
+// ATTR:   ::printAttrParamA(printer, getV0());
+// ATTR:   printer << ",";
+// ATTR:   printer << ' ' << "v1";
+// ATTR:   printer << ' ' << "=";
+// ATTR:   printer << ' ';
+// ATTR:   ::printAttrB(printer, getV1());
+// ATTR: }
+
+def AttrB : TestAttr<"TestB"> {
+  let parameters = (ins
+      AttrParamA:$v0,
+      AttrParamB:$v1
+  );
+
+  let mnemonic = "attr_b";
+  let assemblyFormat = "`{` struct($v0, $v1) `}`";
+}
+
+/// Test attribute with capture-all params has correct parser and printer.
+
+// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR:                                    ::mlir::Type attrType) {
+// ATTR:   ::mlir::FailureOr<int> _result_v0;
+// ATTR:   ::mlir::FailureOr<int> _result_v1;
+// ATTR:   _result_v0 = ::mlir::FieldParser<int>::parse(parser);
+// ATTR:   if (failed(_result_v0))
+// ATTR:     return {};
+// ATTR:   if (parser.parseComma())
+// ATTR:     return {};
+// ATTR:   _result_v1 = ::mlir::FieldParser<int>::parse(parser);
+// ATTR:   if (failed(_result_v1))
+// ATTR:     return {};
+// ATTR:   return TestFAttr::get(parser.getContext(),
+// ATTR:     _result_v0.getValue(),
+// ATTR:     _result_v1.getValue());
+// ATTR: }
+
+// ATTR: void TestFAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR:   printer << "attr_c";
+// ATTR:   printer << ' ';
+// ATTR:   printer << getV0();
+// ATTR:   printer << ",";
+// ATTR:   printer << ' ';
+// ATTR:   printer << getV1();
+// ATTR: }
+
+def AttrC : TestAttr<"TestF"> {
+  let parameters = (ins "int":$v0, "int":$v1);
+
+  let mnemonic = "attr_c";
+  let assemblyFormat = "params";
+}
+
+/// Test type parser and printer that mix variables and struct are generated
+/// correctly.
+
+// TYPE: ::mlir::Type TestCType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE:  FailureOr<IntegerAttr> _result_value;
+// TYPE:  FailureOr<TestParamC> _result_complex;
+// TYPE:  if (parser.parseKeyword("foo"))
+// TYPE:    return {};
+// TYPE:  if (parser.parseComma())
+// TYPE:    return {};
+// TYPE:  if (parser.parseColon())
+// TYPE:    return {};
+// TYPE:  if (parser.parseKeyword("bob"))
+// TYPE:    return {};
+// TYPE:  if (parser.parseKeyword("bar"))
+// TYPE:    return {};
+// TYPE:  _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE:  if (failed(_result_value))
+// TYPE:    return {};
+// TYPE:  bool _seen_complex = false;
+// TYPE:  for (unsigned _index = 0; _index < 1; ++_index) {
+// TYPE:    StringRef _paramKey;
+// TYPE:    if (parser.parseKeyword(&_paramKey))
+// TYPE:      return {};
+// TYPE:    if (!_seen_complex && _paramKey == "complex") {
+// TYPE:      _seen_complex = true;
+// TYPE:      _result_complex = ::parseTypeParamC(parser);
+// TYPE:      if (failed(_result_complex))
+// TYPE:        return {};
+// TYPE:    } else {
+// TYPE:      return {};
+// TYPE:    }
+// TYPE:    if ((_index != 1 - 1) && parser.parseComma())
+// TYPE:      return {};
+// TYPE:  }
+// TYPE:  if (parser.parseRParen())
+// TYPE:    return {};
+// TYPE:  }
+
+// TYPE: void TestCType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE:   printer << "type_c";
+// TYPE:   printer << ' ' << "foo";
+// TYPE:   printer << ",";
+// TYPE:   printer << ' ' << ":";
+// TYPE:   printer << ' ' << "bob";
+// TYPE:   printer << ' ' << "bar";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getValue();
+// TYPE:   printer << ' ' << "complex";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getComplex();
+// TYPE:   printer << ")";
+// TYPE: }
+
+def TypeA : TestType<"TestC"> {
+  let parameters = (ins
+      "IntegerAttr":$value,
+      TypeParamA:$complex
+  );
+
+  let mnemonic = "type_c";
+  let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`";
+}
+
+/// Test type parser and printer with mix of variables and struct are generated
+/// correctly.
+
+// TYPE: ::mlir::Type TestDType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE:   _result_v0 = ::parseTypeParamC(parser);
+// TYPE:   if (failed(_result_v0))
+// TYPE:     return {};
+// TYPE:   bool _seen_v1 = false;
+// TYPE:   bool _seen_v2 = false;
+// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (parser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (parser.parseEqual())
+// TYPE:       return {};
+// TYPE:     if (!_seen_v1 && _paramKey == "v1") {
+// TYPE:       _seen_v1 = true;
+// TYPE:       _result_v1 = someFcnCall();
+// TYPE:       if (failed(_result_v1))
+// TYPE:         return {};
+// TYPE:     } else if (!_seen_v2 && _paramKey == "v2") {
+// TYPE:       _seen_v2 = true;
+// TYPE:       _result_v2 = ::parseTypeParamC(parser);
+// TYPE:       if (failed(_result_v2))
+// TYPE:         return {};
+// TYPE:     } else  {
+// TYPE:       return {};
+// TYPE:     }
+// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:       return {};
+// TYPE:   }
+// TYPE:   _result_v3 = someFcnCall();
+// TYPE:   if (failed(_result_v3))
+// TYPE:     return {};
+// TYPE:   return TestDType::get(parser.getContext(),
+// TYPE:                         _result_v0.getValue(),
+// TYPE:                         _result_v1.getValue(),
+// TYPE:                         _result_v2.getValue(),
+// TYPE:                         _result_v3.getValue());
+// TYPE: }
+
+// TYPE: void TestDType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE:   printer << getV0();
+// TYPE:   myPrinter(getV1());
+// TYPE:   printer << ' ' << "v2";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getV2();
+// TYPE:   myPrinter(getV3());
+// TYPE: }
+
+def TypeB : TestType<"TestD"> {
+  let parameters = (ins
+      TypeParamA:$v0,
+      TypeParamB:$v1,
+      TypeParamA:$v2,
+      TypeParamB:$v3
+  );
+
+  let mnemonic = "type_d";
+  let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`";
+}
+
+/// Type test with two struct directives has correctly generated parser and
+/// printer.
+
+// TYPE: ::mlir::Type TestEType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE:   FailureOr<IntegerAttr> _result_v0;
+// TYPE:   FailureOr<IntegerAttr> _result_v1;
+// TYPE:   FailureOr<IntegerAttr> _result_v2;
+// TYPE:   FailureOr<IntegerAttr> _result_v3;
+// TYPE:   bool _seen_v0 = false;
+// TYPE:   bool _seen_v2 = false;
+// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (parser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (parser.parseEqual())
+// TYPE:       return {};
+// TYPE:     if (!_seen_v0 && _paramKey == "v0") {
+// TYPE:       _seen_v0 = true;
+// TYPE:       _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE:       if (failed(_result_v0))
+// TYPE:         return {};
+// TYPE:     } else if (!_seen_v2 && _paramKey == "v2") {
+// TYPE:       _seen_v2 = true;
+// TYPE:       _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE:       if (failed(_result_v2))
+// TYPE:         return {};
+// TYPE:     } else  {
+// TYPE:       return {};
+// TYPE:     }
+// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:       return {};
+// TYPE:   }
+// TYPE:   bool _seen_v1 = false;
+// TYPE:   bool _seen_v3 = false;
+// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (parser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (parser.parseEqual())
+// TYPE:       return {};
+// TYPE:     if (!_seen_v1 && _paramKey == "v1") {
+// TYPE:       _seen_v1 = true;
+// TYPE:       _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE:       if (failed(_result_v1))
+// TYPE:         return {};
+// TYPE:     } else if (!_seen_v3 && _paramKey == "v3") {
+// TYPE:       _seen_v3 = true;
+// TYPE:       _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE:       if (failed(_result_v3))
+// TYPE:         return {};
+// TYPE:     } else  {
+// TYPE:       return {};
+// TYPE:     }
+// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:       return {};
+// TYPE:   }
+// TYPE:   return TestEType::get(parser.getContext(),
+// TYPE:     _result_v0.getValue(),
+// TYPE:     _result_v1.getValue(),
+// TYPE:     _result_v2.getValue(),
+// TYPE:     _result_v3.getValue());
+// TYPE: }
+
+// TYPE: void TestEType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE:   printer << "v0";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getV0();
+// TYPE:   printer << ",";
+// TYPE:   printer << ' ' << "v2";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getV2();
+// TYPE:   printer << "v1";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getV1();
+// TYPE:   printer << ",";
+// TYPE:   printer << ' ' << "v3";
+// TYPE:   printer << ' ' << "=";
+// TYPE:   printer << ' ';
+// TYPE:   printer << getV3();
+// TYPE: }
+
+def TypeC : TestType<"TestE"> {
+  let parameters = (ins
+      "IntegerAttr":$v0,
+      "IntegerAttr":$v1,
+      "IntegerAttr":$v2,
+      "IntegerAttr":$v3
+  );
+
+  let mnemonic = "type_e";
+  let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index a1b0836a55d7b..5e4cb4d73a392 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "AttrOrTypeFormatGen.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/AttrOrTypeDef.h"
 #include "mlir/TableGen/CodeGenHelpers.h"
@@ -24,6 +25,17 @@
 using namespace mlir;
 using namespace mlir::tblgen;
 
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
+  assert(!name.empty() && "parameter has empty name");
+  auto ret = "get" + name.str();
+  ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
+  return ret;
+}
+
 /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
 /// specified and can only find one dialect's defs, use that.
 static void collectAllDefs(StringRef selectedDialect,
@@ -399,7 +411,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
        << "    }\n";
 
     // If mnemonic specified, emit print/parse declarations.
-    if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
+    if (def.getParserCode() || def.getPrinterCode() ||
+        def.getAssemblyFormat() || !params.empty()) {
       os << llvm::formatv(defDeclParsePrintStr, valueType,
                           isAttrGenerator ? ", ::mlir::Type type" : "");
     }
@@ -410,10 +423,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
     def.getParameters(parameters);
 
     for (AttrOrTypeParameter &parameter : parameters) {
-      SmallString<16> name = parameter.getName();
-      name[0] = llvm::toUpper(name[0]);
-      os << formatv("    {0} get{1}() const;\n", parameter.getCppAccessorType(),
-                    name);
+      os << formatv("    {0} {1}() const;\n", parameter.getCppAccessorType(),
+                    getParameterAccessorName(parameter.getName()));
     }
   }
 
@@ -700,8 +711,32 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
 }
 
 void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
+  auto printerCode = def.getPrinterCode();
+  auto parserCode = def.getParserCode();
+  auto assemblyFormat = def.getAssemblyFormat();
+  if (assemblyFormat && (printerCode || parserCode)) {
+    // Custom assembly format cannot be specified at the same time as either
+    // custom printer or parser code.
+    PrintFatalError(def.getLoc(),
+                    def.getName() + ": assembly format cannot be specified at "
+                                    "the same time as printer or parser code");
+  }
+
+  // Generate a parser and printer based on the assembly format, if specified.
+  if (assemblyFormat) {
+    // A custom assembly format requires accessors to be generated for the
+    // generated printer.
+    if (!def.genAccessors()) {
+      PrintFatalError(def.getLoc(),
+                      def.getName() +
+                          ": the generated printer from 'assemblyFormat' "
+                          "requires 'genAccessors' to be true");
+    }
+    return generateAttrOrTypeFormat(def, os);
+  }
+
   // Emit the printer code, if specified.
-  if (Optional<StringRef> printerCode = def.getPrinterCode()) {
+  if (printerCode) {
     // Both the mnenomic and printerCode must be defined (for parity with
     // parserCode).
     os << "void " << def.getCppClassName()
@@ -717,7 +752,7 @@ void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
   }
 
   // Emit the parser code, if specified.
-  if (Optional<StringRef> parserCode = def.getParserCode()) {
+  if (parserCode) {
     FmtContext fmtCtxt;
     fmtCtxt.addSubst("_parser", "parser")
         .addSubst("_ctxt", "parser.getContext()");
@@ -857,11 +892,10 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
           paramStorageName = param.getName();
         }
 
-        SmallString<16> name = param.getName();
-        name[0] = llvm::toUpper(name[0]);
-        os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
-                      param.getCppAccessorType(), name, paramStorageName,
-                      def.getCppClassName());
+        os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n",
+                      param.getCppAccessorType(),
+                      getParameterAccessorName(param.getName()),
+                      paramStorageName, def.getCppClassName());
       }
     }
   }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
new file mode 100644
index 0000000000000..52e921f2fb27a
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -0,0 +1,781 @@
+//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "AttrOrTypeFormatGen.h"
+#include "FormatGen.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::formatv;
+
+//===----------------------------------------------------------------------===//
+// Element
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// This class represents a single format element.
+class Element {
+public:
+  /// LLVM-style RTTI.
+  enum class Kind {
+    /// This element is a directive.
+    ParamsDirective,
+    StructDirective,
+
+    /// This element is a literal.
+    Literal,
+
+    /// This element is a variable.
+    Variable,
+  };
+  Element(Kind kind) : kind(kind) {}
+  virtual ~Element() = default;
+
+  /// Return the kind of this element.
+  Kind getKind() const { return kind; }
+
+private:
+  /// The kind of this element.
+  Kind kind;
+};
+
+/// This class represents an instance of a literal element.
+class LiteralElement : public Element {
+public:
+  LiteralElement(StringRef literal)
+      : Element(Kind::Literal), literal(literal) {}
+
+  static bool classof(const Element *el) {
+    return el->getKind() == Kind::Literal;
+  }
+
+  /// Get the literal spelling.
+  StringRef getSpelling() const { return literal; }
+
+private:
+  /// The spelling of the literal for this element.
+  StringRef literal;
+};
+
+/// This class represents an instance of a variable element. A variable refers
+/// to an attribute or type parameter.
+class VariableElement : public Element {
+public:
+  VariableElement(AttrOrTypeParameter param)
+      : Element(Kind::Variable), param(param) {}
+
+  static bool classof(const Element *el) {
+    return el->getKind() == Kind::Variable;
+  }
+
+  /// Get the parameter in the element.
+  const AttrOrTypeParameter &getParam() const { return param; }
+
+private:
+  AttrOrTypeParameter param;
+};
+
+/// Base class for a directive that contains references to multiple variables.
+template <Element::Kind ElementKind>
+class ParamsDirectiveBase : public Element {
+public:
+  using Base = ParamsDirectiveBase<ElementKind>;
+
+  ParamsDirectiveBase(SmallVector<std::unique_ptr<Element>> &&params)
+      : Element(ElementKind), params(std::move(params)) {}
+
+  static bool classof(const Element *el) {
+    return el->getKind() == ElementKind;
+  }
+
+  /// Get the parameters contained in this directive.
+  auto getParams() const {
+    return llvm::map_range(params, [](auto &el) {
+      return cast<VariableElement>(el.get())->getParam();
+    });
+  }
+
+  /// Get the number of parameters.
+  unsigned getNumParams() const { return params.size(); }
+
+  /// Take all of the parameters from this directive.
+  SmallVector<std::unique_ptr<Element>> takeParams() {
+    return std::move(params);
+  }
+
+private:
+  /// The parameters captured by this directive.
+  SmallVector<std::unique_ptr<Element>> params;
+};
+
+/// This class represents a `params` directive that refers to all parameters
+/// of an attribute or type. When used as a top-level directive, it generates
+/// a format of the form:
+///
+///   (param-value (`,` param-value)*)?
+///
+/// When used as an argument to another directive that accepts variables,
+/// `params` can be used in place of manually listing all parameters of an
+/// attribute or type.
+class ParamsDirective
+    : public ParamsDirectiveBase<Element::Kind::ParamsDirective> {
+public:
+  using Base::Base;
+};
+
+/// This class represents a `struct` directive that generates a struct format
+/// of the form:
+///
+///   `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
+///
+class StructDirective
+    : public ParamsDirectiveBase<Element::Kind::StructDirective> {
+public:
+  using Base::Base;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Format Strings
+//===----------------------------------------------------------------------===//
+
+/// Format for defining an attribute parser.
+///
+/// $0: The attribute C++ class name.
+static const char *const attrParserDefn = R"(
+::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser,
+                             ::mlir::Type $_type) {
+)";
+
+/// Format for defining a type parser.
+///
+/// $0: The type C++ class name.
+static const char *const typeParserDefn = R"(
+::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) {
+)";
+
+/// Default parser for attribute or type parameters.
+static const char *const defaultParameterParser =
+    "::mlir::FieldParser<$0>::parse($_parser)";
+
+/// Default printer for attribute or type parameters.
+static const char *const defaultParameterPrinter = "$_printer << $_self";
+
+/// Print an error when failing to parse an element.
+///
+/// $0: The parameter C++ class name.
+static const char *const parseErrorStr =
+    "$_parser.emitError($_parser.getCurrentLocation(), ";
+
+/// Format for defining an attribute or type printer.
+///
+/// $0: The attribute or type C++ class name.
+/// $1: The attribute or type mnemonic.
+static const char *const attrOrTypePrinterDefn = R"(
+void $0::print(::mlir::DialectAsmPrinter &$_printer) const {
+  $_printer << "$1";
+)";
+
+/// Loop declaration for struct parser.
+///
+/// $0: Number of expected parameters.
+static const char *const structParseLoopStart = R"(
+  for (unsigned _index = 0; _index < $0; ++_index) {
+    StringRef _paramKey;
+    if ($_parser.parseKeyword(&_paramKey)) {
+      $_parser.emitError($_parser.getCurrentLocation(),
+                         "expected a parameter name in struct");
+      return {};
+    }
+)";
+
+/// Terminator code segment for the struct parser loop. Check for duplicate or
+/// unknown parameters. Parse a comma except on the last element.
+///
+/// {0}: Code template for printing an error.
+/// {1}: Number of elements in the struct.
+static const char *const structParseLoopEnd = R"({{
+      {0}"duplicate or unknown struct parameter name: ") << _paramKey;
+      return {{};
+    }
+    if ((_index != {1} - 1) && parser.parseComma())
+      return {{};
+  }
+)";
+
+/// Code format to parse a variable. Separate by lines because variable parsers
+/// may be generated inside other directives, which requires indentation.
+///
+/// {0}: The parameter name.
+/// {1}: The parse code for the parameter.
+/// {2}: Code template for printing an error.
+/// {3}: Name of the attribute or type.
+/// {4}: C++ class of the parameter.
+static const char *const variableParser[] = {
+    "  // Parse variable '{0}'",
+    "  _result_{0} = {1};",
+    "  if (failed(_result_{0})) {{",
+    "    {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");",
+    "    return {{};",
+    "  }",
+};
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Get a list of an attribute's or type's parameters. These can be wrapper
+/// objects around `AttrOrTypeParameter` or string inits.
+static auto getParameters(const AttrOrTypeDef &def) {
+  SmallVector<AttrOrTypeParameter> params;
+  def.getParameters(params);
+  return params;
+}
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeFormat
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AttrOrTypeFormat {
+public:
+  AttrOrTypeFormat(const AttrOrTypeDef &def,
+                   std::vector<std::unique_ptr<Element>> &&elements)
+      : def(def), elements(std::move(elements)) {}
+
+  /// Generate the attribute or type parser.
+  void genParser(raw_ostream &os);
+  /// Generate the attribute or type printer.
+  void genPrinter(raw_ostream &os);
+
+private:
+  /// Generate the parser code for a specific format element.
+  void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os);
+  /// Generate the parser code for a literal.
+  void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os,
+                        unsigned indent = 0);
+  /// Generate the parser code for a variable.
+  void genVariableParser(const AttrOrTypeParameter &param, FmtContext &ctx,
+                         raw_ostream &os, unsigned indent = 0);
+  /// Generate the parser code for a `params` directive.
+  void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+  /// Generate the parser code for a `struct` directive.
+  void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+
+  /// Generate the printer code for a specific format element.
+  void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os);
+  /// Generate the printer code for a literal.
+  void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os);
+  /// Generate the printer code for a variable.
+  void genVariablePrinter(const AttrOrTypeParameter &param, FmtContext &ctx,
+                          raw_ostream &os);
+  /// Generate the printer code for a `params` directive.
+  void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+  /// Generate the printer code for a `struct` directive.
+  void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+
+  /// The ODS definition of the attribute or type whose format is being used to
+  /// generate a parser and printer.
+  const AttrOrTypeDef &def;
+  /// The list of top-level format elements returned by the assembly format
+  /// parser.
+  std::vector<std::unique_ptr<Element>> elements;
+
+  /// Flags for printing spaces.
+  bool shouldEmitSpace;
+  bool lastWasPunctuation;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ParserGen
+//===----------------------------------------------------------------------===//
+
+void AttrOrTypeFormat::genParser(raw_ostream &os) {
+  FmtContext ctx;
+  ctx.addSubst("_parser", "parser");
+
+  /// Generate the definition.
+  if (isa<AttrDef>(def)) {
+    ctx.addSubst("_type", "attrType");
+    os << tgfmt(attrParserDefn, &ctx, def.getCppClassName());
+  } else {
+    os << tgfmt(typeParserDefn, &ctx, def.getCppClassName());
+  }
+
+  /// Declare variables to store all of the parameters. Allocated parameters
+  /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
+  /// FailureOr<T> to defer type construction for parameters that are parsed in
+  /// a loop (parsers return FailureOr anyways).
+  SmallVector<AttrOrTypeParameter> params = getParameters(def);
+  for (const AttrOrTypeParameter &param : params) {
+    os << formatv("  ::mlir::FailureOr<{0}> _result_{1};\n",
+                  param.getCppStorageType(), param.getName());
+  }
+
+  /// Store the initial location of the parser.
+  ctx.addSubst("_loc", "loc");
+  os << tgfmt("  ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
+              "  (void) $_loc;\n",
+              &ctx);
+
+  /// Generate call to each parameter parser.
+  for (auto &el : elements)
+    genElementParser(el.get(), ctx, os);
+
+  /// Generate call to the attribute or type builder. Use the checked getter
+  /// if one was generated.
+  if (def.genVerifyDecl()) {
+    os << tgfmt("  return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
+                &ctx, def.getCppClassName());
+  } else {
+    os << tgfmt("  return $0::get($_parser.getContext()", &ctx,
+                def.getCppClassName());
+  }
+  for (const AttrOrTypeParameter &param : params)
+    os << formatv(",\n    _result_{0}.getValue()", param.getName());
+  os << ");\n}\n\n";
+}
+
+void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
+                                        raw_ostream &os) {
+  if (auto *literal = dyn_cast<LiteralElement>(el))
+    return genLiteralParser(literal->getSpelling(), ctx, os);
+  if (auto *var = dyn_cast<VariableElement>(el))
+    return genVariableParser(var->getParam(), ctx, os);
+  if (auto *params = dyn_cast<ParamsDirective>(el))
+    return genParamsParser(params, ctx, os);
+  if (auto *strct = dyn_cast<StructDirective>(el))
+    return genStructParser(strct, ctx, os);
+
+  llvm_unreachable("unknown format element");
+}
+
+void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
+                                        raw_ostream &os, unsigned indent) {
+  os.indent(indent) << "  // Parse literal '" << value << "'\n";
+  os.indent(indent) << tgfmt("  if ($_parser.parse", &ctx);
+  if (value.front() == '_' || isalpha(value.front())) {
+    os << "Keyword(\"" << value << "\")";
+  } else {
+    os << StringSwitch<StringRef>(value)
+              .Case("->", "Arrow")
+              .Case(":", "Colon")
+              .Case(",", "Comma")
+              .Case("=", "Equal")
+              .Case("<", "Less")
+              .Case(">", "Greater")
+              .Case("{", "LBrace")
+              .Case("}", "RBrace")
+              .Case("(", "LParen")
+              .Case(")", "RParen")
+              .Case("[", "LSquare")
+              .Case("]", "RSquare")
+              .Case("?", "Question")
+              .Case("+", "Plus")
+              .Case("*", "Star")
+       << "()";
+  }
+  os << ")\n";
+  // Parser will emit an error
+  os.indent(indent) << "    return {};\n";
+}
+
+void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter &param,
+                                         FmtContext &ctx, raw_ostream &os,
+                                         unsigned indent) {
+  /// Check for a custom parser. Use the default attribute parser otherwise.
+  auto customParser = param.getParser();
+  auto parser =
+      customParser ? *customParser : StringRef(defaultParameterParser);
+  for (const char *line : variableParser) {
+    os.indent(indent) << formatv(line, param.getName(),
+                                 tgfmt(parser, &ctx, param.getCppStorageType()),
+                                 tgfmt(parseErrorStr, &ctx), def.getName(),
+                                 param.getCppType())
+                      << "\n";
+  }
+}
+
+void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
+                                       raw_ostream &os) {
+  os << "  // Parse parameter list\n";
+  llvm::interleave(
+      el->getParams(), [&](auto param) { genVariableParser(param, ctx, os); },
+      [&]() { genLiteralParser(",", ctx, os); });
+}
+
+void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
+                                       raw_ostream &os) {
+  os << "  // Parse parameter struct\n";
+
+  /// Declare a "seen" variable for each key.
+  for (const AttrOrTypeParameter &param : el->getParams())
+    os << formatv("  bool _seen_{0} = false;\n", param.getName());
+
+  /// Generate the parsing loop.
+  os << tgfmt(structParseLoopStart, &ctx, el->getNumParams());
+  genLiteralParser("=", ctx, os, 2);
+  os << "    ";
+  for (const AttrOrTypeParameter &param : el->getParams()) {
+    os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
+                  "      _seen_{0} = true;\n",
+                  param.getName());
+    genVariableParser(param, ctx, os, 4);
+    os << "    } else ";
+  }
+
+  /// Duplicate or unknown parameter.
+  os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx),
+                el->getNumParams());
+
+  /// Because the loop loops N times and each non-failing iteration sets 1 of
+  /// N flags, successfully exiting the loop means that all parameters have been
+  /// seen. `parseOptionalComma` would cause issues with any formats that use
+  /// "struct(...) `,`" beacuse structs aren't sounded by braces.
+}
+
+//===----------------------------------------------------------------------===//
+// PrinterGen
+//===----------------------------------------------------------------------===//
+
+void AttrOrTypeFormat::genPrinter(raw_ostream &os) {
+  FmtContext ctx;
+  ctx.addSubst("_printer", "printer");
+
+  /// Generate the definition.
+  os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(),
+              *def.getMnemonic());
+
+  /// Generate printers.
+  shouldEmitSpace = true;
+  lastWasPunctuation = false;
+  for (auto &el : elements)
+    genElementPrinter(el.get(), ctx, os);
+
+  os << "}\n\n";
+}
+
+void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
+                                         raw_ostream &os) {
+  if (auto *literal = dyn_cast<LiteralElement>(el))
+    return genLiteralPrinter(literal->getSpelling(), ctx, os);
+  if (auto *params = dyn_cast<ParamsDirective>(el))
+    return genParamsPrinter(params, ctx, os);
+  if (auto *strct = dyn_cast<StructDirective>(el))
+    return genStructPrinter(strct, ctx, os);
+  if (auto *var = dyn_cast<VariableElement>(el))
+    return genVariablePrinter(var->getParam(), ctx, os);
+
+  llvm_unreachable("unknown format element");
+}
+
+void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
+                                         raw_ostream &os) {
+  /// Don't insert a space before certain punctuation.
+  bool needSpace =
+      shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
+  os << tgfmt("  $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
+              value);
+
+  /// Update the flags.
+  shouldEmitSpace =
+      value.size() != 1 || !StringRef("<({[").contains(value.front());
+  lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
+}
+
+void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter &param,
+                                          FmtContext &ctx, raw_ostream &os) {
+  /// Insert a space before the next parameter, if necessary.
+  if (shouldEmitSpace || !lastWasPunctuation)
+    os << tgfmt("  $_printer << ' ';\n", &ctx);
+  shouldEmitSpace = true;
+  lastWasPunctuation = false;
+
+  ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+  os << "  ";
+  if (auto printer = param.getPrinter())
+    os << tgfmt(*printer, &ctx) << ";\n";
+  else
+    os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
+}
+
+void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+                                        raw_ostream &os) {
+  llvm::interleave(
+      el->getParams(), [&](auto param) { genVariablePrinter(param, ctx, os); },
+      [&]() { genLiteralPrinter(",", ctx, os); });
+}
+
+void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+                                        raw_ostream &os) {
+  llvm::interleave(
+      el->getParams(),
+      [&](auto param) {
+        genLiteralPrinter(param.getName(), ctx, os);
+        genLiteralPrinter("=", ctx, os);
+        os << tgfmt("  $_printer << ' ';\n", &ctx);
+        genVariablePrinter(param, ctx, os);
+      },
+      [&]() { genLiteralPrinter(",", ctx, os); });
+}
+
+//===----------------------------------------------------------------------===//
+// FormatParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+class FormatParser {
+public:
+  FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
+      : lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def),
+        seenParams(def.getNumParameters()) {}
+
+  /// Parse the attribute or type format and create the format elements.
+  FailureOr<AttrOrTypeFormat> parse();
+
+private:
+  /// The current context of the parser when parsing an element.
+  enum ParserContext {
+    /// The element is being parsed in the default context - at the top of the
+    /// format
+    TopLevelContext,
+    /// The element is being parsed as a child to a `struct` directive.
+    StructDirective,
+  };
+
+  /// Emit an error.
+  LogicalResult emitError(const Twine &msg) {
+    lexer.emitError(curToken.getLoc(), msg);
+    return failure();
+  }
+
+  /// Parse an expected token.
+  LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
+    if (curToken.getKind() != kind)
+      return emitError(msg);
+    consumeToken();
+    return success();
+  }
+
+  /// Advance the lexer to the next token.
+  void consumeToken() {
+    assert(curToken.getKind() != FormatToken::eof &&
+           curToken.getKind() != FormatToken::error &&
+           "shouldn't advance past EOF or errors");
+    curToken = lexer.lexToken();
+  }
+
+  /// Parse any element.
+  FailureOr<std::unique_ptr<Element>> parseElement(ParserContext ctx);
+  /// Parse a literal element.
+  FailureOr<std::unique_ptr<Element>> parseLiteral(ParserContext ctx);
+  /// Parse a variable element.
+  FailureOr<std::unique_ptr<Element>> parseVariable(ParserContext ctx);
+  /// Parse a directive.
+  FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
+  /// Parse a `params` directive.
+  FailureOr<std::unique_ptr<Element>> parseParamsDirective();
+  /// Parse a `struct` directive.
+  FailureOr<std::unique_ptr<Element>> parseStructDirective();
+
+  /// The current format lexer.
+  FormatLexer lexer;
+  /// The current token in the stream.
+  FormatToken curToken;
+  /// Attribute or type tablegen def.
+  const AttrOrTypeDef &def;
+
+  /// Seen attribute or type parameters.
+  llvm::BitVector seenParams;
+};
+} // end anonymous namespace
+
+FailureOr<AttrOrTypeFormat> FormatParser::parse() {
+  std::vector<std::unique_ptr<Element>> elements;
+  elements.reserve(16);
+
+  /// Parse the format elements.
+  while (curToken.getKind() != FormatToken::eof) {
+    auto element = parseElement(TopLevelContext);
+    if (failed(element))
+      return failure();
+
+    /// Add the format element and continue.
+    elements.push_back(std::move(*element));
+  }
+
+  /// Check that all parameters have been seen.
+  SmallVector<AttrOrTypeParameter> params = getParameters(def);
+  for (auto it : llvm::enumerate(params)) {
+    if (!seenParams.test(it.index())) {
+      return emitError("format is missing reference to parameter: " +
+                       it.value().getName());
+    }
+  }
+
+  return AttrOrTypeFormat(def, std::move(elements));
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseElement(ParserContext ctx) {
+  if (curToken.getKind() == FormatToken::literal)
+    return parseLiteral(ctx);
+  if (curToken.getKind() == FormatToken::variable)
+    return parseVariable(ctx);
+  if (curToken.isKeyword())
+    return parseDirective(ctx);
+
+  return emitError("expected literal, directive, or variable");
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseLiteral(ParserContext ctx) {
+  if (ctx != TopLevelContext) {
+    return emitError(
+        "literals may only be used in the top-level section of the format");
+  }
+
+  /// Get the literal spelling without the surrounding "`".
+  auto value = curToken.getSpelling().drop_front().drop_back();
+  if (!isValidLiteral(value))
+    return emitError("literal '" + value + "' is not valid");
+
+  consumeToken();
+  return {std::make_unique<LiteralElement>(value)};
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseVariable(ParserContext ctx) {
+  /// Get the parameter name without the preceding "$".
+  auto name = curToken.getSpelling().drop_front();
+
+  /// Lookup the parameter.
+  SmallVector<AttrOrTypeParameter> params = getParameters(def);
+  auto *it = llvm::find_if(
+      params, [&](auto &param) { return param.getName() == name; });
+
+  /// Check that the parameter reference is valid.
+  if (it == params.end())
+    return emitError(def.getName() + " has no parameter named '" + name + "'");
+  auto idx = std::distance(params.begin(), it);
+  if (seenParams.test(idx))
+    return emitError("duplicate parameter '" + name + "'");
+  seenParams.set(idx);
+
+  consumeToken();
+  return {std::make_unique<VariableElement>(*it)};
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseDirective(ParserContext ctx) {
+
+  switch (curToken.getKind()) {
+  case FormatToken::kw_params:
+    return parseParamsDirective();
+  case FormatToken::kw_struct:
+    if (ctx != TopLevelContext) {
+      return emitError(
+          "`struct` may only be used in the top-level section of the format");
+    }
+    return parseStructDirective();
+  default:
+    return emitError("unknown directive in format: " + curToken.getSpelling());
+  }
+}
+
+FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
+  consumeToken();
+  /// Collect all of the attribute's or type's parameters.
+  SmallVector<AttrOrTypeParameter> params = getParameters(def);
+  SmallVector<std::unique_ptr<Element>> vars;
+  /// Ensure that none of the parameters have already been captured.
+  for (auto it : llvm::enumerate(params)) {
+    if (seenParams.test(it.index())) {
+      return emitError("`params` captures duplicate parameter: " +
+                       it.value().getName());
+    }
+    seenParams.set(it.index());
+    vars.push_back(std::make_unique<VariableElement>(it.value()));
+  }
+  return {std::make_unique<ParamsDirective>(std::move(vars))};
+}
+
+FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
+  consumeToken();
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before `struct` argument list")))
+    return failure();
+
+  /// Parse variables captured by `struct`.
+  SmallVector<std::unique_ptr<Element>> vars;
+
+  /// Parse first captured parameter or a `params` directive.
+  FailureOr<std::unique_ptr<Element>> var = parseElement(StructDirective);
+  if (failed(var) || !isa<VariableElement, ParamsDirective>(*var))
+    return emitError("`struct` argument list expected a variable or directive");
+  if (isa<VariableElement>(*var)) {
+    /// Parse any other parameters.
+    vars.push_back(std::move(*var));
+    while (curToken.getKind() == FormatToken::comma) {
+      consumeToken();
+      var = parseElement(StructDirective);
+      if (failed(var) || !isa<VariableElement>(*var))
+        return emitError("expected a variable in `struct` argument list");
+      vars.push_back(std::move(*var));
+    }
+  } else {
+    /// `struct(params)` captures all parameters in the attribute or type.
+    vars = cast<ParamsDirective>(var->get())->takeParams();
+  }
+
+  if (curToken.getKind() != FormatToken::r_paren)
+    return emitError("expected ')' at the end of an argument list");
+
+  consumeToken();
+  return {std::make_unique<::StructDirective>(std::move(vars))};
+}
+
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
+                                            raw_ostream &os) {
+  llvm::SourceMgr mgr;
+  mgr.AddNewSourceBuffer(
+      llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
+      llvm::SMLoc());
+
+  /// Parse the custom assembly format>
+  FormatParser parser(mgr, def);
+  FailureOr<AttrOrTypeFormat> format = parser.parse();
+  if (failed(format)) {
+    if (formatErrorIsFatal)
+      PrintFatalError(def.getLoc(), "failed to parse assembly format");
+    return;
+  }
+
+  /// Generate the parser and printer.
+  format->genParser(os);
+  format->genPrinter(os);
+}

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
new file mode 100644
index 0000000000000..2a10a157dfc90
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
@@ -0,0 +1,32 @@
+//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
+#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
+
+#include "llvm/Support/raw_ostream.h"
+
+#include <string>
+
+namespace mlir {
+namespace tblgen {
+class AttrOrTypeDef;
+
+/// Generate a parser and printer based on a custom assembly format for an
+/// attribute or type.
+void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os);
+
+/// From the parameter name, get the name of the accessor function in camelcase.
+/// The first letter of the parameter is upper-cased and prefixed with "get".
+/// E.g. 'value' -> 'getValue'.
+std::string getParameterAccessorName(llvm::StringRef name);
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_

diff  --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index f16e8965daca4..a937a9d89a1d3 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -6,10 +6,12 @@ set(LLVM_LINK_COMPONENTS
 
 add_tablegen(mlir-tblgen MLIR
   AttrOrTypeDefGen.cpp
+  AttrOrTypeFormatGen.cpp
   CodeGenHelpers.cpp
   DialectGen.cpp
   DirectiveCommonGen.cpp
   EnumsGen.cpp
+  FormatGen.cpp
   LLVMIRConversionGen.cpp
   LLVMIRIntrinsicGen.cpp
   mlir-tblgen.cpp

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
new file mode 100644
index 0000000000000..fa6c0603ac7e1
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -0,0 +1,225 @@
+//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "FormatGen.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/TableGen/Error.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// FormatToken
+//===----------------------------------------------------------------------===//
+
+llvm::SMLoc FormatToken::getLoc() const {
+  return llvm::SMLoc::getFromPointer(spelling.data());
+}
+
+//===----------------------------------------------------------------------===//
+// FormatLexer
+//===----------------------------------------------------------------------===//
+
+FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc)
+    : mgr(mgr), loc(loc),
+      curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
+      curPtr(curBuffer.begin()) {}
+
+FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
+  mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
+  llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
+                            "in custom assembly format for this operation");
+  return formToken(FormatToken::error, loc.getPointer());
+}
+
+FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
+  return emitError(llvm::SMLoc::getFromPointer(loc), msg);
+}
+
+FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+                                          const Twine &note) {
+  mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
+  llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
+                            "in custom assembly format for this operation");
+  mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
+  return formToken(FormatToken::error, loc.getPointer());
+}
+
+int FormatLexer::getNextChar() {
+  char curChar = *curPtr++;
+  switch (curChar) {
+  default:
+    return (unsigned char)curChar;
+  case 0: {
+    // A nul character in the stream is either the end of the current buffer or
+    // a random nul in the file. Disambiguate that here.
+    if (curPtr - 1 != curBuffer.end())
+      return 0;
+
+    // Otherwise, return end of file.
+    --curPtr;
+    return EOF;
+  }
+  case '\n':
+  case '\r':
+    // Handle the newline character by ignoring it and incrementing the line
+    // count. However, be careful about 'dos style' files with \n\r in them.
+    // Only treat a \n\r or \r\n as a single line.
+    if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
+      ++curPtr;
+    return '\n';
+  }
+}
+
+FormatToken FormatLexer::lexToken() {
+  const char *tokStart = curPtr;
+
+  // This always consumes at least one character.
+  int curChar = getNextChar();
+  switch (curChar) {
+  default:
+    // Handle identifiers: [a-zA-Z_]
+    if (isalpha(curChar) || curChar == '_')
+      return lexIdentifier(tokStart);
+
+    // Unknown character, emit an error.
+    return emitError(tokStart, "unexpected character");
+  case EOF:
+    // Return EOF denoting the end of lexing.
+    return formToken(FormatToken::eof, tokStart);
+
+  // Lex punctuation.
+  case '^':
+    return formToken(FormatToken::caret, tokStart);
+  case ':':
+    return formToken(FormatToken::colon, tokStart);
+  case ',':
+    return formToken(FormatToken::comma, tokStart);
+  case '=':
+    return formToken(FormatToken::equal, tokStart);
+  case '<':
+    return formToken(FormatToken::less, tokStart);
+  case '>':
+    return formToken(FormatToken::greater, tokStart);
+  case '?':
+    return formToken(FormatToken::question, tokStart);
+  case '(':
+    return formToken(FormatToken::l_paren, tokStart);
+  case ')':
+    return formToken(FormatToken::r_paren, tokStart);
+  case '*':
+    return formToken(FormatToken::star, tokStart);
+
+  // Ignore whitespace characters.
+  case 0:
+  case ' ':
+  case '\t':
+  case '\n':
+    return lexToken();
+
+  case '`':
+    return lexLiteral(tokStart);
+  case '$':
+    return lexVariable(tokStart);
+  }
+}
+
+FormatToken FormatLexer::lexLiteral(const char *tokStart) {
+  assert(curPtr[-1] == '`');
+
+  // Lex a literal surrounded by ``.
+  while (const char curChar = *curPtr++) {
+    if (curChar == '`')
+      return formToken(FormatToken::literal, tokStart);
+  }
+  return emitError(curPtr - 1, "unexpected end of file in literal");
+}
+
+FormatToken FormatLexer::lexVariable(const char *tokStart) {
+  if (!isalpha(curPtr[0]) && curPtr[0] != '_')
+    return emitError(curPtr - 1, "expected variable name");
+
+  // Otherwise, consume the rest of the characters.
+  while (isalnum(*curPtr) || *curPtr == '_')
+    ++curPtr;
+  return formToken(FormatToken::variable, tokStart);
+}
+
+FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
+  // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
+  while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
+    ++curPtr;
+
+  // Check to see if this identifier is a keyword.
+  StringRef str(tokStart, curPtr - tokStart);
+  auto kind =
+      StringSwitch<FormatToken::Kind>(str)
+          .Case("attr-dict", FormatToken::kw_attr_dict)
+          .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
+          .Case("custom", FormatToken::kw_custom)
+          .Case("functional-type", FormatToken::kw_functional_type)
+          .Case("operands", FormatToken::kw_operands)
+          .Case("params", FormatToken::kw_params)
+          .Case("ref", FormatToken::kw_ref)
+          .Case("regions", FormatToken::kw_regions)
+          .Case("results", FormatToken::kw_results)
+          .Case("struct", FormatToken::kw_struct)
+          .Case("successors", FormatToken::kw_successors)
+          .Case("type", FormatToken::kw_type)
+          .Default(FormatToken::identifier);
+  return FormatToken(kind, str);
+}
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
+                                         bool lastWasPunctuation) {
+  if (value.size() != 1 && value != "->")
+    return true;
+  if (lastWasPunctuation)
+    return !StringRef(">)}],").contains(value.front());
+  return !StringRef("<>(){}[],").contains(value.front());
+}
+
+bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) {
+  if (!isalpha(value.front()) && value.front() != '_')
+    return false;
+  return llvm::all_of(value.drop_front(), [](char c) {
+    return isalnum(c) || c == '_' || c == '$' || c == '.';
+  });
+}
+
+bool mlir::tblgen::isValidLiteral(StringRef value) {
+  if (value.empty())
+    return false;
+  char front = value.front();
+
+  // If there is only one character, this must either be punctuation or a
+  // single character bare identifier.
+  if (value.size() == 1)
+    return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
+
+  // Check the punctuation that are larger than a single character.
+  if (value == "->")
+    return true;
+
+  // Otherwise, this must be an identifier.
+  return canFormatStringAsKeyword(value);
+}
+
+//===----------------------------------------------------------------------===//
+// Commandline Options
+//===----------------------------------------------------------------------===//
+
+llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
+    "asmformat-error-is-fatal",
+    llvm::cl::desc("Emit a fatal error if format parsing fails"),
+    llvm::cl::init(true));

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
new file mode 100644
index 0000000000000..f061d1ed5c678
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -0,0 +1,161 @@
+//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains common classes for building custom assembly format parsers
+// and generators.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace llvm {
+class SourceMgr;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// FormatToken
+//===----------------------------------------------------------------------===//
+
+/// This class represents a specific token in the input format.
+class FormatToken {
+public:
+  /// Basic token kinds.
+  enum Kind {
+    // Markers.
+    eof,
+    error,
+
+    // Tokens with no info.
+    l_paren,
+    r_paren,
+    caret,
+    colon,
+    comma,
+    equal,
+    less,
+    greater,
+    question,
+    star,
+
+    // Keywords.
+    keyword_start,
+    kw_attr_dict,
+    kw_attr_dict_w_keyword,
+    kw_custom,
+    kw_functional_type,
+    kw_operands,
+    kw_params,
+    kw_ref,
+    kw_regions,
+    kw_results,
+    kw_struct,
+    kw_successors,
+    kw_type,
+    keyword_end,
+
+    // String valued tokens.
+    identifier,
+    literal,
+    variable,
+  };
+
+  FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+  /// Return the bytes that make up this token.
+  StringRef getSpelling() const { return spelling; }
+
+  /// Return the kind of this token.
+  Kind getKind() const { return kind; }
+
+  /// Return a location for this token.
+  llvm::SMLoc getLoc() const;
+
+  /// Return if this token is a keyword.
+  bool isKeyword() const {
+    return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
+  }
+
+private:
+  /// Discriminator that indicates the kind of token this is.
+  Kind kind;
+
+  /// A reference to the entire token contents; this is always a pointer into
+  /// a memory buffer owned by the source manager.
+  StringRef spelling;
+};
+
+//===----------------------------------------------------------------------===//
+// FormatLexer
+//===----------------------------------------------------------------------===//
+
+/// This class implements a simple lexer for operation assembly format strings.
+class FormatLexer {
+public:
+  FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc);
+
+  /// Lex the next token and return it.
+  FormatToken lexToken();
+
+  /// Emit an error to the lexer with the given location and message.
+  FormatToken emitError(llvm::SMLoc loc, const Twine &msg);
+  FormatToken emitError(const char *loc, const Twine &msg);
+
+  FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+                               const Twine &note);
+
+private:
+  /// Return the next character in the stream.
+  int getNextChar();
+
+  /// Lex an identifier, literal, or variable.
+  FormatToken lexIdentifier(const char *tokStart);
+  FormatToken lexLiteral(const char *tokStart);
+  FormatToken lexVariable(const char *tokStart);
+
+  /// Create a token with the current pointer and a start pointer.
+  FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
+    return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
+  }
+
+  /// The source manager containing the format string.
+  llvm::SourceMgr &mgr;
+  /// Location of the format string.
+  llvm::SMLoc loc;
+  /// Buffer containing the format string.
+  StringRef curBuffer;
+  /// Current pointer in the buffer.
+  const char *curPtr;
+};
+
+/// Whether a space needs to be emitted before a literal. E.g., two keywords
+/// back-to-back require a space separator, but a keyword followed by '<' does
+/// not require a space.
+bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation);
+
+/// Returns true if the given string can be formatted as a keyword.
+bool canFormatStringAsKeyword(StringRef value);
+
+/// Returns true if the given string is valid format literal element.
+bool isValidLiteral(StringRef value);
+
+/// Whether a failure in parsing the assembly format should be a fatal error.
+extern llvm::cl::opt<bool> formatErrorIsFatal;
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ccc4ac9cf4287..19dd6fa7c1016 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "OpFormatGen.h"
+#include "FormatGen.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
@@ -20,7 +21,6 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Signals.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/Record.h"
@@ -30,20 +30,6 @@
 using namespace mlir;
 using namespace mlir::tblgen;
 
-static llvm::cl::opt<bool> formatErrorIsFatal(
-    "asmformat-error-is-fatal",
-    llvm::cl::desc("Emit a fatal error if format parsing fails"),
-    llvm::cl::init(true));
-
-/// Returns true if the given string can be formatted as a keyword.
-static bool canFormatStringAsKeyword(StringRef value) {
-  if (!isalpha(value.front()) && value.front() != '_')
-    return false;
-  return llvm::all_of(value.drop_front(), [](char c) {
-    return isalnum(c) || c == '_' || c == '$' || c == '.';
-  });
-}
-
 //===----------------------------------------------------------------------===//
 // Element
 //===----------------------------------------------------------------------===//
@@ -273,33 +259,12 @@ class LiteralElement : public Element {
   /// Return the literal for this element.
   StringRef getLiteral() const { return literal; }
 
-  /// Returns true if the given string is a valid literal.
-  static bool isValidLiteral(StringRef value);
-
 private:
   /// The spelling of the literal for this element.
   StringRef literal;
 };
 } // end anonymous namespace
 
-bool LiteralElement::isValidLiteral(StringRef value) {
-  if (value.empty())
-    return false;
-  char front = value.front();
-
-  // If there is only one character, this must either be punctuation or a
-  // single character bare identifier.
-  if (value.size() == 1)
-    return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
-
-  // Check the punctuation that are larger than a single character.
-  if (value == "->")
-    return true;
-
-  // Otherwise, this must be an identifier.
-  return canFormatStringAsKeyword(value);
-}
-
 //===----------------------------------------------------------------------===//
 // WhitespaceElement
 
@@ -1705,14 +1670,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
   body << "  _odsPrinter";
 
   // Don't insert a space for certain punctuation.
-  auto shouldPrintSpaceBeforeLiteral = [&] {
-    if (value.size() != 1 && value != "->")
-      return true;
-    if (lastWasPunctuation)
-      return !StringRef(">)}],").contains(value.front());
-    return !StringRef("<>(){}[],").contains(value.front());
-  };
-  if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
+  if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
     body << " << ' '";
   body << " << \"" << value << "\";\n";
 
@@ -2101,253 +2059,6 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
                       lastWasPunctuation);
 }
 
-//===----------------------------------------------------------------------===//
-// FormatLexer
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class represents a specific token in the input format.
-class Token {
-public:
-  enum Kind {
-    // Markers.
-    eof,
-    error,
-
-    // Tokens with no info.
-    l_paren,
-    r_paren,
-    caret,
-    colon,
-    comma,
-    equal,
-    less,
-    greater,
-    question,
-
-    // Keywords.
-    keyword_start,
-    kw_attr_dict,
-    kw_attr_dict_w_keyword,
-    kw_custom,
-    kw_functional_type,
-    kw_operands,
-    kw_ref,
-    kw_regions,
-    kw_results,
-    kw_successors,
-    kw_type,
-    keyword_end,
-
-    // String valued tokens.
-    identifier,
-    literal,
-    variable,
-  };
-  Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
-
-  /// Return the bytes that make up this token.
-  StringRef getSpelling() const { return spelling; }
-
-  /// Return the kind of this token.
-  Kind getKind() const { return kind; }
-
-  /// Return a location for this token.
-  llvm::SMLoc getLoc() const {
-    return llvm::SMLoc::getFromPointer(spelling.data());
-  }
-
-  /// Return if this token is a keyword.
-  bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
-
-private:
-  /// Discriminator that indicates the kind of token this is.
-  Kind kind;
-
-  /// A reference to the entire token contents; this is always a pointer into
-  /// a memory buffer owned by the source manager.
-  StringRef spelling;
-};
-
-/// This class implements a simple lexer for operation assembly format strings.
-class FormatLexer {
-public:
-  FormatLexer(llvm::SourceMgr &mgr, Operator &op);
-
-  /// Lex the next token and return it.
-  Token lexToken();
-
-  /// Emit an error to the lexer with the given location and message.
-  Token emitError(llvm::SMLoc loc, const Twine &msg);
-  Token emitError(const char *loc, const Twine &msg);
-
-  Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine &note);
-
-private:
-  Token formToken(Token::Kind kind, const char *tokStart) {
-    return Token(kind, StringRef(tokStart, curPtr - tokStart));
-  }
-
-  /// Return the next character in the stream.
-  int getNextChar();
-
-  /// Lex an identifier, literal, or variable.
-  Token lexIdentifier(const char *tokStart);
-  Token lexLiteral(const char *tokStart);
-  Token lexVariable(const char *tokStart);
-
-  llvm::SourceMgr &srcMgr;
-  Operator &op;
-  StringRef curBuffer;
-  const char *curPtr;
-};
-} // end anonymous namespace
-
-FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
-    : srcMgr(mgr), op(op) {
-  curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
-  curPtr = curBuffer.begin();
-}
-
-Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
-  srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
-  llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
-                            "in custom assembly format for this operation");
-  return formToken(Token::error, loc.getPointer());
-}
-Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
-                                    const Twine &note) {
-  srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
-  llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
-                            "in custom assembly format for this operation");
-  srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
-  return formToken(Token::error, loc.getPointer());
-}
-Token FormatLexer::emitError(const char *loc, const Twine &msg) {
-  return emitError(llvm::SMLoc::getFromPointer(loc), msg);
-}
-
-int FormatLexer::getNextChar() {
-  char curChar = *curPtr++;
-  switch (curChar) {
-  default:
-    return (unsigned char)curChar;
-  case 0: {
-    // A nul character in the stream is either the end of the current buffer or
-    // a random nul in the file. Disambiguate that here.
-    if (curPtr - 1 != curBuffer.end())
-      return 0;
-
-    // Otherwise, return end of file.
-    --curPtr;
-    return EOF;
-  }
-  case '\n':
-  case '\r':
-    // Handle the newline character by ignoring it and incrementing the line
-    // count. However, be careful about 'dos style' files with \n\r in them.
-    // Only treat a \n\r or \r\n as a single line.
-    if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
-      ++curPtr;
-    return '\n';
-  }
-}
-
-Token FormatLexer::lexToken() {
-  const char *tokStart = curPtr;
-
-  // This always consumes at least one character.
-  int curChar = getNextChar();
-  switch (curChar) {
-  default:
-    // Handle identifiers: [a-zA-Z_]
-    if (isalpha(curChar) || curChar == '_')
-      return lexIdentifier(tokStart);
-
-    // Unknown character, emit an error.
-    return emitError(tokStart, "unexpected character");
-  case EOF:
-    // Return EOF denoting the end of lexing.
-    return formToken(Token::eof, tokStart);
-
-  // Lex punctuation.
-  case '^':
-    return formToken(Token::caret, tokStart);
-  case ':':
-    return formToken(Token::colon, tokStart);
-  case ',':
-    return formToken(Token::comma, tokStart);
-  case '=':
-    return formToken(Token::equal, tokStart);
-  case '<':
-    return formToken(Token::less, tokStart);
-  case '>':
-    return formToken(Token::greater, tokStart);
-  case '?':
-    return formToken(Token::question, tokStart);
-  case '(':
-    return formToken(Token::l_paren, tokStart);
-  case ')':
-    return formToken(Token::r_paren, tokStart);
-
-  // Ignore whitespace characters.
-  case 0:
-  case ' ':
-  case '\t':
-  case '\n':
-    return lexToken();
-
-  case '`':
-    return lexLiteral(tokStart);
-  case '$':
-    return lexVariable(tokStart);
-  }
-}
-
-Token FormatLexer::lexLiteral(const char *tokStart) {
-  assert(curPtr[-1] == '`');
-
-  // Lex a literal surrounded by ``.
-  while (const char curChar = *curPtr++) {
-    if (curChar == '`')
-      return formToken(Token::literal, tokStart);
-  }
-  return emitError(curPtr - 1, "unexpected end of file in literal");
-}
-
-Token FormatLexer::lexVariable(const char *tokStart) {
-  if (!isalpha(curPtr[0]) && curPtr[0] != '_')
-    return emitError(curPtr - 1, "expected variable name");
-
-  // Otherwise, consume the rest of the characters.
-  while (isalnum(*curPtr) || *curPtr == '_')
-    ++curPtr;
-  return formToken(Token::variable, tokStart);
-}
-
-Token FormatLexer::lexIdentifier(const char *tokStart) {
-  // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
-  while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
-    ++curPtr;
-
-  // Check to see if this identifier is a keyword.
-  StringRef str(tokStart, curPtr - tokStart);
-  Token::Kind kind =
-      StringSwitch<Token::Kind>(str)
-          .Case("attr-dict", Token::kw_attr_dict)
-          .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
-          .Case("custom", Token::kw_custom)
-          .Case("functional-type", Token::kw_functional_type)
-          .Case("operands", Token::kw_operands)
-          .Case("ref", Token::kw_ref)
-          .Case("regions", Token::kw_regions)
-          .Case("results", Token::kw_results)
-          .Case("successors", Token::kw_successors)
-          .Case("type", Token::kw_type)
-          .Default(Token::identifier);
-  return Token(kind, str);
-}
-
 //===----------------------------------------------------------------------===//
 // FormatParser
 //===----------------------------------------------------------------------===//
@@ -2366,8 +2077,8 @@ namespace {
 class FormatParser {
 public:
   FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
-      : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
-        seenOperandTypes(op.getNumOperands()),
+      : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
+        op(op), seenOperandTypes(op.getNumOperands()),
         seenResultTypes(op.getNumResults()) {}
 
   /// Parse the operation assembly format.
@@ -2469,7 +2180,8 @@ class FormatParser {
   LogicalResult parseCustomDirectiveParameter(
       std::vector<std::unique_ptr<Element>> &parameters);
   LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
-                                             Token tok, ParserContext context);
+                                             FormatToken tok,
+                                             ParserContext context);
   LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
                                        llvm::SMLoc loc, ParserContext context);
   LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
@@ -2481,8 +2193,8 @@ class FormatParser {
   LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
                                          llvm::SMLoc loc,
                                          ParserContext context);
-  LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                   ParserContext context);
+  LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
+                                   FormatToken tok, ParserContext context);
   LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
                                           bool isRefChild = false);
 
@@ -2492,12 +2204,12 @@ class FormatParser {
 
   /// Advance the current lexer onto the next token.
   void consumeToken() {
-    assert(curToken.getKind() != Token::eof &&
-           curToken.getKind() != Token::error &&
+    assert(curToken.getKind() != FormatToken::eof &&
+           curToken.getKind() != FormatToken::error &&
            "shouldn't advance past EOF or errors");
     curToken = lexer.lexToken();
   }
-  LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
+  LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
     if (curToken.getKind() != kind)
       return emitError(curToken.getLoc(), msg);
     consumeToken();
@@ -2518,7 +2230,7 @@ class FormatParser {
   //===--------------------------------------------------------------------===//
 
   FormatLexer lexer;
-  Token curToken;
+  FormatToken curToken;
   OperationFormat &fmt;
   Operator &op;
 
@@ -2539,7 +2251,7 @@ LogicalResult FormatParser::parse() {
   llvm::SMLoc loc = curToken.getLoc();
 
   // Parse each of the format elements into the main format.
-  while (curToken.getKind() != Token::eof) {
+  while (curToken.getKind() != FormatToken::eof) {
     std::unique_ptr<Element> element;
     if (failed(parseElement(element, TopLevelContext)))
       return ::mlir::failure();
@@ -2864,13 +2576,13 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
   if (curToken.isKeyword())
     return parseDirective(element, context);
   // Literals.
-  if (curToken.getKind() == Token::literal)
+  if (curToken.getKind() == FormatToken::literal)
     return parseLiteral(element, context);
   // Optionals.
-  if (curToken.getKind() == Token::l_paren)
+  if (curToken.getKind() == FormatToken::l_paren)
     return parseOptional(element, context);
   // Variables.
-  if (curToken.getKind() == Token::variable)
+  if (curToken.getKind() == FormatToken::variable)
     return parseVariable(element, context);
   return emitError(curToken.getLoc(),
                    "expected directive, literal, variable, or optional group");
@@ -2878,7 +2590,7 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
 
 LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
                                           ParserContext context) {
-  Token varTok = curToken;
+  FormatToken varTok = curToken;
   consumeToken();
 
   StringRef name = varTok.getSpelling().drop_front();
@@ -2958,31 +2670,31 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
 
 LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
                                            ParserContext context) {
-  Token dirTok = curToken;
+  FormatToken dirTok = curToken;
   consumeToken();
 
   switch (dirTok.getKind()) {
-  case Token::kw_attr_dict:
+  case FormatToken::kw_attr_dict:
     return parseAttrDictDirective(element, dirTok.getLoc(), context,
                                   /*withKeyword=*/false);
-  case Token::kw_attr_dict_w_keyword:
+  case FormatToken::kw_attr_dict_w_keyword:
     return parseAttrDictDirective(element, dirTok.getLoc(), context,
                                   /*withKeyword=*/true);
-  case Token::kw_custom:
+  case FormatToken::kw_custom:
     return parseCustomDirective(element, dirTok.getLoc(), context);
-  case Token::kw_functional_type:
+  case FormatToken::kw_functional_type:
     return parseFunctionalTypeDirective(element, dirTok, context);
-  case Token::kw_operands:
+  case FormatToken::kw_operands:
     return parseOperandsDirective(element, dirTok.getLoc(), context);
-  case Token::kw_regions:
+  case FormatToken::kw_regions:
     return parseRegionsDirective(element, dirTok.getLoc(), context);
-  case Token::kw_results:
+  case FormatToken::kw_results:
     return parseResultsDirective(element, dirTok.getLoc(), context);
-  case Token::kw_successors:
+  case FormatToken::kw_successors:
     return parseSuccessorsDirective(element, dirTok.getLoc(), context);
-  case Token::kw_ref:
+  case FormatToken::kw_ref:
     return parseReferenceDirective(element, dirTok.getLoc(), context);
-  case Token::kw_type:
+  case FormatToken::kw_type:
     return parseTypeDirective(element, dirTok, context);
 
   default:
@@ -2992,7 +2704,7 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
 
 LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
                                          ParserContext context) {
-  Token literalTok = curToken;
+  FormatToken literalTok = curToken;
   if (context != TopLevelContext) {
     return emitError(
         literalTok.getLoc(),
@@ -3014,7 +2726,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
   }
 
   // Check that the parsed literal is valid.
-  if (!LiteralElement::isValidLiteral(value))
+  if (!isValidLiteral(value))
     return emitError(literalTok.getLoc(), "expected valid literal");
 
   element = std::make_unique<LiteralElement>(value);
@@ -3035,14 +2747,15 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
   do {
     if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
       return ::mlir::failure();
-  } while (curToken.getKind() != Token::r_paren);
+  } while (curToken.getKind() != FormatToken::r_paren);
   consumeToken();
 
   // Parse the `else` elements of this optional group.
-  if (curToken.getKind() == Token::colon) {
+  if (curToken.getKind() == FormatToken::colon) {
     consumeToken();
-    if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
-                                          "of optional group")))
+    if (failed(parseToken(FormatToken::l_paren,
+                          "expected '(' to start else branch "
+                          "of optional group")))
       return failure();
     do {
       llvm::SMLoc childLoc = curToken.getLoc();
@@ -3051,11 +2764,12 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
           failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
                                             /*isAnchor=*/false)))
         return failure();
-    } while (curToken.getKind() != Token::r_paren);
+    } while (curToken.getKind() != FormatToken::r_paren);
     consumeToken();
   }
 
-  if (failed(parseToken(Token::question, "expected '?' after optional group")))
+  if (failed(parseToken(FormatToken::question,
+                        "expected '?' after optional group")))
     return ::mlir::failure();
 
   // The optional group is required to have an anchor.
@@ -3090,7 +2804,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
     return ::mlir::failure();
 
   // Check to see if this element is the anchor of the optional group.
-  bool isAnchor = curToken.getKind() == Token::caret;
+  bool isAnchor = curToken.getKind() == FormatToken::caret;
   if (isAnchor) {
     if (anchorIdx)
       return emitError(childLoc, "only one element can be marked as the anchor "
@@ -3194,16 +2908,16 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
     return emitError(loc, "'custom' is only valid as a top-level directive");
 
   // Parse the custom directive name.
-  if (failed(
-          parseToken(Token::less, "expected '<' before custom directive name")))
+  if (failed(parseToken(FormatToken::less,
+                        "expected '<' before custom directive name")))
     return ::mlir::failure();
 
-  Token nameTok = curToken;
-  if (failed(parseToken(Token::identifier,
+  FormatToken nameTok = curToken;
+  if (failed(parseToken(FormatToken::identifier,
                         "expected custom directive name identifier")) ||
-      failed(parseToken(Token::greater,
+      failed(parseToken(FormatToken::greater,
                         "expected '>' after custom directive name")) ||
-      failed(parseToken(Token::l_paren,
+      failed(parseToken(FormatToken::l_paren,
                         "expected '(' before custom directive parameters")))
     return ::mlir::failure();
 
@@ -3212,12 +2926,12 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
   do {
     if (failed(parseCustomDirectiveParameter(elements)))
       return ::mlir::failure();
-    if (curToken.getKind() != Token::comma)
+    if (curToken.getKind() != FormatToken::comma)
       break;
     consumeToken();
   } while (true);
 
-  if (failed(parseToken(Token::r_paren,
+  if (failed(parseToken(FormatToken::r_paren,
                         "expected ')' after custom directive parameters")))
     return ::mlir::failure();
 
@@ -3254,9 +2968,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
   return ::mlir::success();
 }
 
-LogicalResult
-FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
-                                           Token tok, ParserContext context) {
+LogicalResult FormatParser::parseFunctionalTypeDirective(
+    std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
   llvm::SMLoc loc = tok.getLoc();
   if (context != TopLevelContext)
     return emitError(
@@ -3264,11 +2977,14 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
 
   // Parse the main operand.
   std::unique_ptr<Element> inputs, results;
-  if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")) ||
       failed(parseTypeDirectiveOperand(inputs)) ||
-      failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
+      failed(parseToken(FormatToken::comma,
+                        "expected ',' after inputs argument")) ||
       failed(parseTypeDirectiveOperand(results)) ||
-      failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+      failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
     return ::mlir::failure();
   element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
                                                       std::move(results));
@@ -3299,9 +3015,11 @@ FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
     return emitError(loc, "'ref' is only valid within a `custom` directive");
 
   std::unique_ptr<Element> operand;
-  if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")) ||
       failed(parseElement(operand, RefDirectiveContext)) ||
-      failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+      failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
     return ::mlir::failure();
 
   element = std::make_unique<RefDirective>(std::move(operand));
@@ -3360,17 +3078,19 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
 }
 
 LogicalResult
-FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
-                                 ParserContext context) {
+FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
+                                 FormatToken tok, ParserContext context) {
   llvm::SMLoc loc = tok.getLoc();
   if (context == TypeDirectiveContext)
     return emitError(loc, "'type' cannot be used as a child of another `type`");
 
   bool isRefChild = context == RefDirectiveContext;
   std::unique_ptr<Element> operand;
-  if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")) ||
       failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
-      failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+      failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
     return ::mlir::failure();
 
   element = std::make_unique<TypeDirective>(std::move(operand));

diff  --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index f776696ade5ba..eb19a15cabc85 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -220,6 +220,7 @@ cc_library(
         "//mlir:SideEffects",
         "//mlir:StandardOps",
         "//mlir:StandardOpsTransforms",
+        "//mlir:Support",
         "//mlir:TensorDialect",
         "//mlir:TransformUtils",
         "//mlir:Transforms",


        


More information about the Mlir-commits mailing list