[Mlir-commits] [mlir] 0748639 - [mlir][ods] Optional Attribute or Type Parameters

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 8 12:09:49 PST 2022


Author: Mogball
Date: 2022-02-08T20:09:44Z
New Revision: 07486395d2d05c9c567994456774cafdcc1611d0

URL: https://github.com/llvm/llvm-project/commit/07486395d2d05c9c567994456774cafdcc1611d0
DIFF: https://github.com/llvm/llvm-project/commit/07486395d2d05c9c567994456774cafdcc1611d0.diff

LOG: [mlir][ods] Optional Attribute or Type Parameters

Implements optional attribute or type parameters, including support for such parameters in the assembly format `struct` directive. Also implements optional groups.

Depends on D117971

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D118208

Added: 
    

Modified: 
    mlir/docs/Tutorials/DefiningAttributesAndTypes.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.h
    mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
    mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 3212267b64dd3..1501b54af5279 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -1,9 +1,9 @@
 # Defining Dialect Attributes and Types
 
 This document is a quickstart to defining dialect specific extensions to the
-[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system) systems in
-MLIR. The main part of this tutorial focuses on defining types, but the
-instructions are nearly identical for defining attributes.
+[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system)
+systems in MLIR. The main part of this tutorial focuses on defining types, but
+the instructions are nearly identical for defining attributes.
 
 See [MLIR specification](../LangRef.md) for more information about MLIR, the
 structure of the IR, operations, etc.
@@ -24,18 +24,19 @@ defining a new `Type` it isn't always necessary to define a new storage class.
 So before defining the derived `Type`, it's important to know which of the two
 classes of `Type` we are defining:
 
-Some types are _singleton_ in nature, meaning they have no parameters and only
-ever have one instance, like the [`index` type](../Dialects/Builtin.md/#indextype).
+Some types are *singleton* in nature, meaning they have no parameters and only
+ever have one instance, like the
+[`index` type](../Dialects/Builtin.md/#indextype).
 
-Other types are _parametric_, and contain additional information that
+Other types are *parametric*, and contain additional information that
 
diff erentiates 
diff erent instances of the same `Type`. For example the
-[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with `i8` and
-`i16` representing 
diff erent instances of
-[`integer` type](../Dialects/Builtin.md/#integertype). _Parametric_ may also contain a
-mutable component, which can be used, for example, to construct self-referring
-recursive types. The mutable component _cannot_ be used to 
diff erentiate
-instances of a type class, so usually such types contain other parametric
-components that serve to identify them.
+[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with
+`i8` and `i16` representing 
diff erent instances of
+[`integer` type](../Dialects/Builtin.md/#integertype). *Parametric* may also
+contain a mutable component, which can be used, for example, to construct
+self-referring recursive types. The mutable component *cannot* be used to
+
diff erentiate instances of a type class, so usually such types contain other
+parametric components that serve to identify them.
 
 #### Singleton types
 
@@ -389,12 +390,12 @@ Attributes and types defined in ODS with a mnemonic can define an
 `assemblyFormat` to declaratively describe custom parsers and printers. The
 assembly format consists of literals, variables, and directives.
 
-* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
-  `` `keyword` `` or `` `<` ``.
-* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
-  which captures one attribute or type parameter.
-* A directive is a keyword followed by an optional argument list that defines
-  special parser and printer behaviour.
+*   A literal is a keyword or valid punctuation enclosed in backticks, e.g. ``
+    `keyword` `` or `` `<` ``.
+*   A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
+    which captures one attribute or type parameter.
+*   A directive is a keyword followed by an optional argument list that defines
+    special parser and printer behaviour.
 
 ```tablegen
 // An example type with an assembly format.
@@ -412,8 +413,8 @@ def MyType : TypeDef<My_Dialect, "MyType"> {
 }
 ```
 
-The declarative assembly format for `MyType` results in the following format
-in the IR:
+The declarative assembly format for `MyType` results in the following format in
+the IR:
 
 ```mlir
 !my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
@@ -421,15 +422,15 @@ in the IR:
 
 ### Parameter Parsing and Printing
 
-For many basic parameter types, no additional work is needed to define how
-these parameters are parsed or printed.
+For many basic parameter types, no additional work is needed to define how these
+parameters are parsed or printed.
 
-* The default printer for any parameter is `$_printer << $_self`,
-  where `$_self` is the C++ value of the parameter and `$_printer` is an
-  `AsmPrinter`.
-* The default parser for a parameter is
-  `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
-  of the parameter and `$_parser` is an `AsmParser`.
+*   The default printer for any parameter is `$_printer << $_self`, where
+    `$_self` is the C++ value of the parameter and `$_printer` is an
+    `AsmPrinter`.
+*   The default parser for a parameter is
+    `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
+    of the parameter and `$_parser` is an `AsmParser`.
 
 Printing and parsing behaviour can be added to additional C++ types by
 overloading these functions or by defining a `parser` and `printer` in an ODS
@@ -470,8 +471,8 @@ def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
 }
 ```
 
-A type using this parameter with the assembly format `` `<` $myParam `>` ``
-will look as follows in the IR:
+A type using this parameter with the assembly format `` `<` $myParam `>` `` will
+look as follows in the IR:
 
 ```mlir
 !my_dialect.my_type<42 * 24>
@@ -480,10 +481,42 @@ will look as follows in the IR:
 #### Non-POD Parameters
 
 Parameters that aren't plain-old-data (e.g. references) may need to define a
-`cppStorageType` to contain the data until it is copied into the allocator.
-For example, `StringRefParameter` uses `std::string` as its storage type,
-whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
-for these parameters are expected to return `FailureOr<$cppStorageType>`.
+`cppStorageType` to contain the data until it is copied into the allocator. For
+example, `StringRefParameter` uses `std::string` as its storage type, whereas
+`ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for
+these parameters are expected to return `FailureOr<$cppStorageType>`.
+
+#### Optional Parameters
+
+Optional parameters in the assembly format can be indicated by setting
+`isOptional`. The C++ type of an optional parameter is required to satisfy the
+following requirements:
+
+*   is default-constructible
+*   is contextually convertible to `bool`
+*   only the default-constructed value is `false`
+
+The parameter parser should return the default-constructed value to indicate "no
+value present". The printer will guard on the presence of a value to print the
+parameter.
+
+If a value was not parsed for an optional parameter, then the parameter will be
+set to its default-constructed C++ value. For example, `Optional<int>` will be
+set to `llvm::None` and `Attribute` will be set to `nullptr`.
+
+Only optional parameters or directives that only capture optional parameters can
+be used in optional groups. An optional group is a set of elements optionally
+printed based on the presence of an anchor. Suppose parameter `a` is an
+`IntegerAttr`.
+
+```
+( `(` $a^ `)` ) : (`x`)?
+```
+
+In the above assembly format, if `a` is present (non-null), then it will be
+printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that
+are used inside optional groups are allowed only if all captured parameters are
+also optional.
 
 ### Assembly Format Directives
 
@@ -497,9 +530,9 @@ Attribute and type assembly formats have the following directives:
 
 #### `params` Directive
 
-This directive is used to refer to all parameters of an attribute or type.
-When used as a top-level directive, `params` generates a parser and printer for
-a comma-separated list of the parameters. For example:
+This directive is used to refer to all parameters of an attribute or type. When
+used as a top-level directive, `params` generates a parser and printer for a
+comma-separated list of the parameters. For example:
 
 ```tablegen
 def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
@@ -547,12 +580,16 @@ In the IR, the types will appear as:
 !my_dialect.outer_qual<pair : !mydialect.pair<42, 24>>
 ```
 
+If optional parameters are present, they are not printed in the parameter list
+if they are not present.
+
 #### `struct` Directive
 
 The `struct` directive accepts a list of variables to capture and will generate
-a parser and printer for a comma-separated list of key-value pairs. The
-variables are printed in the order they are specified in the argument list **but
-can be parsed in any order**. For example:
+a parser and printer for a comma-separated list of key-value pairs. If an
+optional parameter is included in the `struct`, it can be elided. The variables
+are printed in the order they are specified in the argument list **but can be
+parsed in any order**. For example:
 
 ```tablegen
 def MyStructType : TypeDef<My_Dialect, "MyStructType"> {

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 064145724ab72..bb108f714a789 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -3130,15 +3130,19 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   string summary = desc;
   // The format string for the asm syntax (documentation only).
   string syntax = ?;
-  // The default parameter parser is `::mlir::parseField<T>($_parser)`, which
-  // returns `FailureOr<T>`. Overload `parseField` to support parsing for your
-  // type. Or you can provide a customer printer. For attributes, "$_type" will
-  // be replaced with the required attribute type.
+  // The default parameter parser is `::mlir::FieldParser<T>::parse($_parser)`,
+  // which returns `FailureOr<T>`. Specialize `FieldParser` to support parsing
+  // for your type. Or you can provide a customer printer. For attributes,
+  // "$_type" will be replaced with the required attribute type.
   string parser = ?;
   // The default parameter printer is `$_printer << $_self`. Overload the stream
   // operator of `AsmPrinter` as necessary to print your type. Or you can
   // provide a custom printer.
   string printer = ?;
+  // Mark a parameter as optional. The C++ type of parameters marked as optional
+  // must be default constructible and be contextually convertible to `bool`.
+  // Any `Optional<T>` and any attribute type satisfies these requirements.
+  bit isOptional = 0;
 }
 class AttrParameter<string type, string desc, string accessorType = "">
  : AttrOrTypeParameter<type, desc, accessorType>;
@@ -3183,6 +3187,12 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
   }];
 }
 
+// An optional parameter.
+class OptionalParameter<string type, string desc = ""> :
+    AttrOrTypeParameter<type, desc> {
+  let isOptional = 1;
+}
+
 // This is a special parameter used for AttrDefs that represents a `mlir::Type`
 // that is also used as the value `Type` of the attribute. Only one parameter
 // of the attribute may be of this type.

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 7be684d4c5343..dcfdc8ab28a63 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -45,47 +45,50 @@ class AttrOrTypeBuilder : public Builder {
 // AttrOrTypeParameter
 //===----------------------------------------------------------------------===//
 
-// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
-// AttrOrTypeDefs to parameterize them.
+/// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
+/// AttrOrTypeDefs to parameterize them.
 class AttrOrTypeParameter {
 public:
   explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
       : def(def), index(index) {}
 
-  // Get the parameter name.
+  /// Get the parameter name.
   StringRef getName() const;
 
-  // If specified, get the custom allocator code for this parameter.
+  /// If specified, get the custom allocator code for this parameter.
   Optional<StringRef> getAllocator() const;
 
-  // If specified, get the custom comparator code for this parameter.
+  /// If specified, get the custom comparator code for this parameter.
   Optional<StringRef> getComparator() const;
 
-  // Get the C++ type of this parameter.
+  /// Get the C++ type of this parameter.
   StringRef getCppType() const;
 
-  // Get the C++ accessor type of this parameter.
+  /// Get the C++ accessor type of this parameter.
   StringRef getCppAccessorType() const;
 
-  // Get the C++ storage type of this parameter.
+  /// Get the C++ storage type of this parameter.
   StringRef getCppStorageType() const;
 
-  // Get an optional C++ parameter parser.
+  /// Get an optional C++ parameter parser.
   Optional<StringRef> getParser() const;
 
-  // Get an optional C++ parameter printer.
+  /// Get an optional C++ parameter printer.
   Optional<StringRef> getPrinter() const;
 
-  // Get a description of this parameter for documentation purposes.
+  /// Get a description of this parameter for documentation purposes.
   Optional<StringRef> getSummary() const;
 
-  // Get the assembly syntax documentation.
+  /// Get the assembly syntax documentation.
   StringRef getSyntax() const;
 
-  // Return the underlying def of this parameter.
-  const llvm::Init *getDef() const;
+  /// Returns true if the parameter is optional.
+  bool isOptional() const;
 
-  // The parameter is pointer-comparable.
+  /// Return the underlying def of this parameter.
+  llvm::Init *getDef() const;
+
+  /// The parameter is pointer-comparable.
   bool operator==(const AttrOrTypeParameter &other) const {
     return def == other.def && index == other.index;
   }
@@ -94,6 +97,11 @@ class AttrOrTypeParameter {
   }
 
 private:
+  /// A parameter can be either a string or a def. Get a potentially null value
+  /// from the def.
+  template <typename InitT>
+  auto getDefValue(StringRef name) const;
+
   /// The underlying tablegen parameter list this parameter is a part of.
   const llvm::DagInit *def;
   /// The index of the parameter within the parameter list (`def`).
@@ -121,113 +129,113 @@ class AttrOrTypeDef {
 public:
   explicit AttrOrTypeDef(const llvm::Record *def);
 
-  // Get the dialect for which this def belongs.
+  /// Get the dialect for which this def belongs.
   Dialect getDialect() const;
 
-  // Returns the name of this AttrOrTypeDef record.
+  /// Returns the name of this AttrOrTypeDef record.
   StringRef getName() const;
 
-  // Query functions for the documentation of the def.
+  /// Query functions for the documentation of the def.
   bool hasDescription() const;
   StringRef getDescription() const;
   bool hasSummary() const;
   StringRef getSummary() const;
 
-  // Returns the name of the C++ class to generate.
+  /// Returns the name of the C++ class to generate.
   StringRef getCppClassName() const;
 
-  // Returns the name of the C++ base class to use when generating this def.
+  /// Returns the name of the C++ base class to use when generating this def.
   StringRef getCppBaseClassName() const;
 
-  // Returns the name of the storage class for this def.
+  /// Returns the name of the storage class for this def.
   StringRef getStorageClassName() const;
 
-  // Returns the C++ namespace for this def's storage class.
+  /// Returns the C++ namespace for this def's storage class.
   StringRef getStorageNamespace() const;
 
-  // Returns true if we should generate the storage class.
+  /// Returns true if we should generate the storage class.
   bool genStorageClass() const;
 
-  // Indicates whether or not to generate the storage class constructor.
+  /// Indicates whether or not to generate the storage class constructor.
   bool hasStorageCustomConstructor() const;
 
   /// Get the parameters of this attribute or type.
   ArrayRef<AttrOrTypeParameter> getParameters() const { return parameters; }
 
-  // Return the number of parameters
+  /// Return the number of parameters
   unsigned getNumParameters() const;
 
-  // Return the keyword/mnemonic to use in the printer/parser methods if we are
-  // supposed to auto-generate them.
+  /// Return the keyword/mnemonic to use in the printer/parser methods if we are
+  /// supposed to auto-generate them.
   Optional<StringRef> getMnemonic() const;
 
-  // Returns the code to use as the types printer method. If not specified,
-  // return a non-value. Otherwise, return the contents of that code block.
+  /// Returns the code to use as the types printer method. If not specified,
+  /// return a non-value. Otherwise, return the contents of that code block.
   Optional<StringRef> getPrinterCode() const;
 
-  // Returns the code to use as the parser method. If not specified, returns
-  // None. Otherwise, returns the contents of that code block.
+  /// Returns the code to use as the parser method. If not specified, returns
+  /// None. Otherwise, returns the contents of that code block.
   Optional<StringRef> getParserCode() const;
 
-  // Returns the custom assembly format, if one was specified.
+  /// Returns the custom assembly format, if one was specified.
   Optional<StringRef> getAssemblyFormat() const;
 
-  // An attribute or type with parameters needs a parser.
+  /// An attribute or type with parameters needs a parser.
   bool needsParserPrinter() const { return getNumParameters() != 0; }
 
-  // Returns true if this attribute or type has a generated parser.
+  /// Returns true if this attribute or type has a generated parser.
   bool hasGeneratedParser() const {
     return getParserCode() || getAssemblyFormat();
   }
 
-  // Returns true if this attribute or type has a generated printer.
+  /// Returns true if this attribute or type has a generated printer.
   bool hasGeneratedPrinter() const {
     return getPrinterCode() || getAssemblyFormat();
   }
 
-  // Returns true if the accessors based on the parameters should be generated.
+  /// Returns true if the accessors based on the parameters should be generated.
   bool genAccessors() const;
 
-  // Return true if we need to generate the verify declaration and getChecked
-  // method.
+  /// Return true if we need to generate the verify declaration and getChecked
+  /// method.
   bool genVerifyDecl() const;
 
-  // Returns the def's extra class declaration code.
+  /// Returns the def's extra class declaration code.
   Optional<StringRef> getExtraDecls() const;
 
-  // Get the code location (for error printing).
+  /// Get the code location (for error printing).
   ArrayRef<SMLoc> getLoc() const;
 
-  // Returns true if the default get/getChecked methods should be skipped during
-  // generation.
+  /// Returns true if the default get/getChecked methods should be skipped
+  /// during generation.
   bool skipDefaultBuilders() const;
 
-  // Returns the builders of this def.
+  /// Returns the builders of this def.
   ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
 
-  // Returns the traits of this def.
+  /// Returns the traits of this def.
   ArrayRef<Trait> getTraits() const { return traits; }
 
-  // Returns whether two AttrOrTypeDefs are equal by checking the equality of
-  // the underlying record.
+  /// Returns whether two AttrOrTypeDefs are equal by checking the equality of
+  /// the underlying record.
   bool operator==(const AttrOrTypeDef &other) const;
 
-  // Compares two AttrOrTypeDefs by comparing the names of the dialects.
+  /// Compares two AttrOrTypeDefs by comparing the names of the dialects.
   bool operator<(const AttrOrTypeDef &other) const;
 
-  // Returns whether the AttrOrTypeDef is defined.
+  /// Returns whether the AttrOrTypeDef is defined.
   operator bool() const { return def != nullptr; }
 
-  // Return the underlying def.
+  /// Return the underlying def.
   const llvm::Record *getDef() const { return def; }
 
 protected:
   const llvm::Record *def;
 
-  // The builders of this definition.
+  /// The builders of this definition.
   SmallVector<AttrOrTypeBuilder> builders;
 
-  // The traits of this definition.
+  /// The traits of this definition.
   SmallVector<Trait> traits;
 
   /// The parameters of this attribute or type.
@@ -243,8 +251,8 @@ class AttrDef : public AttrOrTypeDef {
 public:
   using AttrOrTypeDef::AttrOrTypeDef;
 
-  // Returns the attributes value type builder code block, or None if it doesn't
-  // have one.
+  /// Returns the attributes value type builder code block, or None if it
+  /// doesn't have one.
   Optional<StringRef> getTypeBuilder() const;
 
   static bool classof(const AttrOrTypeDef *def);

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c3a14bb5f1ddf..2d92a5837c79f 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -177,32 +177,30 @@ bool AttrDef::classof(const AttrOrTypeDef *def) {
 // AttrOrTypeParameter
 //===----------------------------------------------------------------------===//
 
+template <typename InitT>
+auto AttrOrTypeParameter::getDefValue(StringRef name) const {
+  Optional<decltype(std::declval<InitT>().getValue())> result;
+  if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+    if (auto *init = param->getDef()->getValue(name))
+      if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
+        result = value->getValue();
+  return result;
+}
+
 StringRef AttrOrTypeParameter::getName() const {
   return def->getArgName(index)->getValue();
 }
 
 Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
-  llvm::Init *parameterType = def->getArg(index);
-  if (isa<llvm::StringInit>(parameterType))
-    return Optional<StringRef>();
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
-    return param->getDef()->getValueAsOptionalString("allocator");
-  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
-                        "defs which inherit from AttrOrTypeParameter\n");
+  return getDefValue<llvm::StringInit>("allocator");
 }
 
 Optional<StringRef> AttrOrTypeParameter::getComparator() const {
-  llvm::Init *parameterType = def->getArg(index);
-  if (isa<llvm::StringInit>(parameterType))
-    return Optional<StringRef>();
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
-    return param->getDef()->getValueAsOptionalString("comparator");
-  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
-                        "defs which inherit from AttrOrTypeParameter\n");
+  return getDefValue<llvm::StringInit>("comparator");
 }
 
 StringRef AttrOrTypeParameter::getCppType() const {
-  auto *parameterType = def->getArg(index);
+  llvm::Init *parameterType = getDef();
   if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
     return stringType->getValue();
   if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
@@ -213,74 +211,45 @@ StringRef AttrOrTypeParameter::getCppType() const {
 }
 
 StringRef AttrOrTypeParameter::getCppAccessorType() const {
-  if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
-    if (Optional<StringRef> type =
-            param->getDef()->getValueAsOptionalString("cppAccessorType"))
-      return *type;
-  }
-  return getCppType();
+  return getDefValue<llvm::StringInit>("cppAccessorType")
+      .getValueOr(getCppType());
 }
 
 StringRef AttrOrTypeParameter::getCppStorageType() const {
-  if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
-    if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
-      return *type;
-  }
-  return getCppType();
+  return getDefValue<llvm::StringInit>("cppStorageType")
+      .getValueOr(getCppType());
 }
 
 Optional<StringRef> AttrOrTypeParameter::getParser() const {
-  auto *parameterType = def->getArg(index);
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
-    if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
-      return *parser;
-  }
-  return {};
+  return getDefValue<llvm::StringInit>("parser");
 }
 
 Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
-  auto *parameterType = def->getArg(index);
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
-    if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
-      return *printer;
-  }
-  return {};
+  return getDefValue<llvm::StringInit>("printer");
 }
 
 Optional<StringRef> AttrOrTypeParameter::getSummary() const {
-  auto *parameterType = def->getArg(index);
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
-    const auto *desc = param->getDef()->getValue("summary");
-    if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
-      return ci->getValue();
-  }
-  return Optional<StringRef>();
+  return getDefValue<llvm::StringInit>("summary");
 }
 
 StringRef AttrOrTypeParameter::getSyntax() const {
-  auto *parameterType = def->getArg(index);
-  if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+  if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
     return stringType->getValue();
-  if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
-    const auto *syntax = param->getDef()->getValue("syntax");
-    if (syntax && isa<llvm::StringInit>(syntax->getValue()))
-      return cast<llvm::StringInit>(syntax->getValue())->getValue();
-    return getCppType();
-  }
-  llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
-                        "defs which inherit from AttrOrTypeParameter");
+  return getDefValue<llvm::StringInit>("syntax").getValueOr(getCppType());
 }
 
-const llvm::Init *AttrOrTypeParameter::getDef() const {
-  return def->getArg(index);
+bool AttrOrTypeParameter::isOptional() const {
+  return getDefValue<llvm::BitInit>("isOptional").getValueOr(false);
 }
 
+llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
+
 //===----------------------------------------------------------------------===//
 // AttributeSelfTypeParameter
 //===----------------------------------------------------------------------===//
 
 bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
-  const llvm::Init *paramDef = param->getDef();
+  llvm::Init *paramDef = param->getDef();
   if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
     return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
   return false;

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 3a052dc98ed0c..8751ee27eef42 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/DataLayoutInterfaces.td"
 
 // All of the types will extend this class.
 class Test_Type<string name, list<Trait> traits = []>
-  : TypeDef<Test_Dialect, name, traits>;
+    : TypeDef<Test_Dialect, name, traits>;
 
 def SimpleTypeA : Test_Type<"SimpleA"> {
   let mnemonic = "smpla";
@@ -42,7 +42,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
     ArrayRefParameter<
       "int", // The parameter C++ type.
       "An example of an array of ints" // Parameter description.
-      >: $arrayOfInts
+    >:$arrayOfInts
   );
 
   let extraClassDeclaration = [{
@@ -110,14 +110,14 @@ def IntegerType : Test_Type<"TestInteger"> {
 
   // The parser is defined here also.
   let parser = [{
-    if (parser.parseLess()) return Type();
+    if ($_parser.parseLess()) return Type();
     SignednessSemantics signedness;
-    if (parseSignedness($_parser, signedness)) return mlir::Type();
+    if (parseSignedness($_parser, signedness)) return Type();
     if ($_parser.parseComma()) return Type();
     int width;
     if ($_parser.parseInteger(width)) return Type();
     if ($_parser.parseGreater()) return Type();
-    ::mlir::Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
+    Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
     return getChecked(loc, loc.getContext(), width, signedness);
   }];
 
@@ -262,4 +262,66 @@ def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
   let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a, "int":$b);
+  let mnemonic = "optional_param";
+  let assemblyFormat = "`<` $a `,` $b `>`";
+}
+
+def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        StringRefParameter<>:$b);
+  let mnemonic = "optional_params";
+  let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeOptionalParamsAfterRequired
+    : Test_Type<"TestTypeOptionalParamsAfterRequired"> {
+  let parameters = (ins StringRefParameter<>:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "optional_params_after";
+  let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeOptionalStruct : Test_Type<"TestTypeOptionalStruct"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        StringRefParameter<>:$b);
+  let mnemonic = "optional_struct";
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def TestTypeAllOptionalParams : Test_Type<"TestTypeAllOptionalParams"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "all_optional_params";
+  let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeAllOptionalStruct : Test_Type<"TestTypeAllOptionalStruct"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "all_optional_struct";
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> {
+  let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "optional_group";
+  let assemblyFormat = "`<` (`(` $b^ `)`) : (`x`)? $a `>`";
+}
+
+def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "optional_group_params";
+  let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`";
+}
+
+def TestTypeOptionalGroupStruct : Test_Type<"TestTypeOptionalGroupStruct"> {
+  let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+                        OptionalParameter<"mlir::Optional<int>">:$b);
+  let mnemonic = "optional_group_struct";
+  let assemblyFormat = "`<` (`(` struct(params)^ `)`) : (`x`)? `>`";
+}
+
 #endif // TEST_TYPEDEFS

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 91650afebb9b1..d7b1163f7b39d 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -64,11 +64,28 @@ struct FieldParser<test::CustomParam> {
     return test::CustomParam{value.getValue()};
   }
 };
+
 inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
                                     test::CustomParam param) {
   return printer << param.value;
 }
 
+/// Overload the attribute parameter parser for optional integers.
+template <>
+struct FieldParser<Optional<int>> {
+  static FailureOr<Optional<int>> parse(AsmParser &parser) {
+    Optional<int> value;
+    value.emplace();
+    OptionalParseResult result = parser.parseOptionalInteger(*value);
+    if (result.hasValue()) {
+      if (succeeded(*result))
+        return value;
+      return failure();
+    }
+    value.reset();
+    return value;
+  }
+};
 } // namespace mlir
 
 #include "TestTypeInterfaces.h.inc"

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index 012685fd05cba..d92ae1677100c 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -11,28 +11,28 @@ class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
   let mnemonic = asm;
 }
 
-/// Test format is missing a parameter capture.
+// Test format is missing a parameter capture.
 def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
   let parameters = (ins "int":$v0, "int":$v1);
   // CHECK: format is missing reference to parameter: v1
   let assemblyFormat = "`<` $v0 `>`";
 }
 
-/// Test format has duplicate parameter captures.
+// Test format has duplicate parameter captures.
 def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
   let parameters = (ins "int":$v0, "int":$v1);
   // CHECK: duplicate parameter 'v0'
   let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
 }
 
-/// Test format has invalid syntax.
+// Test format has invalid syntax.
 def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
   let parameters = (ins "int":$v0, "int":$v1);
   // CHECK: expected literal, variable, directive, or optional group
   let assemblyFormat = "`<` $v0, $v1 `>`";
 }
 
-/// Test struct directive has invalid syntax.
+// Test struct directive has invalid syntax.
 def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
   let parameters = (ins "int":$v0);
   // CHECK: literals may only be used in the top-level section of the format
@@ -40,37 +40,70 @@ def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
   let assemblyFormat = "`<` struct($v0, `,`) `>`";
 }
 
-/// Test struct directive cannot capture zero parameters.
+// Test struct directive cannot capture zero parameters.
 def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
   let parameters = (ins "int":$v0);
   // CHECK: `struct` argument list expected a variable or directive
   let assemblyFormat = "`<` struct() $v0 `>`";
 }
 
-/// Test capture parameter that does not exist.
+// Test capture parameter that does not exist.
 def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
   let parameters = (ins "int":$v0);
   // CHECK: InvalidTypeF has no parameter named 'v1'
   let assemblyFormat = "`<` $v0 $v1 `>`";
 }
 
-/// Test duplicate capture of parameter in capture-all struct.
+// Test duplicate capture of parameter in capture-all struct.
 def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
   let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
   // CHECK: duplicate parameter 'v0'
   let assemblyFormat = "`<` struct(params) $v0 `>`";
 }
 
-/// Test capture-all struct duplicate capture.
+// Test capture-all struct duplicate capture.
 def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
   let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
   // CHECK: `params` captures duplicate parameter: v0
   let assemblyFormat = "`<` $v0 struct(params) `>`";
 }
 
-/// Test capture of parameter after `params` directive.
+// Test capture of parameter after `params` directive.
 def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
   let parameters = (ins "int":$v0);
   // CHECK: duplicate parameter 'v0'
   let assemblyFormat = "`<` params $v0 `>`";
 }
+
+// Test `struct` with optional parameter followed by comma.
+def InvalidTypeJ : InvalidType<"InvalidTypeJ", "invalid_j"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: directive with optional parameters cannot be followed by a comma literal
+  let assemblyFormat = "struct($a) `,` $b";
+}
+
+// Test `struct` in optional group must have all optional parameters.
+def InvalidTypeK : InvalidType<"InvalidTypeK", "invalid_k"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: is only allowed in an optional group if all captured parameters are optional
+  let assemblyFormat = "(`(` struct(params)^ `)`)?";
+}
+
+// Test `struct` in optional group must have all optional parameters.
+def InvalidTypeL : InvalidType<"InvalidTypeL", "invalid_l"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: directive allowed in optional group only if all parameters are optional
+  let assemblyFormat = "(`(` params^ `)`)?";
+}
+
+def InvalidTypeM : InvalidType<"InvalidTypeM", "invalid_m"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  // CHECK: parameters in an optional group must be optional
+  let assemblyFormat = "(`(` $a^ `,` $b `)`)?";
+}
+
+def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
+  let parameters = (ins OptionalParameter<"int">:$a);
+  // CHECK: optional group anchor must be a parameter or directive
+  let assemblyFormat = "(`(` $a `)`^)?";
+}

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index 20aef66e183f4..3b05d6c7c1371 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -20,4 +20,52 @@ attributes {
 // CHECK-LABEL: @test_roundtrip_default_parsers_struct
 // CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
 // CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
-func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>
+// CHECK: !test.optional_param<, 6>
+// CHECK: !test.optional_param<5, 6>
+// CHECK: !test.optional_params<"a">
+// CHECK: !test.optional_params<5, "a">
+// CHECK: !test.optional_struct<b = "a">
+// CHECK: !test.optional_struct<a = 5, b = "a">
+// CHECK: !test.optional_params_after<"a">
+// CHECK: !test.optional_params_after<"a", 5>
+// CHECK: !test.all_optional_params<>
+// CHECK: !test.all_optional_params<5>
+// CHECK: !test.all_optional_params<5, 6>
+// CHECK: !test.all_optional_struct<>
+// CHECK: !test.all_optional_struct<b = 5>
+// CHECK: !test.all_optional_struct<a = 5, b = 10>
+// CHECK: !test.optional_group<(5) 6>
+// CHECK: !test.optional_group<x 6>
+// CHECK: !test.optional_group_params<x>
+// CHECK: !test.optional_group_params<(5)>
+// CHECK: !test.optional_group_params<(5, 6)>
+// CHECK: !test.optional_group_struct<x>
+// CHECK: !test.optional_group_struct<(b = 5)>
+// CHECK: !test.optional_group_struct<(a = 10, b = 5)>
+func private @test_roundtrip_default_parsers_struct(
+  !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
+) -> (
+  !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>,
+  !test.optional_param<, 6>,
+  !test.optional_param<5, 6>,
+  !test.optional_params<"a">,
+  !test.optional_params<5, "a">,
+  !test.optional_struct<b = "a">,
+  !test.optional_struct<b = "a", a = 5>,
+  !test.optional_params_after<"a">,
+  !test.optional_params_after<"a", 5>,
+  !test.all_optional_params<>,
+  !test.all_optional_params<5>,
+  !test.all_optional_params<5, 6>,
+  !test.all_optional_struct<>,
+  !test.all_optional_struct<b = 5>,
+  !test.all_optional_struct<b = 10, a = 5>,
+  !test.optional_group<(5) 6>,
+  !test.optional_group<x 6>,
+  !test.optional_group_params<x>,
+  !test.optional_group_params<(5)>,
+  !test.optional_group_params<(5, 6)>,
+  !test.optional_group_struct<x>,
+  !test.optional_group_struct<(b = 5)>,
+  !test.optional_group_struct<(b = 5, a = 10)>
+)

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index d354c35480a95..4ed281d488db5 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -35,38 +35,38 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
 
 /// Check simple attribute parser and printer are generated correctly.
 
-// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &parser,
-// ATTR:                                    ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR:                                    ::mlir::Type odsType) {
 // ATTR:   FailureOr<IntegerAttr> _result_value;
 // ATTR:   FailureOr<TestParamA> _result_complex;
-// ATTR:   if (parser.parseKeyword("hello"))
+// ATTR:   if (odsParser.parseKeyword("hello"))
 // ATTR:     return {};
-// ATTR:   if (parser.parseEqual())
+// ATTR:   if (odsParser.parseEqual())
 // ATTR:     return {};
-// ATTR:   _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// ATTR:   if (failed(_result_value))
+// ATTR:   _result_value = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// ATTR:   if (::mlir::failed(_result_value))
 // ATTR:     return {};
-// ATTR:   if (parser.parseComma())
+// ATTR:   if (odsParser.parseComma())
 // ATTR:     return {};
-// ATTR:   _result_complex = ::parseAttrParamA(parser, type);
-// ATTR:   if (failed(_result_complex))
+// ATTR:   _result_complex = ::parseAttrParamA(odsParser, odsType);
+// ATTR:   if (::mlir::failed(_result_complex))
 // ATTR:     return {};
-// ATTR:   if (parser.parseRParen())
+// ATTR:   if (odsParser.parseRParen())
 // ATTR:     return {};
-// ATTR:   return TestAAttr::get(parser.getContext(),
-// ATTR:                         _result_value.getValue(),
-// ATTR:                         _result_complex.getValue());
+// ATTR:   return TestAAttr::get(odsParser.getContext(),
+// ATTR:                         *_result_value,
+// ATTR:                         *_result_complex);
 // ATTR: }
 
-// ATTR: void TestAAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR:   printer << ' ' << "hello";
-// ATTR:   printer << ' ' << "=";
-// ATTR:   printer << ' ';
-// ATTR:   printer.printStrippedAttrOrType(getValue());
-// ATTR:   printer << ",";
-// ATTR:   printer << ' ';
-// ATTR:   ::printAttrParamA(printer, getComplex());
-// ATTR:   printer << ")";
+// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR:   odsPrinter << ' ' << "hello";
+// ATTR:   odsPrinter << ' ' << "=";
+// ATTR:   odsPrinter << ' ';
+// ATTR:   odsPrinter.printStrippedAttrOrType(getValue());
+// ATTR:   odsPrinter << ",";
+// ATTR:   odsPrinter << ' ';
+// ATTR:   ::printAttrParamA(odsPrinter, getComplex());
+// ATTR:   odsPrinter << ")";
 // ATTR: }
 
 def AttrA : TestAttr<"TestA"> {
@@ -81,47 +81,48 @@ def AttrA : TestAttr<"TestA"> {
 
 /// Test simple struct parser and printer are generated correctly.
 
-// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &parser,
-// ATTR:                                    ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR:                                    ::mlir::Type odsType) {
 // ATTR:   bool _seen_v0 = false;
 // ATTR:   bool _seen_v1 = false;
-// ATTR:   for (unsigned _index = 0; _index < 2; ++_index) {
-// ATTR:     StringRef _paramKey;
-// ATTR:     if (parser.parseKeyword(&_paramKey))
-// ATTR:       return {};
-// ATTR:     if (parser.parseEqual())
+// ATTR:   const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// ATTR:     if (odsParser.parseEqual())
 // ATTR:       return {};
 // ATTR:     if (!_seen_v0 && _paramKey == "v0") {
 // ATTR:       _seen_v0 = true;
-// ATTR:       _result_v0 = ::parseAttrParamA(parser, type);
-// ATTR:       if (failed(_result_v0))
+// ATTR:       _result_v0 = ::parseAttrParamA(odsParser, odsType);
+// ATTR:       if (::mlir::failed(_result_v0))
 // ATTR:         return {};
 // ATTR:     } else if (!_seen_v1 && _paramKey == "v1") {
 // ATTR:       _seen_v1 = true;
-// ATTR:       _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser);
-// ATTR:       if (failed(_result_v1))
+// ATTR:       _result_v1 = odsType ? ::parseAttrWithType(odsParser, odsType) :
+// ATTR-SAME:                         ::parseAttrWithout(odsParser);
+// ATTR:       if (::mlir::failed(_result_v1))
 // ATTR:         return {};
 // ATTR:     } else {
 // ATTR:       return {};
 // ATTR:     }
-// ATTR:     if ((_index != 2 - 1) && parser.parseComma())
+// ATTR:     return true;
+// ATTR:   }
+// ATTR:   for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// ATTR:     StringRef _paramKey;
+// ATTR:     if (odsParser.parseKeyword(&_paramKey))
+// ATTR:       return {};
+// ATTR:     if (!_loop_body(_paramKey)) return {};
+// ATTR:     if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
 // ATTR:       return {};
 // ATTR:   }
-// ATTR:   return TestBAttr::get(parser.getContext(),
-// ATTR:                         _result_v0.getValue(),
-// ATTR:                         _result_v1.getValue());
+// ATTR:   return TestBAttr::get(odsParser.getContext(),
+// ATTR:                         *_result_v0,
+// ATTR:                         *_result_v1);
 // ATTR: }
 
-// ATTR: void TestBAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR:   printer << "v0";
-// ATTR:   printer << ' ' << "=";
-// ATTR:   printer << ' ';
-// ATTR:   ::printAttrParamA(printer, getV0());
-// ATTR:   printer << ",";
-// ATTR:   printer << ' ' << "v1";
-// ATTR:   printer << ' ' << "=";
-// ATTR:   printer << ' ';
-// ATTR:   ::printAttrB(printer, getV1());
+// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR:   odsPrinter << "v0 = ";
+// ATTR:   ::printAttrParamA(odsPrinter, getV0());
+// ATTR:   odsPrinter << ", ";
+// ATTR:   odsPrinter << "v1 = ";
+// ATTR:   ::printAttrB(odsPrinter, getV1());
 // ATTR: }
 
 def AttrB : TestAttr<"TestB"> {
@@ -136,29 +137,21 @@ def AttrB : TestAttr<"TestB"> {
 
 /// Test attribute with capture-all params has correct parser and printer.
 
-// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &parser,
-// ATTR:                                    ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR:                                    ::mlir::Type odsType) {
 // ATTR:   ::mlir::FailureOr<int> _result_v0;
 // ATTR:   ::mlir::FailureOr<int> _result_v1;
-// ATTR:   _result_v0 = ::mlir::FieldParser<int>::parse(parser);
-// ATTR:   if (failed(_result_v0))
+// ATTR:   _result_v0 = ::mlir::FieldParser<int>::parse(odsParser);
+// ATTR:   if (::mlir::failed(_result_v0))
 // ATTR:     return {};
-// ATTR:   if (parser.parseComma())
+// ATTR:   if (odsParser.parseComma())
 // ATTR:     return {};
-// ATTR:   _result_v1 = ::mlir::FieldParser<int>::parse(parser);
-// ATTR:   if (failed(_result_v1))
+// ATTR:   _result_v1 = ::mlir::FieldParser<int>::parse(odsParser);
+// ATTR:   if (::mlir::failed(_result_v1))
 // ATTR:     return {};
-// ATTR:   return TestFAttr::get(parser.getContext(),
-// ATTR:     _result_v0.getValue(),
-// ATTR:     _result_v1.getValue());
-// ATTR: }
-
-// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR:   printer << ' ';
-// ATTR:   printer.printStrippedAttrOrType(getV0());
-// ATTR:   printer << ",";
-// ATTR:   printer << ' ';
-// ATTR:   printer.printStrippedAttrOrType(getV1());
+// ATTR:   return TestFAttr::get(odsParser.getContext(),
+// ATTR:     *_result_v0,
+// ATTR:     *_result_v1);
 // ATTR: }
 
 def AttrC : TestAttr<"TestF"> {
@@ -171,55 +164,57 @@ def AttrC : TestAttr<"TestF"> {
 /// Test type parser and printer that mix variables and struct are generated
 /// correctly.
 
-// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &parser) {
+// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &odsParser) {
 // TYPE:  FailureOr<IntegerAttr> _result_value;
 // TYPE:  FailureOr<TestParamC> _result_complex;
-// TYPE:  if (parser.parseKeyword("foo"))
+// TYPE:  if (odsParser.parseKeyword("foo"))
 // TYPE:    return {};
-// TYPE:  if (parser.parseComma())
+// TYPE:  if (odsParser.parseComma())
 // TYPE:    return {};
-// TYPE:  if (parser.parseColon())
+// TYPE:  if (odsParser.parseColon())
 // TYPE:    return {};
-// TYPE:  if (parser.parseKeyword("bob"))
+// TYPE:  if (odsParser.parseKeyword("bob"))
 // TYPE:    return {};
-// TYPE:  if (parser.parseKeyword("bar"))
+// TYPE:  if (odsParser.parseKeyword("bar"))
 // TYPE:    return {};
-// TYPE:  _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE:  if (failed(_result_value))
+// TYPE:  _result_value = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE:  if (::mlir::failed(_result_value))
 // TYPE:    return {};
 // TYPE:  bool _seen_complex = false;
-// TYPE:  for (unsigned _index = 0; _index < 1; ++_index) {
-// TYPE:    StringRef _paramKey;
-// TYPE:    if (parser.parseKeyword(&_paramKey))
-// TYPE:      return {};
+// TYPE:  const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
 // TYPE:    if (!_seen_complex && _paramKey == "complex") {
 // TYPE:      _seen_complex = true;
-// TYPE:      _result_complex = ::parseTypeParamC(parser);
-// TYPE:      if (failed(_result_complex))
+// TYPE:      _result_complex = ::parseTypeParamC(odsParser);
+// TYPE:      if (::mlir::failed(_result_complex))
 // TYPE:        return {};
 // TYPE:    } else {
 // TYPE:      return {};
 // TYPE:    }
-// TYPE:    if ((_index != 1 - 1) && parser.parseComma())
+// TYPE:    return true;
+// TYPE:  }
+// TYPE:  for (unsigned odsStructIndex = 0; odsStructIndex < 1; ++odsStructIndex) {
+// TYPE:    StringRef _paramKey;
+// TYPE:    if (odsParser.parseKeyword(&_paramKey))
+// TYPE:      return {};
+// TYPE:    if (!_loop_body(_paramKey)) return {};
+// TYPE:    if ((odsStructIndex != 1 - 1) && odsParser.parseComma())
 // TYPE:      return {};
 // TYPE:  }
-// TYPE:  if (parser.parseRParen())
+// TYPE:  if (odsParser.parseRParen())
 // TYPE:    return {};
 // TYPE:  }
 
-// TYPE: void TestCType::print(::mlir::AsmPrinter &printer) const {
-// TYPE:   printer << ' ' << "foo";
-// TYPE:   printer << ",";
-// TYPE:   printer << ' ' << ":";
-// TYPE:   printer << ' ' << "bob";
-// TYPE:   printer << ' ' << "bar";
-// TYPE:   printer << ' ';
-// TYPE:   printer.printStrippedAttrOrType(getValue());
-// TYPE:   printer << ' ' << "complex";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer << getComplex();
-// TYPE:   printer << ")";
+// TYPE: void TestCType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   odsPrinter << ' ' << "foo";
+// TYPE:   odsPrinter << ",";
+// TYPE:   odsPrinter << ' ' << ":";
+// TYPE:   odsPrinter << ' ' << "bob";
+// TYPE:   odsPrinter << ' ' << "bar";
+// TYPE:   odsPrinter << ' ';
+// TYPE:   odsPrinter.printStrippedAttrOrType(getValue());
+// TYPE:   odsPrinter << "complex = ";
+// TYPE:   odsPrinter << getComplex();
+// TYPE:   odsPrinter << ")";
 // TYPE: }
 
 def TypeA : TestType<"TestC"> {
@@ -235,51 +230,53 @@ def TypeA : TestType<"TestC"> {
 /// Test type parser and printer with mix of variables and struct are generated
 /// correctly.
 
-// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &parser) {
-// TYPE:   _result_v0 = ::parseTypeParamC(parser);
-// TYPE:   if (failed(_result_v0))
+// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   _result_v0 = ::parseTypeParamC(odsParser);
+// TYPE:   if (::mlir::failed(_result_v0))
 // TYPE:     return {};
 // TYPE:   bool _seen_v1 = false;
 // TYPE:   bool _seen_v2 = false;
-// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE:     StringRef _paramKey;
-// TYPE:     if (parser.parseKeyword(&_paramKey))
-// TYPE:       return {};
-// TYPE:     if (parser.parseEqual())
+// TYPE:   const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE:     if (odsParser.parseEqual())
 // TYPE:       return {};
 // TYPE:     if (!_seen_v1 && _paramKey == "v1") {
 // TYPE:       _seen_v1 = true;
 // TYPE:       _result_v1 = someFcnCall();
-// TYPE:       if (failed(_result_v1))
+// TYPE:       if (::mlir::failed(_result_v1))
 // TYPE:         return {};
 // TYPE:     } else if (!_seen_v2 && _paramKey == "v2") {
 // TYPE:       _seen_v2 = true;
-// TYPE:       _result_v2 = ::parseTypeParamC(parser);
-// TYPE:       if (failed(_result_v2))
+// TYPE:       _result_v2 = ::parseTypeParamC(odsParser);
+// TYPE:       if (::mlir::failed(_result_v2))
 // TYPE:         return {};
 // TYPE:     } else  {
 // TYPE:       return {};
 // TYPE:     }
-// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:     return true;
+// TYPE:   }
+// TYPE:   for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (odsParser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (!_loop_body(_paramKey)) return {};
+// TYPE:     if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
 // TYPE:       return {};
 // TYPE:   }
 // TYPE:   _result_v3 = someFcnCall();
-// TYPE:   if (failed(_result_v3))
+// TYPE:   if (::mlir::failed(_result_v3))
 // TYPE:     return {};
-// TYPE:   return TestDType::get(parser.getContext(),
-// TYPE:                         _result_v0.getValue(),
-// TYPE:                         _result_v1.getValue(),
-// TYPE:                         _result_v2.getValue(),
-// TYPE:                         _result_v3.getValue());
+// TYPE:   return TestDType::get(odsParser.getContext(),
+// TYPE:                         *_result_v0,
+// TYPE:                         *_result_v1,
+// TYPE:                         *_result_v2,
+// TYPE:                         *_result_v3);
 // TYPE: }
 
-// TYPE: void TestDType::print(::mlir::AsmPrinter &printer) const {
-// TYPE:   printer << getV0();
+// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   odsPrinter << getV0();
 // TYPE:   myPrinter(getV1());
-// TYPE:   printer << ' ' << "v2";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer << getV2();
+// TYPE:   odsPrinter << "v2 = ";
+// TYPE:   odsPrinter << getV2();
 // TYPE:   myPrinter(getV3());
 // TYPE: }
 
@@ -298,85 +295,86 @@ def TypeB : TestType<"TestD"> {
 /// Type test with two struct directives has correctly generated parser and
 /// printer.
 
-// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &parser) {
+// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &odsParser) {
 // TYPE:   FailureOr<IntegerAttr> _result_v0;
 // TYPE:   FailureOr<IntegerAttr> _result_v1;
 // TYPE:   FailureOr<IntegerAttr> _result_v2;
 // TYPE:   FailureOr<IntegerAttr> _result_v3;
 // TYPE:   bool _seen_v0 = false;
 // TYPE:   bool _seen_v2 = false;
-// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE:     StringRef _paramKey;
-// TYPE:     if (parser.parseKeyword(&_paramKey))
-// TYPE:       return {};
-// TYPE:     if (parser.parseEqual())
+// TYPE:   const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE:     if (odsParser.parseEqual())
 // TYPE:       return {};
 // TYPE:     if (!_seen_v0 && _paramKey == "v0") {
 // TYPE:       _seen_v0 = true;
-// TYPE:       _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE:       if (failed(_result_v0))
+// TYPE:       _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE:       if (::mlir::failed(_result_v0))
 // TYPE:         return {};
 // TYPE:     } else if (!_seen_v2 && _paramKey == "v2") {
 // TYPE:       _seen_v2 = true;
-// TYPE:       _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE:       if (failed(_result_v2))
+// TYPE:       _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE:       if (::mlir::failed(_result_v2))
 // TYPE:         return {};
 // TYPE:     } else  {
 // TYPE:       return {};
 // TYPE:     }
-// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:     return true;
+// TYPE:   }
+// TYPE:   for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (odsParser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (!_loop_body(_paramKey)) return {};
+// TYPE:     if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
 // TYPE:       return {};
 // TYPE:   }
 // TYPE:   bool _seen_v1 = false;
 // TYPE:   bool _seen_v3 = false;
-// TYPE:   for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE:     StringRef _paramKey;
-// TYPE:     if (parser.parseKeyword(&_paramKey))
-// TYPE:       return {};
-// TYPE:     if (parser.parseEqual())
+// TYPE:   const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE:     if (odsParser.parseEqual())
 // TYPE:       return {};
 // TYPE:     if (!_seen_v1 && _paramKey == "v1") {
 // TYPE:       _seen_v1 = true;
-// TYPE:       _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE:       if (failed(_result_v1))
+// TYPE:       _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE:       if (::mlir::failed(_result_v1))
 // TYPE:         return {};
 // TYPE:     } else if (!_seen_v3 && _paramKey == "v3") {
 // TYPE:       _seen_v3 = true;
-// TYPE:       _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE:       if (failed(_result_v3))
+// TYPE:       _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE:       if (::mlir::failed(_result_v3))
 // TYPE:         return {};
 // TYPE:     } else  {
 // TYPE:       return {};
 // TYPE:     }
-// TYPE:     if ((_index != 2 - 1) && parser.parseComma())
+// TYPE:     return true;
+// TYPE:   }
+// TYPE:   for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE:     StringRef _paramKey;
+// TYPE:     if (odsParser.parseKeyword(&_paramKey))
+// TYPE:       return {};
+// TYPE:     if (!_loop_body(_paramKey)) return {};
+// TYPE:     if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
 // TYPE:       return {};
 // TYPE:   }
-// TYPE:   return TestEType::get(parser.getContext(),
-// TYPE:     _result_v0.getValue(),
-// TYPE:     _result_v1.getValue(),
-// TYPE:     _result_v2.getValue(),
-// TYPE:     _result_v3.getValue());
+// TYPE:   return TestEType::get(odsParser.getContext(),
+// TYPE:     *_result_v0,
+// TYPE:     *_result_v1,
+// TYPE:     *_result_v2,
+// TYPE:     *_result_v3);
 // TYPE: }
 
-// TYPE: void TestEType::print(::mlir::AsmPrinter &printer) const {
-// TYPE:   printer << "v0";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer.printStrippedAttrOrType(getV0());
-// TYPE:   printer << ",";
-// TYPE:   printer << ' ' << "v2";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer.printStrippedAttrOrType(getV2());
-// TYPE:   printer << "v1";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer.printStrippedAttrOrType(getV1());
-// TYPE:   printer << ",";
-// TYPE:   printer << ' ' << "v3";
-// TYPE:   printer << ' ' << "=";
-// TYPE:   printer << ' ';
-// TYPE:   printer.printStrippedAttrOrType(getV3());
+// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   odsPrinter << "v0 = ";
+// TYPE:   odsPrinter.printStrippedAttrOrType(getV0());
+// TYPE:   odsPrinter << ", ";
+// TYPE:   odsPrinter << "v2 = ";
+// TYPE:   odsPrinter.printStrippedAttrOrType(getV2());
+// TYPE:   odsPrinter << ", ";
+// TYPE:   odsPrinter << "v1 = ";
+// TYPE:   odsPrinter.printStrippedAttrOrType(getV1());
+// TYPE:   odsPrinter << ", ";
+// TYPE:   odsPrinter << "v3 = ";
+// TYPE:   odsPrinter.printStrippedAttrOrType(getV3());
 // TYPE: }
 
 def TypeC : TestType<"TestE"> {
@@ -390,3 +388,99 @@ def TypeC : TestType<"TestE"> {
   let mnemonic = "type_e";
   let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
 }
+
+// TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE if (getA()) {
+// TYPE   printer << ' ';
+// TYPE   printer.printStrippedAttrOrType(getA());
+def TypeD : TestType<"TestF"> {
+  let parameters = (ins OptionalParameter<"int">:$a);
+  let mnemonic = "type_f";
+  let assemblyFormat = "$a";
+}
+
+// TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   if (::mlir::failed(_result_a))
+// TYPE:     return {};
+// TYPE:   if (::mlir::succeeded(_result_a) && *_result_a)
+// TYPE:     if (odsParser.parseComma())
+// TYPE:       return {};
+
+// TYPE: if (getA())
+// TYPE:   odsPrinter.printStrippedAttrOrType(getA());
+// TYPE: odsPrinter << ", ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getB());
+
+def TypeE : TestType<"TestG"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  let mnemonic = "type_g";
+  let assemblyFormat = "params";
+}
+
+
+// TYPE: ::mlir::Type TestHType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   do {
+// TYPE:     if (!_loop_body(_paramKey)) return {};
+// TYPE:   } while(!odsParser.parseOptionalComma());
+// TYPE:   if (!_seen_b)
+// TYPE:     return {};
+
+// TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   if (getA()) {
+// TYPE:     odsPrinter << "a = ";
+// TYPE:     odsPrinter.printStrippedAttrOrType(getA());
+// TYPE:     odsPrinter << ", ";
+// TYPE:   }
+
+def TypeF : TestType<"TestH"> {
+  let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+  let mnemonic = "type_h";
+  let assemblyFormat = "struct(params)";
+}
+
+
+// TYPE: do {
+// TYPE:   _result_a = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE:   if (::mlir::failed(_result_a))
+// TYPE:     return {};
+// TYPE:   if (odsParser.parseOptionalComma()) break;
+// TYPE:   _result_b = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE:   if (::mlir::failed(_result_b))
+// TYPE:     return {};
+// TYPE: } while(false);
+
+def TypeG : TestType<"TestI"> {
+  let parameters = (ins "int":$a, OptionalParameter<"int">:$b);
+  let mnemonic = "type_i";
+  let assemblyFormat = "params";
+}
+
+// TYPE: ::mlir::Type TestJType::parse(::mlir::AsmParser &odsParser) {
+// TYPE:   if (odsParser.parseOptionalLParen()) {
+// TYPE:     if (odsParser.parseKeyword("x")) return {};
+// TYPE:   } else {
+// TYPE:     _result_b = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE:     if (::mlir::failed(_result_b))
+// TYPE:       return {};
+// TYPE:     if (odsParser.parseRParen()) return {};
+// TYPE:   }
+// TYPE:   _result_a = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE:   if (::mlir::failed(_result_a))
+// TYPE:     return {};
+
+// TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE:   if (getB()) {
+// TYPE:     odsPrinter << "(";
+// TYPE:     if (getB())
+// TYPE:       odsPrinter.printStrippedAttrOrType(getB());
+// TYPE:     odsPrinter << ")";
+// TYPE:   } else {
+// TYPE:     odsPrinter << ' ' << "x";
+// TYPE:   }
+// TYPE:   odsPrinter.printStrippedAttrOrType(getA());
+
+def TypeH : TestType<"TestJ"> {
+  let parameters = (ins "int":$a, OptionalParameter<"int">:$b);
+  let mnemonic = "type_j";
+  let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a";
+}

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 34c8588225f70..2393fdae71ed6 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -67,8 +67,8 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
 // DECL:   return {"cmpnd_a"};
 // DECL: }
 // DECL: static ::mlir::Attribute parse(
-// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
 // DECL: int getWidthOfSomething() const;
 // DECL: ::test::SimpleTypeA getExampleTdType() const;
 // DECL: ::llvm::APFloat getApFloat() const;
@@ -107,8 +107,8 @@ def C_IndexAttr : TestAttr<"Index"> {
 // DECL:   return {"index"};
 // DECL: }
 // DECL: static ::mlir::Attribute parse(
-// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
 }
 
 def D_SingleParameterAttr : TestAttr<"SingleParameter"> {

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 733ca5e9b52f0..103dba3201eee 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -70,8 +70,8 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
 // DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
 // DECL:   return {"cmpnd_a"};
 // DECL: }
-// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
 // DECL: int getWidthOfSomething() const;
 // DECL: ::test::SimpleTypeA getExampleTdType() const;
 // DECL: SomeCppStruct getExampleCppType() const;
@@ -89,8 +89,8 @@ def C_IndexType : TestType<"Index"> {
 // DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
 // DECL:   return {"index"};
 // DECL: }
-// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
 }
 
 def D_SingleParameterType : TestType<"SingleParameter"> {

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index ce90cac72ce34..7e2b147478e17 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -266,9 +266,9 @@ void DefGen::emitParserPrinter() {
 
   // Declare the parser.
   SmallVector<MethodParameter> parserParams;
-  parserParams.emplace_back("::mlir::AsmParser &", "parser");
+  parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
   if (isa<AttrDef>(&def))
-    parserParams.emplace_back("::mlir::Type", "type");
+    parserParams.emplace_back("::mlir::Type", "odsType");
   auto *parser = defCls.addMethod(
       strfmt("::mlir::{0}", valueType), "parse",
       def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
@@ -278,7 +278,7 @@ void DefGen::emitParserPrinter() {
       def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
   Method *printer =
       defCls.addMethod("void", "print", props,
-                       MethodParameter("::mlir::AsmPrinter &", "printer"));
+                       MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
   // Emit the bodies.
   emitParserPrinterBody(parser->body(), printer->body());
 }
@@ -431,14 +431,15 @@ void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
   if (asmFormat)
     return generateAttrOrTypeFormat(def, parser, printer);
 
-  FmtContext ctx = FmtContext(
-      {{"_parser", "parser"}, {"_printer", "printer"}, {"_type", "type"}});
+  FmtContext ctx = FmtContext({{"_parser", "odsParser"},
+                               {"_printer", "odsPrinter"},
+                               {"_type", "odsType"}});
   if (parserCode) {
-    ctx.addSubst("_ctxt", "parser.getContext()");
+    ctx.addSubst("_ctxt", "odsParser.getContext()");
     parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
   }
   if (printerCode) {
-    ctx.addSubst("_ctxt", "printer.getContext()");
+    ctx.addSubst("_ctxt", "odsPrinter.getContext()");
     printer.indent().getStream().printReindented(
         tgfmt(*printerCode, &ctx).str());
   }

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 4ca20e59cc03b..d697ad2b4ea40 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -16,7 +16,9 @@
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/TableGen/Error.h"
 #include "llvm/TableGen/TableGenBackend.h"
@@ -48,36 +50,47 @@ class ParameterElement
     shouldBeQualifiedFlag = qualified;
   }
 
+  /// Returns true if the element contains an optional parameter.
+  bool isOptional() const { return param.isOptional(); }
+
+  /// Returns the name of the parameter.
+  StringRef getName() const { return param.getName(); }
+
 private:
   bool shouldBeQualifiedFlag = false;
   AttrOrTypeParameter param;
 };
 
+/// Shorthand functions that can be used with ranged-based conditions.
+static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
+static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
+
 /// Base class for a directive that contains references to multiple variables.
 template <DirectiveElement::Kind DirectiveKind>
 class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
 public:
   using Base = ParamsDirectiveBase<DirectiveKind>;
 
-  ParamsDirectiveBase(std::vector<FormatElement *> &&params)
+  ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
       : params(std::move(params)) {}
 
   /// Get the parameters contained in this directive.
-  auto getParams() const {
-    return llvm::map_range(params, [](FormatElement *el) {
-      return cast<ParameterElement>(el)->getParam();
-    });
-  }
+  ArrayRef<ParameterElement *> getParams() const { return params; }
 
   /// Get the number of parameters.
   unsigned getNumParams() const { return params.size(); }
 
   /// Take all of the parameters from this directive.
-  std::vector<FormatElement *> takeParams() { return std::move(params); }
+  std::vector<ParameterElement *> takeParams() { return std::move(params); }
+
+  /// Returns true if there are optional parameters present.
+  bool hasOptionalParams() const {
+    return llvm::any_of(getParams(), paramIsOptional);
+  }
 
 private:
   /// The parameters captured by this directive.
-  std::vector<FormatElement *> params;
+  std::vector<ParameterElement *> params;
 };
 
 /// This class represents a `params` directive that refers to all parameters
@@ -125,36 +138,9 @@ static const char *const qualifiedParameterPrinter = "$_printer << $_self";
 /// Print an error when failing to parse an element.
 ///
 /// $0: The parameter C++ class name.
-static const char *const parseErrorStr =
+static const char *const parserErrorStr =
     "$_parser.emitError($_parser.getCurrentLocation(), ";
 
-/// Loop declaration for struct parser.
-///
-/// $0: Number of expected parameters.
-static const char *const structParseLoopStart = R"(
-  for (unsigned _index = 0; _index < $0; ++_index) {
-    StringRef _paramKey;
-    if ($_parser.parseKeyword(&_paramKey)) {
-      $_parser.emitError($_parser.getCurrentLocation(),
-                         "expected a parameter name in struct");
-      return {};
-    }
-)";
-
-/// Terminator code segment for the struct parser loop. Check for duplicate or
-/// unknown parameters. Parse a comma except on the last element.
-///
-/// {0}: Code template for printing an error.
-/// {1}: Number of elements in the struct.
-static const char *const structParseLoopEnd = R"({{
-    {0}"duplicate or unknown struct parameter name: ") << _paramKey;
-    return {{};
-  }
-  if ((_index != {1} - 1) && parser.parseComma())
-    return {{};
-}
-)";
-
 /// Code format to parse a variable. Separate by lines because variable parsers
 /// may be generated inside other directives, which requires indentation.
 ///
@@ -166,21 +152,20 @@ static const char *const structParseLoopEnd = R"({{
 static const char *const variableParser = R"(
 // Parse variable '{0}'
 _result_{0} = {1};
-if (failed(_result_{0})) {{
+if (::mlir::failed(_result_{0})) {{
   {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
   return {{};
 }
 )";
 
 //===----------------------------------------------------------------------===//
-// AttrOrTypeFormat
+// DefFormat
 //===----------------------------------------------------------------------===//
 
 namespace {
-class AttrOrTypeFormat {
+class DefFormat {
 public:
-  AttrOrTypeFormat(const AttrOrTypeDef &def,
-                   std::vector<FormatElement *> &&elements)
+  DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
       : def(def), elements(std::move(elements)) {}
 
   /// Generate the attribute or type parser.
@@ -192,26 +177,36 @@ class AttrOrTypeFormat {
   /// Generate the parser code for a specific format element.
   void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for a literal.
-  void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os);
+  void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
+                        bool isOptional = false);
   /// Generate the parser code for a variable.
-  void genVariableParser(const AttrOrTypeParameter &param, FmtContext &ctx,
-                         MethodBody &os);
+  void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for a `params` directive.
   void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for a `struct` directive.
   void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
+  /// Generate the parser code for an optional group.
+  void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
+                              MethodBody &os);
 
   /// Generate the printer code for a specific format element.
   void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for a literal.
   void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for a variable.
-  void genVariablePrinter(const AttrOrTypeParameter &param, FmtContext &ctx,
-                          MethodBody &os, bool printQualified = false);
+  void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
+                          bool skipGuard = false);
+  /// Generate a printer for comma-separated parameters.
+  void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
+                                FmtContext &ctx, MethodBody &os,
+                                function_ref<void(ParameterElement *)> extra);
   /// Generate the printer code for a `params` directive.
   void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the printer code for a `struct` directive.
   void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
+  /// Generate the printer code for an optional group.
+  void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
+                               MethodBody &os);
 
   /// The ODS definition of the attribute or type whose format is being used to
   /// generate a parser and printer.
@@ -230,35 +225,44 @@ class AttrOrTypeFormat {
 // ParserGen
 //===----------------------------------------------------------------------===//
 
-void AttrOrTypeFormat::genParser(MethodBody &os) {
+void DefFormat::genParser(MethodBody &os) {
   FmtContext ctx;
-  ctx.addSubst("_parser", "parser");
+  ctx.addSubst("_parser", "odsParser");
   if (isa<AttrDef>(def))
-    ctx.addSubst("_type", "type");
+    ctx.addSubst("_type", "odsType");
   os.indent();
 
-  /// Declare variables to store all of the parameters. Allocated parameters
-  /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
-  /// FailureOr<T> to defer type construction for parameters that are parsed in
-  /// a loop (parsers return FailureOr anyways).
+  // Declare variables to store all of the parameters. Allocated parameters
+  // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
+  // FailureOr<T> to defer type construction for parameters that are parsed in
+  // a loop (parsers return FailureOr anyways).
   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
   for (const AttrOrTypeParameter &param : params) {
-    os << formatv("  ::mlir::FailureOr<{0}> _result_{1};\n",
+    os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
                   param.getCppStorageType(), param.getName());
   }
 
-  /// Store the initial location of the parser.
-  ctx.addSubst("_loc", "loc");
+  // Store the initial location of the parser.
+  ctx.addSubst("_loc", "odsLoc");
   os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
               "(void) $_loc;\n",
               &ctx);
 
-  /// Generate call to each parameter parser.
+  // Generate call to each parameter parser.
   for (FormatElement *el : elements)
     genElementParser(el, ctx, os);
 
-  /// Generate call to the attribute or type builder. Use the checked getter
-  /// if one was generated.
+  // Emit an assert for each mandatory parameter. Triggering an assert means
+  // the generated parser is incorrect (i.e. there is a bug in this code).
+  for (const AttrOrTypeParameter &param : params) {
+    if (!param.isOptional()) {
+      os << formatv("assert(::mlir::succeeded(_result_{0}));\n",
+                    param.getName());
+    }
+  }
+
+  // Generate call to the attribute or type builder. Use the checked getter
+  // if one was generated.
   if (def.genVerifyDecl()) {
     os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
                 &ctx, def.getCppClassName());
@@ -266,29 +270,38 @@ void AttrOrTypeFormat::genParser(MethodBody &os) {
     os << tgfmt("return $0::get($_parser.getContext()", &ctx,
                 def.getCppClassName());
   }
-  for (const AttrOrTypeParameter &param : params)
-    os << formatv(",\n    _result_{0}.getValue()", param.getName());
+  for (const AttrOrTypeParameter &param : params) {
+    if (param.isOptional())
+      os << formatv(",\n    _result_{0}.getValueOr({1}())", param.getName(),
+                    param.getCppStorageType());
+    else
+      os << formatv(",\n    *_result_{0}", param.getName());
+  }
   os << ");";
 }
 
-void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx,
-                                        MethodBody &os) {
+void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
+                                 MethodBody &os) {
   if (auto *literal = dyn_cast<LiteralElement>(el))
     return genLiteralParser(literal->getSpelling(), ctx, os);
   if (auto *var = dyn_cast<ParameterElement>(el))
-    return genVariableParser(var->getParam(), ctx, os);
+    return genVariableParser(var, ctx, os);
   if (auto *params = dyn_cast<ParamsDirective>(el))
     return genParamsParser(params, ctx, os);
   if (auto *strct = dyn_cast<StructDirective>(el))
     return genStructParser(strct, ctx, os);
+  if (auto *optional = dyn_cast<OptionalElement>(el))
+    return genOptionalGroupParser(optional, ctx, os);
 
   llvm_unreachable("unknown format element");
 }
 
-void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
-                                        MethodBody &os) {
+void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
+                                 MethodBody &os, bool isOptional) {
   os << "// Parse literal '" << value << "'\n";
   os << tgfmt("if ($_parser.parse", &ctx);
+  if (isOptional)
+    os << "Optional";
   if (value.front() == '_' || isalpha(value.front())) {
     os << "Keyword(\"" << value << "\")";
   } else {
@@ -310,70 +323,275 @@ void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
               .Case("*", "Star")
        << "()";
   }
-  os << ")\n";
+  if (isOptional) {
+    // Leave the `if` unclosed to guard optional groups.
+    return;
+  }
   // Parser will emit an error
-  os << "  return {};\n";
+  os << ") return {};\n";
 }
 
-void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter &param,
-                                         FmtContext &ctx, MethodBody &os) {
-  /// Check for a custom parser. Use the default attribute parser otherwise.
+void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
+                                  MethodBody &os) {
+  // Check for a custom parser. Use the default attribute parser otherwise.
+  const AttrOrTypeParameter &param = el->getParam();
   auto customParser = param.getParser();
   auto parser =
       customParser ? *customParser : StringRef(defaultParameterParser);
   os << formatv(variableParser, param.getName(),
                 tgfmt(parser, &ctx, param.getCppStorageType()),
-                tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType());
+                tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
 }
 
-void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
-                                       MethodBody &os) {
+void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
+                                MethodBody &os) {
   os << "// Parse parameter list\n";
-  llvm::interleave(
-      el->getParams(),
-      [&](auto param) { this->genVariableParser(param, ctx, os); },
-      [&]() { this->genLiteralParser(",", ctx, os); });
+
+  // If there are optional parameters, we need to switch to `parseOptionalComma`
+  // if there are no more required parameters after a certain point.
+  bool hasOptional = el->hasOptionalParams();
+  if (hasOptional) {
+    // Wrap everything in a do-while so that we can `break`.
+    os << "do {\n";
+    os.indent();
+  }
+
+  ArrayRef<ParameterElement *> params = el->getParams();
+  using IteratorT = ParameterElement *const *;
+  IteratorT it = params.begin();
+
+  // Find the last required parameter. Commas become optional aftewards.
+  // Note: IteratorT's copy assignment is deleted.
+  ParameterElement *lastReq = nullptr;
+  for (ParameterElement *param : params)
+    if (!param->isOptional())
+      lastReq = param;
+  IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
+
+  auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
+  auto betweenFn = [&](IteratorT it) {
+    ParameterElement *el = *std::prev(it);
+    // Parse a comma if the last optional parameter had a value.
+    if (el->isOptional()) {
+      os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n",
+                    el->getName());
+      os.indent();
+    }
+    if (it <= lastReqIt) {
+      genLiteralParser(",", ctx, os);
+    } else {
+      genLiteralParser(",", ctx, os, /*isOptional=*/true);
+      os << ") break;\n";
+    }
+    if (el->isOptional())
+      os.unindent() << "}\n";
+  };
+
+  // llvm::interleave
+  if (it != params.end()) {
+    eachFn(*it++);
+    for (IteratorT e = params.end(); it != e; ++it) {
+      betweenFn(it);
+      eachFn(*it);
+    }
+  }
+
+  if (hasOptional)
+    os.unindent() << "} while(false);\n";
 }
 
-void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
-                                       MethodBody &os) {
+void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
+                                MethodBody &os) {
+  // Loop declaration for struct parser with only required parameters.
+  //
+  // $0: Number of expected parameters.
+  const char *const loopHeader = R"(
+  for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
+)";
+
+  // Loop body start for struct parser.
+  const char *const loopStart = R"(
+    ::llvm::StringRef _paramKey;
+    if ($_parser.parseKeyword(&_paramKey)) {
+      $_parser.emitError($_parser.getCurrentLocation(),
+                         "expected a parameter name in struct");
+      return {};
+    }
+    if (!_loop_body(_paramKey)) return {};
+)";
+
+  // Struct parser loop end. Check for duplicate or unknown struct parameters.
+  //
+  // {0}: Code template for printing an error.
+  const char *const loopEnd = R"({{
+  {0}"duplicate or unknown struct parameter name: ") << _paramKey;
+  return {{};
+}
+)";
+
+  // Struct parser loop terminator. Parse a comma except on the last element.
+  //
+  // {0}: Number of elements in the struct.
+  const char *const loopTerminator = R"(
+  if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
+    return {{};
+}
+)";
+
+  // Check that a mandatory parameter was parse.
+  //
+  // {0}: Name of the parameter.
+  const char *const checkParam = R"(
+    if (!_seen_{0}) {
+      {1}"struct is missing required parameter: ") << "{0}";
+      return {{};
+    }
+)";
+
+  // Optional parameters in a struct must be parsed successfully if the
+  // keyword is present.
+  //
+  // {0}: Name of the parameter.
+  // {1}: Emit error string
+  const char *const checkOptionalParam = R"(
+    if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{
+      {1}"expected a value for parameter '{0}'");
+      return {{};
+    }
+)";
+
+  // First iteration of the loop parsing an optional struct.
+  const char *const optionalStructFirst = R"(
+  ::llvm::StringRef _paramKey;
+  if (!$_parser.parseOptionalKeyword(&_paramKey)) {
+    if (!_loop_body(_paramKey)) return {};
+    while (!$_parser.parseOptionalComma()) {
+)";
+
   os << "// Parse parameter struct\n";
 
-  /// Declare a "seen" variable for each key.
-  for (const AttrOrTypeParameter &param : el->getParams())
-    os << formatv("bool _seen_{0} = false;\n", param.getName());
+  // Declare a "seen" variable for each key.
+  for (ParameterElement *param : el->getParams())
+    os << formatv("bool _seen_{0} = false;\n", param->getName());
 
-  /// Generate the parsing loop.
-  os.getStream().printReindented(
-      tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str());
-  os.indent();
-  genLiteralParser("=", ctx, os);
-  for (const AttrOrTypeParameter &param : el->getParams()) {
+  // Generate the body of the parsing loop inside a lambda.
+  os << "{\n";
+  os.indent()
+      << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
+  genLiteralParser("=", ctx, os.indent());
+  for (ParameterElement *param : el->getParams()) {
     os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
                   "  _seen_{0} = true;\n",
-                  param.getName());
+                  param->getName());
     genVariableParser(param, ctx, os.indent());
+    if (param->isOptional()) {
+      os.getStream().printReindented(strfmt(checkOptionalParam,
+                                            param->getName(),
+                                            tgfmt(parserErrorStr, &ctx).str()));
+    }
     os.unindent() << "} else ";
+    // Print the check for duplicate or unknown parameter.
   }
+  os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
+  os << "return true;\n";
+  os.unindent() << "};\n";
+
+  // Generate the parsing loop. If optional parameters are present, then the
+  // parse loop is guarded by commas.
+  unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
+  if (numOptional) {
+    // If the struct itself is optional, pull out the first iteration.
+    if (numOptional == el->getNumParams()) {
+      os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
+      os.indent();
+    } else {
+      os << "do {\n";
+    }
+  } else {
+    os.getStream().printReindented(
+        tgfmt(loopHeader, &ctx, el->getNumParams()).str());
+  }
+  os.indent();
+  os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
   os.unindent();
 
-  /// Duplicate or unknown parameter.
-  os.getStream().printReindented(strfmt(
-      structParseLoopEnd, tgfmt(parseErrorStr, &ctx), el->getNumParams()));
+  // Print the loop terminator. For optional parameters, we have to check that
+  // all mandatory parameters have been parsed.
+  // The whole struct is optional if all its parameters are optional.
+  if (numOptional) {
+    if (numOptional == el->getNumParams()) {
+      os << "}\n";
+      os.unindent() << "}\n";
+    } else {
+      os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
+      for (ParameterElement *param : el->getParams()) {
+        if (param->isOptional())
+          continue;
+        os.getStream().printReindented(
+            strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
+      }
+    }
+  } else {
+    // Because the loop loops N times and each non-failing iteration sets 1 of
+    // N flags, successfully exiting the loop means that all parameters have
+    // been seen. `parseOptionalComma` would cause issues with any formats that
+    // use "struct(...) `,`" beacuse structs aren't sounded by braces.
+    os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
+  }
+  os.unindent() << "}\n";
+}
+
+void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
+                                       MethodBody &os) {
+  ArrayRef<FormatElement *> elements =
+      el->getThenElements().drop_front(el->getParseStart());
+
+  FormatElement *first = elements.front();
+  const auto guardOn = [&](auto params) {
+    os << "if (!(";
+    llvm::interleave(
+        params, os,
+        [&](ParameterElement *el) {
+          os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
+                        el->getName());
+        },
+        " || ");
+    os << ")) {\n";
+  };
+  if (auto *literal = dyn_cast<LiteralElement>(first)) {
+    genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
+    os << ") {\n";
+  } else if (auto *param = dyn_cast<ParameterElement>(first)) {
+    genVariableParser(param, ctx, os);
+    guardOn(llvm::makeArrayRef(param));
+  } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
+    genParamsParser(params, ctx, os);
+    guardOn(params->getParams());
+  } else {
+    auto *strct = cast<StructDirective>(first);
+    genStructParser(strct, ctx, os);
+    guardOn(params->getParams());
+  }
+  os.indent();
 
-  /// Because the loop loops N times and each non-failing iteration sets 1 of
-  /// N flags, successfully exiting the loop means that all parameters have been
-  /// seen. `parseOptionalComma` would cause issues with any formats that use
-  /// "struct(...) `,`" beacuse structs aren't sounded by braces.
+  // Generate the parsers for the rest of the elements.
+  for (FormatElement *element : el->getElseElements())
+    genElementParser(element, ctx, os);
+  os.unindent() << "} else {\n";
+  os.indent();
+  for (FormatElement *element : elements.drop_front())
+    genElementParser(element, ctx, os);
+  os.unindent() << "}\n";
 }
 
 //===----------------------------------------------------------------------===//
 // PrinterGen
 //===----------------------------------------------------------------------===//
 
-void AttrOrTypeFormat::genPrinter(MethodBody &os) {
+void DefFormat::genPrinter(MethodBody &os) {
   FmtContext ctx;
-  ctx.addSubst("_printer", "printer");
+  ctx.addSubst("_printer", "odsPrinter");
+  os.indent();
 
   /// Generate printers.
   shouldEmitSpace = true;
@@ -382,8 +600,8 @@ void AttrOrTypeFormat::genPrinter(MethodBody &os) {
     genElementPrinter(el, ctx, os);
 }
 
-void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
-                                         MethodBody &os) {
+void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
+                                  MethodBody &os) {
   if (auto *literal = dyn_cast<LiteralElement>(el))
     return genLiteralPrinter(literal->getSpelling(), ctx, os);
   if (auto *params = dyn_cast<ParamsDirective>(el))
@@ -391,63 +609,147 @@ void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
   if (auto *strct = dyn_cast<StructDirective>(el))
     return genStructPrinter(strct, ctx, os);
   if (auto *var = dyn_cast<ParameterElement>(el))
-    return genVariablePrinter(var->getParam(), ctx, os,
-                              var->shouldBeQualified());
+    return genVariablePrinter(var, ctx, os);
+  if (auto *optional = dyn_cast<OptionalElement>(el))
+    return genOptionalGroupPrinter(optional, ctx, os);
 
-  llvm_unreachable("unknown format element");
+  llvm::PrintFatalError("unsupported format element");
 }
 
-void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
-                                         MethodBody &os) {
-  /// Don't insert a space before certain punctuation.
+void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
+                                  MethodBody &os) {
+  // Don't insert a space before certain punctuation.
   bool needSpace =
       shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
-  os << tgfmt("  $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
+  os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
               value);
 
-  /// Update the flags.
+  // Update the flags.
   shouldEmitSpace =
       value.size() != 1 || !StringRef("<({[").contains(value.front());
   lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
 }
 
-void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter &param,
-                                          FmtContext &ctx, MethodBody &os,
-                                          bool printQualified) {
-  /// Insert a space before the next parameter, if necessary.
+void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
+                                   MethodBody &os, bool skipGuard) {
+  const AttrOrTypeParameter &param = el->getParam();
+  ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+
+  // Guard the printer on the presence of optional parameters.
+  if (el->isOptional() && !skipGuard) {
+    os << tgfmt("if ($_self) {\n", &ctx);
+    os.indent();
+  }
+
+  // Insert a space before the next parameter, if necessary.
   if (shouldEmitSpace || !lastWasPunctuation)
-    os << tgfmt("  $_printer << ' ';\n", &ctx);
+    os << tgfmt("$_printer << ' ';\n", &ctx);
   shouldEmitSpace = true;
   lastWasPunctuation = false;
 
-  ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
-  os << "  ";
-  if (printQualified)
+  if (el->shouldBeQualified())
     os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
   else if (auto printer = param.getPrinter())
     os << tgfmt(*printer, &ctx) << ";\n";
   else
     os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
+
+  if (el->isOptional() && !skipGuard)
+    os.unindent() << "}\n";
 }
 
-void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
-                                        MethodBody &os) {
-  llvm::interleave(
-      el->getParams(),
-      [&](auto param) { this->genVariablePrinter(param, ctx, os); },
-      [&] { this->genLiteralPrinter(",", ctx, os); });
+void DefFormat::genCommaSeparatedPrinter(
+    ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
+    function_ref<void(ParameterElement *)> extra) {
+  // Emit a space if necessary, but only if the struct is present.
+  if (shouldEmitSpace || !lastWasPunctuation) {
+    bool allOptional = llvm::all_of(params, paramIsOptional);
+    if (allOptional) {
+      os << "if (";
+      llvm::interleave(
+          params, os,
+          [&](ParameterElement *param) {
+            os << getParameterAccessorName(param->getName()) << "()";
+          },
+          " || ");
+      os << ") {\n";
+      os.indent();
+    }
+    os << tgfmt("$_printer << ' ';\n", &ctx);
+    if (allOptional)
+      os.unindent() << "}\n";
+  }
+
+  // The first printed element does not need to emit a comma.
+  os << "{\n";
+  os.indent() << "bool _firstPrinted = true;\n";
+  for (ParameterElement *param : params) {
+    if (param->isOptional()) {
+      os << tgfmt("if ($_self()) {\n",
+                  &ctx.withSelf(getParameterAccessorName(param->getName())));
+      os.indent();
+    }
+    os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
+    os << "_firstPrinted = false;\n";
+    extra(param);
+    shouldEmitSpace = false;
+    lastWasPunctuation = true;
+    genVariablePrinter(param, ctx, os);
+    if (param->isOptional())
+      os.unindent() << "}\n";
+  }
+  os.unindent() << "}\n";
+}
+
+void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+                                 MethodBody &os) {
+  genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
+                           [&](ParameterElement *param) {});
+}
+
+void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+                                 MethodBody &os) {
+  genCommaSeparatedPrinter(
+      llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
+        os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
+      });
 }
 
-void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
                                         MethodBody &os) {
-  llvm::interleave(
-      el->getParams(),
-      [&](auto param) {
-        this->genLiteralPrinter(param.getName(), ctx, os);
-        this->genLiteralPrinter("=", ctx, os);
-        this->genVariablePrinter(param, ctx, os);
-      },
-      [&] { this->genLiteralPrinter(",", ctx, os); });
+  // Emit the check on whether the group should be printed.
+  const auto guardOn = [&](auto params) {
+    os << "if (";
+    llvm::interleave(
+        params, os,
+        [&](ParameterElement *el) {
+          os << getParameterAccessorName(el->getName()) << "()";
+        },
+        " || ");
+    os << ") {\n";
+    os.indent();
+  };
+  FormatElement *anchor = el->getAnchor();
+  if (auto *param = dyn_cast<ParameterElement>(anchor)) {
+    guardOn(llvm::makeArrayRef(param));
+  } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
+    guardOn(params->getParams());
+  } else {
+    auto *strct = dyn_cast<StructDirective>(anchor);
+    guardOn(strct->getParams());
+  }
+  // Generate the printer for the contained elements.
+  {
+    llvm::SaveAndRestore<bool> shouldEmitSpaceFlag(shouldEmitSpace);
+    llvm::SaveAndRestore<bool> lastWasPunctuationFlag(lastWasPunctuation);
+    for (FormatElement *element : el->getThenElements())
+      genElementPrinter(element, ctx, os);
+  }
+  os.unindent() << "} else {\n";
+  os.indent();
+  for (FormatElement *element : el->getElseElements())
+    genElementPrinter(element, ctx, os);
+  os.unindent() << "}\n";
 }
 
 //===----------------------------------------------------------------------===//
@@ -462,7 +764,7 @@ class DefFormatParser : public FormatParser {
         seenParams(def.getNumParameters()) {}
 
   /// Parse the attribute or type format and create the format elements.
-  FailureOr<AttrOrTypeFormat> parse();
+  FailureOr<DefFormat> parse();
 
 protected:
   /// Verify the parsed elements.
@@ -476,9 +778,7 @@ class DefFormatParser : public FormatParser {
   /// Verify the elements of an optional group.
   LogicalResult
   verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
-                              Optional<unsigned> anchorIndex) override {
-    return emitError(loc, "optional groups not (yet) supported");
-  }
+                              Optional<unsigned> anchorIndex) override;
 
   /// Parse an attribute or type variable.
   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
@@ -505,30 +805,76 @@ class DefFormatParser : public FormatParser {
 
 LogicalResult DefFormatParser::verify(SMLoc loc,
                                       ArrayRef<FormatElement *> elements) {
+  // Check that all parameters are referenced in the format.
   for (auto &it : llvm::enumerate(def.getParameters())) {
-    if (!seenParams.test(it.index())) {
+    if (!it.value().isOptional() && !seenParams.test(it.index())) {
       return emitError(loc, "format is missing reference to parameter: " +
                                 it.value().getName());
     }
   }
+  // A `struct` directive that contains optional parameters cannot be followed
+  // by a comma literal, which is ambiguous.
+  for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
+    auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
+    auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
+    if (!structEl || !literalEl)
+      continue;
+    if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
+      return emitError(loc, "`struct` directive with optional parameters "
+                            "cannot be followed by a comma literal");
+    }
+  }
+  return success();
+}
+
+LogicalResult
+DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
+                                             ArrayRef<FormatElement *> elements,
+                                             Optional<unsigned> anchorIndex) {
+  // `params` and `struct` directives are allowed only if all the contained
+  // parameters are optional.
+  for (FormatElement *el : elements) {
+    if (auto *param = dyn_cast<ParameterElement>(el)) {
+      if (!param->isOptional()) {
+        return emitError(loc,
+                         "parameters in an optional group must be optional");
+      }
+    } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
+      if (llvm::any_of(params->getParams(), paramNotOptional)) {
+        return emitError(loc, "`params` directive allowed in optional group "
+                              "only if all parameters are optional");
+      }
+    } else if (auto *strct = dyn_cast<StructDirective>(el)) {
+      if (llvm::any_of(strct->getParams(), paramNotOptional)) {
+        return emitError(loc, "`struct` is only allowed in an optional group "
+                              "if all captured parameters are optional");
+      }
+    }
+  }
+  // The anchor must be a parameter or one of the aforementioned directives.
+  if (anchorIndex && !isa<ParameterElement, ParamsDirective, StructDirective>(
+                         elements[*anchorIndex])) {
+    return emitError(loc,
+                     "optional group anchor must be a parameter or directive");
+  }
   return success();
 }
 
-FailureOr<AttrOrTypeFormat> DefFormatParser::parse() {
+FailureOr<DefFormat> DefFormatParser::parse() {
   FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
   if (failed(elements))
     return failure();
-  return AttrOrTypeFormat(def, std::move(*elements));
+  return DefFormat(def, std::move(*elements));
 }
 
 FailureOr<FormatElement *>
 DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
-  /// Lookup the parameter.
+  // Lookup the parameter.
   ArrayRef<AttrOrTypeParameter> params = def.getParameters();
   auto *it = llvm::find_if(
       params, [&](auto &param) { return param.getName() == name; });
 
-  /// Check that the parameter reference is valid.
+  // Check that the parameter reference is valid.
   if (it == params.end()) {
     return emitError(loc,
                      def.getName() + " has no parameter named '" + name + "'");
@@ -581,9 +927,9 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
 }
 
 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
-  /// Collect all of the attribute's or type's parameters.
-  std::vector<FormatElement *> vars;
-  /// Ensure that none of the parameters have already been captured.
+  // Collect all of the attribute's or type's parameters.
+  std::vector<ParameterElement *> vars;
+  // Ensure that none of the parameters have already been captured.
   for (const auto &it : llvm::enumerate(def.getParameters())) {
     if (seenParams.test(it.index())) {
       return emitError(loc, "`params` captures duplicate parameter: " +
@@ -600,27 +946,27 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
                         "expected '(' before `struct` argument list")))
     return failure();
 
-  /// Parse variables captured by `struct`.
-  std::vector<FormatElement *> vars;
+  // Parse variables captured by `struct`.
+  std::vector<ParameterElement *> vars;
 
-  /// Parse first captured parameter or a `params` directive.
+  // Parse first captured parameter or a `params` directive.
   FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
   if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
     return emitError(loc,
                      "`struct` argument list expected a variable or directive");
   }
   if (isa<VariableElement>(*var)) {
-    /// Parse any other parameters.
-    vars.push_back(std::move(*var));
+    // Parse any other parameters.
+    vars.push_back(cast<ParameterElement>(*var));
     while (peekToken().is(FormatToken::comma)) {
       consumeToken();
       var = parseElement(StructDirectiveContext);
       if (failed(var) || !isa<VariableElement>(*var))
         return emitError(loc, "expected a variable in `struct` argument list");
-      vars.push_back(std::move(*var));
+      vars.push_back(cast<ParameterElement>(*var));
     }
   } else {
-    /// `struct(params)` captures all parameters in the attribute or type.
+    // `struct(params)` captures all parameters in the attribute or type.
     vars = cast<ParamsDirective>(*var)->takeParams();
   }
 
@@ -642,16 +988,16 @@ void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
   mgr.AddNewSourceBuffer(
       llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
 
-  /// Parse the custom assembly format>
+  // Parse the custom assembly format>
   DefFormatParser fmtParser(mgr, def);
-  FailureOr<AttrOrTypeFormat> format = fmtParser.parse();
+  FailureOr<DefFormat> format = fmtParser.parse();
   if (failed(format)) {
     if (formatErrorIsFatal)
       PrintFatalError(def.getLoc(), "failed to parse assembly format");
     return;
   }
 
-  /// Generate the parser and printer.
+  // Generate the parser and printer.
   format->genParser(parser);
   format->genPrinter(printer);
 }


        


More information about the Mlir-commits mailing list