[Mlir-commits] [mlir] 0748639 - [mlir][ods] Optional Attribute or Type Parameters
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 8 12:09:49 PST 2022
Author: Mogball
Date: 2022-02-08T20:09:44Z
New Revision: 07486395d2d05c9c567994456774cafdcc1611d0
URL: https://github.com/llvm/llvm-project/commit/07486395d2d05c9c567994456774cafdcc1611d0
DIFF: https://github.com/llvm/llvm-project/commit/07486395d2d05c9c567994456774cafdcc1611d0.diff
LOG: [mlir][ods] Optional Attribute or Type Parameters
Implements optional attribute or type parameters, including support for such parameters in the assembly format `struct` directive. Also implements optional groups.
Depends on D117971
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D118208
Added:
Modified:
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 3212267b64dd3..1501b54af5279 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -1,9 +1,9 @@
# Defining Dialect Attributes and Types
This document is a quickstart to defining dialect specific extensions to the
-[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system) systems in
-MLIR. The main part of this tutorial focuses on defining types, but the
-instructions are nearly identical for defining attributes.
+[attribute](../LangRef.md/#attributes) and [type](../LangRef.md/#type-system)
+systems in MLIR. The main part of this tutorial focuses on defining types, but
+the instructions are nearly identical for defining attributes.
See [MLIR specification](../LangRef.md) for more information about MLIR, the
structure of the IR, operations, etc.
@@ -24,18 +24,19 @@ defining a new `Type` it isn't always necessary to define a new storage class.
So before defining the derived `Type`, it's important to know which of the two
classes of `Type` we are defining:
-Some types are _singleton_ in nature, meaning they have no parameters and only
-ever have one instance, like the [`index` type](../Dialects/Builtin.md/#indextype).
+Some types are *singleton* in nature, meaning they have no parameters and only
+ever have one instance, like the
+[`index` type](../Dialects/Builtin.md/#indextype).
-Other types are _parametric_, and contain additional information that
+Other types are *parametric*, and contain additional information that
diff erentiates
diff erent instances of the same `Type`. For example the
-[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with `i8` and
-`i16` representing
diff erent instances of
-[`integer` type](../Dialects/Builtin.md/#integertype). _Parametric_ may also contain a
-mutable component, which can be used, for example, to construct self-referring
-recursive types. The mutable component _cannot_ be used to
diff erentiate
-instances of a type class, so usually such types contain other parametric
-components that serve to identify them.
+[`integer` type](../Dialects/Builtin.md/#integertype) contains a bitwidth, with
+`i8` and `i16` representing
diff erent instances of
+[`integer` type](../Dialects/Builtin.md/#integertype). *Parametric* may also
+contain a mutable component, which can be used, for example, to construct
+self-referring recursive types. The mutable component *cannot* be used to
+
diff erentiate instances of a type class, so usually such types contain other
+parametric components that serve to identify them.
#### Singleton types
@@ -389,12 +390,12 @@ Attributes and types defined in ODS with a mnemonic can define an
`assemblyFormat` to declaratively describe custom parsers and printers. The
assembly format consists of literals, variables, and directives.
-* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
- `` `keyword` `` or `` `<` ``.
-* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
- which captures one attribute or type parameter.
-* A directive is a keyword followed by an optional argument list that defines
- special parser and printer behaviour.
+* A literal is a keyword or valid punctuation enclosed in backticks, e.g. ``
+ `keyword` `` or `` `<` ``.
+* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
+ which captures one attribute or type parameter.
+* A directive is a keyword followed by an optional argument list that defines
+ special parser and printer behaviour.
```tablegen
// An example type with an assembly format.
@@ -412,8 +413,8 @@ def MyType : TypeDef<My_Dialect, "MyType"> {
}
```
-The declarative assembly format for `MyType` results in the following format
-in the IR:
+The declarative assembly format for `MyType` results in the following format in
+the IR:
```mlir
!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
@@ -421,15 +422,15 @@ in the IR:
### Parameter Parsing and Printing
-For many basic parameter types, no additional work is needed to define how
-these parameters are parsed or printed.
+For many basic parameter types, no additional work is needed to define how these
+parameters are parsed or printed.
-* The default printer for any parameter is `$_printer << $_self`,
- where `$_self` is the C++ value of the parameter and `$_printer` is an
- `AsmPrinter`.
-* The default parser for a parameter is
- `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
- of the parameter and `$_parser` is an `AsmParser`.
+* The default printer for any parameter is `$_printer << $_self`, where
+ `$_self` is the C++ value of the parameter and `$_printer` is an
+ `AsmPrinter`.
+* The default parser for a parameter is
+ `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
+ of the parameter and `$_parser` is an `AsmParser`.
Printing and parsing behaviour can be added to additional C++ types by
overloading these functions or by defining a `parser` and `printer` in an ODS
@@ -470,8 +471,8 @@ def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
}
```
-A type using this parameter with the assembly format `` `<` $myParam `>` ``
-will look as follows in the IR:
+A type using this parameter with the assembly format `` `<` $myParam `>` `` will
+look as follows in the IR:
```mlir
!my_dialect.my_type<42 * 24>
@@ -480,10 +481,42 @@ will look as follows in the IR:
#### Non-POD Parameters
Parameters that aren't plain-old-data (e.g. references) may need to define a
-`cppStorageType` to contain the data until it is copied into the allocator.
-For example, `StringRefParameter` uses `std::string` as its storage type,
-whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
-for these parameters are expected to return `FailureOr<$cppStorageType>`.
+`cppStorageType` to contain the data until it is copied into the allocator. For
+example, `StringRefParameter` uses `std::string` as its storage type, whereas
+`ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for
+these parameters are expected to return `FailureOr<$cppStorageType>`.
+
+#### Optional Parameters
+
+Optional parameters in the assembly format can be indicated by setting
+`isOptional`. The C++ type of an optional parameter is required to satisfy the
+following requirements:
+
+* is default-constructible
+* is contextually convertible to `bool`
+* only the default-constructed value is `false`
+
+The parameter parser should return the default-constructed value to indicate "no
+value present". The printer will guard on the presence of a value to print the
+parameter.
+
+If a value was not parsed for an optional parameter, then the parameter will be
+set to its default-constructed C++ value. For example, `Optional<int>` will be
+set to `llvm::None` and `Attribute` will be set to `nullptr`.
+
+Only optional parameters or directives that only capture optional parameters can
+be used in optional groups. An optional group is a set of elements optionally
+printed based on the presence of an anchor. Suppose parameter `a` is an
+`IntegerAttr`.
+
+```
+( `(` $a^ `)` ) : (`x`)?
+```
+
+In the above assembly format, if `a` is present (non-null), then it will be
+printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that
+are used inside optional groups are allowed only if all captured parameters are
+also optional.
### Assembly Format Directives
@@ -497,9 +530,9 @@ Attribute and type assembly formats have the following directives:
#### `params` Directive
-This directive is used to refer to all parameters of an attribute or type.
-When used as a top-level directive, `params` generates a parser and printer for
-a comma-separated list of the parameters. For example:
+This directive is used to refer to all parameters of an attribute or type. When
+used as a top-level directive, `params` generates a parser and printer for a
+comma-separated list of the parameters. For example:
```tablegen
def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
@@ -547,12 +580,16 @@ In the IR, the types will appear as:
!my_dialect.outer_qual<pair : !mydialect.pair<42, 24>>
```
+If optional parameters are present, they are not printed in the parameter list
+if they are not present.
+
#### `struct` Directive
The `struct` directive accepts a list of variables to capture and will generate
-a parser and printer for a comma-separated list of key-value pairs. The
-variables are printed in the order they are specified in the argument list **but
-can be parsed in any order**. For example:
+a parser and printer for a comma-separated list of key-value pairs. If an
+optional parameter is included in the `struct`, it can be elided. The variables
+are printed in the order they are specified in the argument list **but can be
+parsed in any order**. For example:
```tablegen
def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 064145724ab72..bb108f714a789 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -3130,15 +3130,19 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
string summary = desc;
// The format string for the asm syntax (documentation only).
string syntax = ?;
- // The default parameter parser is `::mlir::parseField<T>($_parser)`, which
- // returns `FailureOr<T>`. Overload `parseField` to support parsing for your
- // type. Or you can provide a customer printer. For attributes, "$_type" will
- // be replaced with the required attribute type.
+ // The default parameter parser is `::mlir::FieldParser<T>::parse($_parser)`,
+ // which returns `FailureOr<T>`. Specialize `FieldParser` to support parsing
+ // for your type. Or you can provide a customer printer. For attributes,
+ // "$_type" will be replaced with the required attribute type.
string parser = ?;
// The default parameter printer is `$_printer << $_self`. Overload the stream
// operator of `AsmPrinter` as necessary to print your type. Or you can
// provide a custom printer.
string printer = ?;
+ // Mark a parameter as optional. The C++ type of parameters marked as optional
+ // must be default constructible and be contextually convertible to `bool`.
+ // Any `Optional<T>` and any attribute type satisfies these requirements.
+ bit isOptional = 0;
}
class AttrParameter<string type, string desc, string accessorType = "">
: AttrOrTypeParameter<type, desc, accessorType>;
@@ -3183,6 +3187,12 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
}];
}
+// An optional parameter.
+class OptionalParameter<string type, string desc = ""> :
+ AttrOrTypeParameter<type, desc> {
+ let isOptional = 1;
+}
+
// This is a special parameter used for AttrDefs that represents a `mlir::Type`
// that is also used as the value `Type` of the attribute. Only one parameter
// of the attribute may be of this type.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 7be684d4c5343..dcfdc8ab28a63 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -45,47 +45,50 @@ class AttrOrTypeBuilder : public Builder {
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
-// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
-// AttrOrTypeDefs to parameterize them.
+/// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to
+/// AttrOrTypeDefs to parameterize them.
class AttrOrTypeParameter {
public:
explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
: def(def), index(index) {}
- // Get the parameter name.
+ /// Get the parameter name.
StringRef getName() const;
- // If specified, get the custom allocator code for this parameter.
+ /// If specified, get the custom allocator code for this parameter.
Optional<StringRef> getAllocator() const;
- // If specified, get the custom comparator code for this parameter.
+ /// If specified, get the custom comparator code for this parameter.
Optional<StringRef> getComparator() const;
- // Get the C++ type of this parameter.
+ /// Get the C++ type of this parameter.
StringRef getCppType() const;
- // Get the C++ accessor type of this parameter.
+ /// Get the C++ accessor type of this parameter.
StringRef getCppAccessorType() const;
- // Get the C++ storage type of this parameter.
+ /// Get the C++ storage type of this parameter.
StringRef getCppStorageType() const;
- // Get an optional C++ parameter parser.
+ /// Get an optional C++ parameter parser.
Optional<StringRef> getParser() const;
- // Get an optional C++ parameter printer.
+ /// Get an optional C++ parameter printer.
Optional<StringRef> getPrinter() const;
- // Get a description of this parameter for documentation purposes.
+ /// Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const;
- // Get the assembly syntax documentation.
+ /// Get the assembly syntax documentation.
StringRef getSyntax() const;
- // Return the underlying def of this parameter.
- const llvm::Init *getDef() const;
+ /// Returns true if the parameter is optional.
+ bool isOptional() const;
- // The parameter is pointer-comparable.
+ /// Return the underlying def of this parameter.
+ llvm::Init *getDef() const;
+
+ /// The parameter is pointer-comparable.
bool operator==(const AttrOrTypeParameter &other) const {
return def == other.def && index == other.index;
}
@@ -94,6 +97,11 @@ class AttrOrTypeParameter {
}
private:
+ /// A parameter can be either a string or a def. Get a potentially null value
+ /// from the def.
+ template <typename InitT>
+ auto getDefValue(StringRef name) const;
+
/// The underlying tablegen parameter list this parameter is a part of.
const llvm::DagInit *def;
/// The index of the parameter within the parameter list (`def`).
@@ -121,113 +129,113 @@ class AttrOrTypeDef {
public:
explicit AttrOrTypeDef(const llvm::Record *def);
- // Get the dialect for which this def belongs.
+ /// Get the dialect for which this def belongs.
Dialect getDialect() const;
- // Returns the name of this AttrOrTypeDef record.
+ /// Returns the name of this AttrOrTypeDef record.
StringRef getName() const;
- // Query functions for the documentation of the def.
+ /// Query functions for the documentation of the def.
bool hasDescription() const;
StringRef getDescription() const;
bool hasSummary() const;
StringRef getSummary() const;
- // Returns the name of the C++ class to generate.
+ /// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
- // Returns the name of the C++ base class to use when generating this def.
+ /// Returns the name of the C++ base class to use when generating this def.
StringRef getCppBaseClassName() const;
- // Returns the name of the storage class for this def.
+ /// Returns the name of the storage class for this def.
StringRef getStorageClassName() const;
- // Returns the C++ namespace for this def's storage class.
+ /// Returns the C++ namespace for this def's storage class.
StringRef getStorageNamespace() const;
- // Returns true if we should generate the storage class.
+ /// Returns true if we should generate the storage class.
bool genStorageClass() const;
- // Indicates whether or not to generate the storage class constructor.
+ /// Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const;
/// Get the parameters of this attribute or type.
ArrayRef<AttrOrTypeParameter> getParameters() const { return parameters; }
- // Return the number of parameters
+ /// Return the number of parameters
unsigned getNumParameters() const;
- // Return the keyword/mnemonic to use in the printer/parser methods if we are
- // supposed to auto-generate them.
+ /// Return the keyword/mnemonic to use in the printer/parser methods if we are
+ /// supposed to auto-generate them.
Optional<StringRef> getMnemonic() const;
- // Returns the code to use as the types printer method. If not specified,
- // return a non-value. Otherwise, return the contents of that code block.
+ /// Returns the code to use as the types printer method. If not specified,
+ /// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const;
- // Returns the code to use as the parser method. If not specified, returns
- // None. Otherwise, returns the contents of that code block.
+ /// Returns the code to use as the parser method. If not specified, returns
+ /// None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
- // Returns the custom assembly format, if one was specified.
+ /// Returns the custom assembly format, if one was specified.
Optional<StringRef> getAssemblyFormat() const;
- // An attribute or type with parameters needs a parser.
+ /// An attribute or type with parameters needs a parser.
bool needsParserPrinter() const { return getNumParameters() != 0; }
- // Returns true if this attribute or type has a generated parser.
+ /// Returns true if this attribute or type has a generated parser.
bool hasGeneratedParser() const {
return getParserCode() || getAssemblyFormat();
}
- // Returns true if this attribute or type has a generated printer.
+ /// Returns true if this attribute or type has a generated printer.
bool hasGeneratedPrinter() const {
return getPrinterCode() || getAssemblyFormat();
}
- // Returns true if the accessors based on the parameters should be generated.
+ /// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
- // Return true if we need to generate the verify declaration and getChecked
- // method.
+ /// Return true if we need to generate the verify declaration and getChecked
+ /// method.
bool genVerifyDecl() const;
- // Returns the def's extra class declaration code.
+ /// Returns the def's extra class declaration code.
Optional<StringRef> getExtraDecls() const;
- // Get the code location (for error printing).
+ /// Get the code location (for error printing).
ArrayRef<SMLoc> getLoc() const;
- // Returns true if the default get/getChecked methods should be skipped during
- // generation.
+ /// Returns true if the default get/getChecked methods should be skipped
+ /// during generation.
bool skipDefaultBuilders() const;
- // Returns the builders of this def.
+ /// Returns the builders of this def.
ArrayRef<AttrOrTypeBuilder> getBuilders() const { return builders; }
- // Returns the traits of this def.
+ /// Returns the traits of this def.
ArrayRef<Trait> getTraits() const { return traits; }
- // Returns whether two AttrOrTypeDefs are equal by checking the equality of
- // the underlying record.
+ /// Returns whether two AttrOrTypeDefs are equal by checking the equality of
+ /// the underlying record.
bool operator==(const AttrOrTypeDef &other) const;
- // Compares two AttrOrTypeDefs by comparing the names of the dialects.
+ /// Compares two AttrOrTypeDefs by comparing the names of the dialects.
bool operator<(const AttrOrTypeDef &other) const;
- // Returns whether the AttrOrTypeDef is defined.
+ /// Returns whether the AttrOrTypeDef is defined.
operator bool() const { return def != nullptr; }
- // Return the underlying def.
+ /// Return the underlying def.
const llvm::Record *getDef() const { return def; }
protected:
const llvm::Record *def;
- // The builders of this definition.
+ /// The builders of this definition.
SmallVector<AttrOrTypeBuilder> builders;
- // The traits of this definition.
+ /// The traits of this definition.
SmallVector<Trait> traits;
/// The parameters of this attribute or type.
@@ -243,8 +251,8 @@ class AttrDef : public AttrOrTypeDef {
public:
using AttrOrTypeDef::AttrOrTypeDef;
- // Returns the attributes value type builder code block, or None if it doesn't
- // have one.
+ /// Returns the attributes value type builder code block, or None if it
+ /// doesn't have one.
Optional<StringRef> getTypeBuilder() const;
static bool classof(const AttrOrTypeDef *def);
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c3a14bb5f1ddf..2d92a5837c79f 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -177,32 +177,30 @@ bool AttrDef::classof(const AttrOrTypeDef *def) {
// AttrOrTypeParameter
//===----------------------------------------------------------------------===//
+template <typename InitT>
+auto AttrOrTypeParameter::getDefValue(StringRef name) const {
+ Optional<decltype(std::declval<InitT>().getValue())> result;
+ if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+ if (auto *init = param->getDef()->getValue(name))
+ if (auto *value = dyn_cast_or_null<InitT>(init->getValue()))
+ result = value->getValue();
+ return result;
+}
+
StringRef AttrOrTypeParameter::getName() const {
return def->getArgName(index)->getValue();
}
Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
- llvm::Init *parameterType = def->getArg(index);
- if (isa<llvm::StringInit>(parameterType))
- return Optional<StringRef>();
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
- return param->getDef()->getValueAsOptionalString("allocator");
- llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
- "defs which inherit from AttrOrTypeParameter\n");
+ return getDefValue<llvm::StringInit>("allocator");
}
Optional<StringRef> AttrOrTypeParameter::getComparator() const {
- llvm::Init *parameterType = def->getArg(index);
- if (isa<llvm::StringInit>(parameterType))
- return Optional<StringRef>();
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
- return param->getDef()->getValueAsOptionalString("comparator");
- llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
- "defs which inherit from AttrOrTypeParameter\n");
+ return getDefValue<llvm::StringInit>("comparator");
}
StringRef AttrOrTypeParameter::getCppType() const {
- auto *parameterType = def->getArg(index);
+ llvm::Init *parameterType = getDef();
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *param = dyn_cast<llvm::DefInit>(parameterType))
@@ -213,74 +211,45 @@ StringRef AttrOrTypeParameter::getCppType() const {
}
StringRef AttrOrTypeParameter::getCppAccessorType() const {
- if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
- if (Optional<StringRef> type =
- param->getDef()->getValueAsOptionalString("cppAccessorType"))
- return *type;
- }
- return getCppType();
+ return getDefValue<llvm::StringInit>("cppAccessorType")
+ .getValueOr(getCppType());
}
StringRef AttrOrTypeParameter::getCppStorageType() const {
- if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
- if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
- return *type;
- }
- return getCppType();
+ return getDefValue<llvm::StringInit>("cppStorageType")
+ .getValueOr(getCppType());
}
Optional<StringRef> AttrOrTypeParameter::getParser() const {
- auto *parameterType = def->getArg(index);
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
- if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
- return *parser;
- }
- return {};
+ return getDefValue<llvm::StringInit>("parser");
}
Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
- auto *parameterType = def->getArg(index);
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
- if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
- return *printer;
- }
- return {};
+ return getDefValue<llvm::StringInit>("printer");
}
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
- auto *parameterType = def->getArg(index);
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
- const auto *desc = param->getDef()->getValue("summary");
- if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
- return ci->getValue();
- }
- return Optional<StringRef>();
+ return getDefValue<llvm::StringInit>("summary");
}
StringRef AttrOrTypeParameter::getSyntax() const {
- auto *parameterType = def->getArg(index);
- if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
+ if (auto *stringType = dyn_cast<llvm::StringInit>(getDef()))
return stringType->getValue();
- if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
- const auto *syntax = param->getDef()->getValue("syntax");
- if (syntax && isa<llvm::StringInit>(syntax->getValue()))
- return cast<llvm::StringInit>(syntax->getValue())->getValue();
- return getCppType();
- }
- llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
- "defs which inherit from AttrOrTypeParameter");
+ return getDefValue<llvm::StringInit>("syntax").getValueOr(getCppType());
}
-const llvm::Init *AttrOrTypeParameter::getDef() const {
- return def->getArg(index);
+bool AttrOrTypeParameter::isOptional() const {
+ return getDefValue<llvm::BitInit>("isOptional").getValueOr(false);
}
+llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
+
//===----------------------------------------------------------------------===//
// AttributeSelfTypeParameter
//===----------------------------------------------------------------------===//
bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) {
- const llvm::Init *paramDef = param->getDef();
+ llvm::Init *paramDef = param->getDef();
if (auto *paramDefInit = dyn_cast<llvm::DefInit>(paramDef))
return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter");
return false;
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 3a052dc98ed0c..8751ee27eef42 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/DataLayoutInterfaces.td"
// All of the types will extend this class.
class Test_Type<string name, list<Trait> traits = []>
- : TypeDef<Test_Dialect, name, traits>;
+ : TypeDef<Test_Dialect, name, traits>;
def SimpleTypeA : Test_Type<"SimpleA"> {
let mnemonic = "smpla";
@@ -42,7 +42,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
ArrayRefParameter<
"int", // The parameter C++ type.
"An example of an array of ints" // Parameter description.
- >: $arrayOfInts
+ >:$arrayOfInts
);
let extraClassDeclaration = [{
@@ -110,14 +110,14 @@ def IntegerType : Test_Type<"TestInteger"> {
// The parser is defined here also.
let parser = [{
- if (parser.parseLess()) return Type();
+ if ($_parser.parseLess()) return Type();
SignednessSemantics signedness;
- if (parseSignedness($_parser, signedness)) return mlir::Type();
+ if (parseSignedness($_parser, signedness)) return Type();
if ($_parser.parseComma()) return Type();
int width;
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
- ::mlir::Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
+ Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
return getChecked(loc, loc.getContext(), width, signedness);
}];
@@ -262,4 +262,66 @@ def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
let assemblyFormat = "`<` struct(params) `>`";
}
+def TestTypeOptionalParam : Test_Type<"TestTypeOptionalParam"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a, "int":$b);
+ let mnemonic = "optional_param";
+ let assemblyFormat = "`<` $a `,` $b `>`";
+}
+
+def TestTypeOptionalParams : Test_Type<"TestTypeOptionalParams"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ StringRefParameter<>:$b);
+ let mnemonic = "optional_params";
+ let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeOptionalParamsAfterRequired
+ : Test_Type<"TestTypeOptionalParamsAfterRequired"> {
+ let parameters = (ins StringRefParameter<>:$a,
+ OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "optional_params_after";
+ let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeOptionalStruct : Test_Type<"TestTypeOptionalStruct"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ StringRefParameter<>:$b);
+ let mnemonic = "optional_struct";
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def TestTypeAllOptionalParams : Test_Type<"TestTypeAllOptionalParams"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "all_optional_params";
+ let assemblyFormat = "`<` params `>`";
+}
+
+def TestTypeAllOptionalStruct : Test_Type<"TestTypeAllOptionalStruct"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "all_optional_struct";
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> {
+ let parameters = (ins "int":$a, OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "optional_group";
+ let assemblyFormat = "`<` (`(` $b^ `)`) : (`x`)? $a `>`";
+}
+
+def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "optional_group_params";
+ let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`";
+}
+
+def TestTypeOptionalGroupStruct : Test_Type<"TestTypeOptionalGroupStruct"> {
+ let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ OptionalParameter<"mlir::Optional<int>">:$b);
+ let mnemonic = "optional_group_struct";
+ let assemblyFormat = "`<` (`(` struct(params)^ `)`) : (`x`)? `>`";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 91650afebb9b1..d7b1163f7b39d 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -64,11 +64,28 @@ struct FieldParser<test::CustomParam> {
return test::CustomParam{value.getValue()};
}
};
+
inline mlir::AsmPrinter &operator<<(mlir::AsmPrinter &printer,
test::CustomParam param) {
return printer << param.value;
}
+/// Overload the attribute parameter parser for optional integers.
+template <>
+struct FieldParser<Optional<int>> {
+ static FailureOr<Optional<int>> parse(AsmParser &parser) {
+ Optional<int> value;
+ value.emplace();
+ OptionalParseResult result = parser.parseOptionalInteger(*value);
+ if (result.hasValue()) {
+ if (succeeded(*result))
+ return value;
+ return failure();
+ }
+ value.reset();
+ return value;
+ }
+};
} // namespace mlir
#include "TestTypeInterfaces.h.inc"
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index 012685fd05cba..d92ae1677100c 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -11,28 +11,28 @@ class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
let mnemonic = asm;
}
-/// Test format is missing a parameter capture.
+// Test format is missing a parameter capture.
def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: format is missing reference to parameter: v1
let assemblyFormat = "`<` $v0 `>`";
}
-/// Test format has duplicate parameter captures.
+// Test format has duplicate parameter captures.
def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
}
-/// Test format has invalid syntax.
+// Test format has invalid syntax.
def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
let parameters = (ins "int":$v0, "int":$v1);
// CHECK: expected literal, variable, directive, or optional group
let assemblyFormat = "`<` $v0, $v1 `>`";
}
-/// Test struct directive has invalid syntax.
+// Test struct directive has invalid syntax.
def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let parameters = (ins "int":$v0);
// CHECK: literals may only be used in the top-level section of the format
@@ -40,37 +40,70 @@ def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
let assemblyFormat = "`<` struct($v0, `,`) `>`";
}
-/// Test struct directive cannot capture zero parameters.
+// Test struct directive cannot capture zero parameters.
def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
let parameters = (ins "int":$v0);
// CHECK: `struct` argument list expected a variable or directive
let assemblyFormat = "`<` struct() $v0 `>`";
}
-/// Test capture parameter that does not exist.
+// Test capture parameter that does not exist.
def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
let parameters = (ins "int":$v0);
// CHECK: InvalidTypeF has no parameter named 'v1'
let assemblyFormat = "`<` $v0 $v1 `>`";
}
-/// Test duplicate capture of parameter in capture-all struct.
+// Test duplicate capture of parameter in capture-all struct.
def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` struct(params) $v0 `>`";
}
-/// Test capture-all struct duplicate capture.
+// Test capture-all struct duplicate capture.
def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
// CHECK: `params` captures duplicate parameter: v0
let assemblyFormat = "`<` $v0 struct(params) `>`";
}
-/// Test capture of parameter after `params` directive.
+// Test capture of parameter after `params` directive.
def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
let parameters = (ins "int":$v0);
// CHECK: duplicate parameter 'v0'
let assemblyFormat = "`<` params $v0 `>`";
}
+
+// Test `struct` with optional parameter followed by comma.
+def InvalidTypeJ : InvalidType<"InvalidTypeJ", "invalid_j"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: directive with optional parameters cannot be followed by a comma literal
+ let assemblyFormat = "struct($a) `,` $b";
+}
+
+// Test `struct` in optional group must have all optional parameters.
+def InvalidTypeK : InvalidType<"InvalidTypeK", "invalid_k"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: is only allowed in an optional group if all captured parameters are optional
+ let assemblyFormat = "(`(` struct(params)^ `)`)?";
+}
+
+// Test `struct` in optional group must have all optional parameters.
+def InvalidTypeL : InvalidType<"InvalidTypeL", "invalid_l"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: directive allowed in optional group only if all parameters are optional
+ let assemblyFormat = "(`(` params^ `)`)?";
+}
+
+def InvalidTypeM : InvalidType<"InvalidTypeM", "invalid_m"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ // CHECK: parameters in an optional group must be optional
+ let assemblyFormat = "(`(` $a^ `,` $b `)`)?";
+}
+
+def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
+ let parameters = (ins OptionalParameter<"int">:$a);
+ // CHECK: optional group anchor must be a parameter or directive
+ let assemblyFormat = "(`(` $a `)`^)?";
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index 20aef66e183f4..3b05d6c7c1371 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -20,4 +20,52 @@ attributes {
// CHECK-LABEL: @test_roundtrip_default_parsers_struct
// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
// CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
-func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>
+// CHECK: !test.optional_param<, 6>
+// CHECK: !test.optional_param<5, 6>
+// CHECK: !test.optional_params<"a">
+// CHECK: !test.optional_params<5, "a">
+// CHECK: !test.optional_struct<b = "a">
+// CHECK: !test.optional_struct<a = 5, b = "a">
+// CHECK: !test.optional_params_after<"a">
+// CHECK: !test.optional_params_after<"a", 5>
+// CHECK: !test.all_optional_params<>
+// CHECK: !test.all_optional_params<5>
+// CHECK: !test.all_optional_params<5, 6>
+// CHECK: !test.all_optional_struct<>
+// CHECK: !test.all_optional_struct<b = 5>
+// CHECK: !test.all_optional_struct<a = 5, b = 10>
+// CHECK: !test.optional_group<(5) 6>
+// CHECK: !test.optional_group<x 6>
+// CHECK: !test.optional_group_params<x>
+// CHECK: !test.optional_group_params<(5)>
+// CHECK: !test.optional_group_params<(5, 6)>
+// CHECK: !test.optional_group_struct<x>
+// CHECK: !test.optional_group_struct<(b = 5)>
+// CHECK: !test.optional_group_struct<(a = 10, b = 5)>
+func private @test_roundtrip_default_parsers_struct(
+ !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
+) -> (
+ !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>,
+ !test.optional_param<, 6>,
+ !test.optional_param<5, 6>,
+ !test.optional_params<"a">,
+ !test.optional_params<5, "a">,
+ !test.optional_struct<b = "a">,
+ !test.optional_struct<b = "a", a = 5>,
+ !test.optional_params_after<"a">,
+ !test.optional_params_after<"a", 5>,
+ !test.all_optional_params<>,
+ !test.all_optional_params<5>,
+ !test.all_optional_params<5, 6>,
+ !test.all_optional_struct<>,
+ !test.all_optional_struct<b = 5>,
+ !test.all_optional_struct<b = 10, a = 5>,
+ !test.optional_group<(5) 6>,
+ !test.optional_group<x 6>,
+ !test.optional_group_params<x>,
+ !test.optional_group_params<(5)>,
+ !test.optional_group_params<(5, 6)>,
+ !test.optional_group_struct<x>,
+ !test.optional_group_struct<(b = 5)>,
+ !test.optional_group_struct<(b = 5, a = 10)>
+)
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index d354c35480a95..4ed281d488db5 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -35,38 +35,38 @@ def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
/// Check simple attribute parser and printer are generated correctly.
-// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR: ::mlir::Type odsType) {
// ATTR: FailureOr<IntegerAttr> _result_value;
// ATTR: FailureOr<TestParamA> _result_complex;
-// ATTR: if (parser.parseKeyword("hello"))
+// ATTR: if (odsParser.parseKeyword("hello"))
// ATTR: return {};
-// ATTR: if (parser.parseEqual())
+// ATTR: if (odsParser.parseEqual())
// ATTR: return {};
-// ATTR: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// ATTR: if (failed(_result_value))
+// ATTR: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// ATTR: if (::mlir::failed(_result_value))
// ATTR: return {};
-// ATTR: if (parser.parseComma())
+// ATTR: if (odsParser.parseComma())
// ATTR: return {};
-// ATTR: _result_complex = ::parseAttrParamA(parser, type);
-// ATTR: if (failed(_result_complex))
+// ATTR: _result_complex = ::parseAttrParamA(odsParser, odsType);
+// ATTR: if (::mlir::failed(_result_complex))
// ATTR: return {};
-// ATTR: if (parser.parseRParen())
+// ATTR: if (odsParser.parseRParen())
// ATTR: return {};
-// ATTR: return TestAAttr::get(parser.getContext(),
-// ATTR: _result_value.getValue(),
-// ATTR: _result_complex.getValue());
+// ATTR: return TestAAttr::get(odsParser.getContext(),
+// ATTR: *_result_value,
+// ATTR: *_result_complex);
// ATTR: }
-// ATTR: void TestAAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR: printer << ' ' << "hello";
-// ATTR: printer << ' ' << "=";
-// ATTR: printer << ' ';
-// ATTR: printer.printStrippedAttrOrType(getValue());
-// ATTR: printer << ",";
-// ATTR: printer << ' ';
-// ATTR: ::printAttrParamA(printer, getComplex());
-// ATTR: printer << ")";
+// ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR: odsPrinter << ' ' << "hello";
+// ATTR: odsPrinter << ' ' << "=";
+// ATTR: odsPrinter << ' ';
+// ATTR: odsPrinter.printStrippedAttrOrType(getValue());
+// ATTR: odsPrinter << ",";
+// ATTR: odsPrinter << ' ';
+// ATTR: ::printAttrParamA(odsPrinter, getComplex());
+// ATTR: odsPrinter << ")";
// ATTR: }
def AttrA : TestAttr<"TestA"> {
@@ -81,47 +81,48 @@ def AttrA : TestAttr<"TestA"> {
/// Test simple struct parser and printer are generated correctly.
-// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR: ::mlir::Type odsType) {
// ATTR: bool _seen_v0 = false;
// ATTR: bool _seen_v1 = false;
-// ATTR: for (unsigned _index = 0; _index < 2; ++_index) {
-// ATTR: StringRef _paramKey;
-// ATTR: if (parser.parseKeyword(&_paramKey))
-// ATTR: return {};
-// ATTR: if (parser.parseEqual())
+// ATTR: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// ATTR: if (odsParser.parseEqual())
// ATTR: return {};
// ATTR: if (!_seen_v0 && _paramKey == "v0") {
// ATTR: _seen_v0 = true;
-// ATTR: _result_v0 = ::parseAttrParamA(parser, type);
-// ATTR: if (failed(_result_v0))
+// ATTR: _result_v0 = ::parseAttrParamA(odsParser, odsType);
+// ATTR: if (::mlir::failed(_result_v0))
// ATTR: return {};
// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
// ATTR: _seen_v1 = true;
-// ATTR: _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser);
-// ATTR: if (failed(_result_v1))
+// ATTR: _result_v1 = odsType ? ::parseAttrWithType(odsParser, odsType) :
+// ATTR-SAME: ::parseAttrWithout(odsParser);
+// ATTR: if (::mlir::failed(_result_v1))
// ATTR: return {};
// ATTR: } else {
// ATTR: return {};
// ATTR: }
-// ATTR: if ((_index != 2 - 1) && parser.parseComma())
+// ATTR: return true;
+// ATTR: }
+// ATTR: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// ATTR: StringRef _paramKey;
+// ATTR: if (odsParser.parseKeyword(&_paramKey))
+// ATTR: return {};
+// ATTR: if (!_loop_body(_paramKey)) return {};
+// ATTR: if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
// ATTR: return {};
// ATTR: }
-// ATTR: return TestBAttr::get(parser.getContext(),
-// ATTR: _result_v0.getValue(),
-// ATTR: _result_v1.getValue());
+// ATTR: return TestBAttr::get(odsParser.getContext(),
+// ATTR: *_result_v0,
+// ATTR: *_result_v1);
// ATTR: }
-// ATTR: void TestBAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR: printer << "v0";
-// ATTR: printer << ' ' << "=";
-// ATTR: printer << ' ';
-// ATTR: ::printAttrParamA(printer, getV0());
-// ATTR: printer << ",";
-// ATTR: printer << ' ' << "v1";
-// ATTR: printer << ' ' << "=";
-// ATTR: printer << ' ';
-// ATTR: ::printAttrB(printer, getV1());
+// ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const {
+// ATTR: odsPrinter << "v0 = ";
+// ATTR: ::printAttrParamA(odsPrinter, getV0());
+// ATTR: odsPrinter << ", ";
+// ATTR: odsPrinter << "v1 = ";
+// ATTR: ::printAttrB(odsPrinter, getV1());
// ATTR: }
def AttrB : TestAttr<"TestB"> {
@@ -136,29 +137,21 @@ def AttrB : TestAttr<"TestB"> {
/// Test attribute with capture-all params has correct parser and printer.
-// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &parser,
-// ATTR: ::mlir::Type type) {
+// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &odsParser,
+// ATTR: ::mlir::Type odsType) {
// ATTR: ::mlir::FailureOr<int> _result_v0;
// ATTR: ::mlir::FailureOr<int> _result_v1;
-// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(parser);
-// ATTR: if (failed(_result_v0))
+// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(odsParser);
+// ATTR: if (::mlir::failed(_result_v0))
// ATTR: return {};
-// ATTR: if (parser.parseComma())
+// ATTR: if (odsParser.parseComma())
// ATTR: return {};
-// ATTR: _result_v1 = ::mlir::FieldParser<int>::parse(parser);
-// ATTR: if (failed(_result_v1))
+// ATTR: _result_v1 = ::mlir::FieldParser<int>::parse(odsParser);
+// ATTR: if (::mlir::failed(_result_v1))
// ATTR: return {};
-// ATTR: return TestFAttr::get(parser.getContext(),
-// ATTR: _result_v0.getValue(),
-// ATTR: _result_v1.getValue());
-// ATTR: }
-
-// ATTR: void TestFAttr::print(::mlir::AsmPrinter &printer) const {
-// ATTR: printer << ' ';
-// ATTR: printer.printStrippedAttrOrType(getV0());
-// ATTR: printer << ",";
-// ATTR: printer << ' ';
-// ATTR: printer.printStrippedAttrOrType(getV1());
+// ATTR: return TestFAttr::get(odsParser.getContext(),
+// ATTR: *_result_v0,
+// ATTR: *_result_v1);
// ATTR: }
def AttrC : TestAttr<"TestF"> {
@@ -171,55 +164,57 @@ def AttrC : TestAttr<"TestF"> {
/// Test type parser and printer that mix variables and struct are generated
/// correctly.
-// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &parser) {
+// TYPE: ::mlir::Type TestCType::parse(::mlir::AsmParser &odsParser) {
// TYPE: FailureOr<IntegerAttr> _result_value;
// TYPE: FailureOr<TestParamC> _result_complex;
-// TYPE: if (parser.parseKeyword("foo"))
+// TYPE: if (odsParser.parseKeyword("foo"))
// TYPE: return {};
-// TYPE: if (parser.parseComma())
+// TYPE: if (odsParser.parseComma())
// TYPE: return {};
-// TYPE: if (parser.parseColon())
+// TYPE: if (odsParser.parseColon())
// TYPE: return {};
-// TYPE: if (parser.parseKeyword("bob"))
+// TYPE: if (odsParser.parseKeyword("bob"))
// TYPE: return {};
-// TYPE: if (parser.parseKeyword("bar"))
+// TYPE: if (odsParser.parseKeyword("bar"))
// TYPE: return {};
-// TYPE: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE: if (failed(_result_value))
+// TYPE: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_value))
// TYPE: return {};
// TYPE: bool _seen_complex = false;
-// TYPE: for (unsigned _index = 0; _index < 1; ++_index) {
-// TYPE: StringRef _paramKey;
-// TYPE: if (parser.parseKeyword(&_paramKey))
-// TYPE: return {};
+// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
// TYPE: if (!_seen_complex && _paramKey == "complex") {
// TYPE: _seen_complex = true;
-// TYPE: _result_complex = ::parseTypeParamC(parser);
-// TYPE: if (failed(_result_complex))
+// TYPE: _result_complex = ::parseTypeParamC(odsParser);
+// TYPE: if (::mlir::failed(_result_complex))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
-// TYPE: if ((_index != 1 - 1) && parser.parseComma())
+// TYPE: return true;
+// TYPE: }
+// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 1; ++odsStructIndex) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (odsParser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (!_loop_body(_paramKey)) return {};
+// TYPE: if ((odsStructIndex != 1 - 1) && odsParser.parseComma())
// TYPE: return {};
// TYPE: }
-// TYPE: if (parser.parseRParen())
+// TYPE: if (odsParser.parseRParen())
// TYPE: return {};
// TYPE: }
-// TYPE: void TestCType::print(::mlir::AsmPrinter &printer) const {
-// TYPE: printer << ' ' << "foo";
-// TYPE: printer << ",";
-// TYPE: printer << ' ' << ":";
-// TYPE: printer << ' ' << "bob";
-// TYPE: printer << ' ' << "bar";
-// TYPE: printer << ' ';
-// TYPE: printer.printStrippedAttrOrType(getValue());
-// TYPE: printer << ' ' << "complex";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer << getComplex();
-// TYPE: printer << ")";
+// TYPE: void TestCType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: odsPrinter << ' ' << "foo";
+// TYPE: odsPrinter << ",";
+// TYPE: odsPrinter << ' ' << ":";
+// TYPE: odsPrinter << ' ' << "bob";
+// TYPE: odsPrinter << ' ' << "bar";
+// TYPE: odsPrinter << ' ';
+// TYPE: odsPrinter.printStrippedAttrOrType(getValue());
+// TYPE: odsPrinter << "complex = ";
+// TYPE: odsPrinter << getComplex();
+// TYPE: odsPrinter << ")";
// TYPE: }
def TypeA : TestType<"TestC"> {
@@ -235,51 +230,53 @@ def TypeA : TestType<"TestC"> {
/// Test type parser and printer with mix of variables and struct are generated
/// correctly.
-// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &parser) {
-// TYPE: _result_v0 = ::parseTypeParamC(parser);
-// TYPE: if (failed(_result_v0))
+// TYPE: ::mlir::Type TestDType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: _result_v0 = ::parseTypeParamC(odsParser);
+// TYPE: if (::mlir::failed(_result_v0))
// TYPE: return {};
// TYPE: bool _seen_v1 = false;
// TYPE: bool _seen_v2 = false;
-// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE: StringRef _paramKey;
-// TYPE: if (parser.parseKeyword(&_paramKey))
-// TYPE: return {};
-// TYPE: if (parser.parseEqual())
+// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE: if (odsParser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
// TYPE: _seen_v1 = true;
// TYPE: _result_v1 = someFcnCall();
-// TYPE: if (failed(_result_v1))
+// TYPE: if (::mlir::failed(_result_v1))
// TYPE: return {};
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
// TYPE: _seen_v2 = true;
-// TYPE: _result_v2 = ::parseTypeParamC(parser);
-// TYPE: if (failed(_result_v2))
+// TYPE: _result_v2 = ::parseTypeParamC(odsParser);
+// TYPE: if (::mlir::failed(_result_v2))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
-// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return true;
+// TYPE: }
+// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (odsParser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (!_loop_body(_paramKey)) return {};
+// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: _result_v3 = someFcnCall();
-// TYPE: if (failed(_result_v3))
+// TYPE: if (::mlir::failed(_result_v3))
// TYPE: return {};
-// TYPE: return TestDType::get(parser.getContext(),
-// TYPE: _result_v0.getValue(),
-// TYPE: _result_v1.getValue(),
-// TYPE: _result_v2.getValue(),
-// TYPE: _result_v3.getValue());
+// TYPE: return TestDType::get(odsParser.getContext(),
+// TYPE: *_result_v0,
+// TYPE: *_result_v1,
+// TYPE: *_result_v2,
+// TYPE: *_result_v3);
// TYPE: }
-// TYPE: void TestDType::print(::mlir::AsmPrinter &printer) const {
-// TYPE: printer << getV0();
+// TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: odsPrinter << getV0();
// TYPE: myPrinter(getV1());
-// TYPE: printer << ' ' << "v2";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer << getV2();
+// TYPE: odsPrinter << "v2 = ";
+// TYPE: odsPrinter << getV2();
// TYPE: myPrinter(getV3());
// TYPE: }
@@ -298,85 +295,86 @@ def TypeB : TestType<"TestD"> {
/// Type test with two struct directives has correctly generated parser and
/// printer.
-// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &parser) {
+// TYPE: ::mlir::Type TestEType::parse(::mlir::AsmParser &odsParser) {
// TYPE: FailureOr<IntegerAttr> _result_v0;
// TYPE: FailureOr<IntegerAttr> _result_v1;
// TYPE: FailureOr<IntegerAttr> _result_v2;
// TYPE: FailureOr<IntegerAttr> _result_v3;
// TYPE: bool _seen_v0 = false;
// TYPE: bool _seen_v2 = false;
-// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE: StringRef _paramKey;
-// TYPE: if (parser.parseKeyword(&_paramKey))
-// TYPE: return {};
-// TYPE: if (parser.parseEqual())
+// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE: if (odsParser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v0 && _paramKey == "v0") {
// TYPE: _seen_v0 = true;
-// TYPE: _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE: if (failed(_result_v0))
+// TYPE: _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_v0))
// TYPE: return {};
// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
// TYPE: _seen_v2 = true;
-// TYPE: _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE: if (failed(_result_v2))
+// TYPE: _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_v2))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
-// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return true;
+// TYPE: }
+// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (odsParser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (!_loop_body(_paramKey)) return {};
+// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
// TYPE: return {};
// TYPE: }
// TYPE: bool _seen_v1 = false;
// TYPE: bool _seen_v3 = false;
-// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
-// TYPE: StringRef _paramKey;
-// TYPE: if (parser.parseKeyword(&_paramKey))
-// TYPE: return {};
-// TYPE: if (parser.parseEqual())
+// TYPE: const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {
+// TYPE: if (odsParser.parseEqual())
// TYPE: return {};
// TYPE: if (!_seen_v1 && _paramKey == "v1") {
// TYPE: _seen_v1 = true;
-// TYPE: _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE: if (failed(_result_v1))
+// TYPE: _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_v1))
// TYPE: return {};
// TYPE: } else if (!_seen_v3 && _paramKey == "v3") {
// TYPE: _seen_v3 = true;
-// TYPE: _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
-// TYPE: if (failed(_result_v3))
+// TYPE: _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_v3))
// TYPE: return {};
// TYPE: } else {
// TYPE: return {};
// TYPE: }
-// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return true;
+// TYPE: }
+// TYPE: for (unsigned odsStructIndex = 0; odsStructIndex < 2; ++odsStructIndex) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (odsParser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (!_loop_body(_paramKey)) return {};
+// TYPE: if ((odsStructIndex != 2 - 1) && odsParser.parseComma())
// TYPE: return {};
// TYPE: }
-// TYPE: return TestEType::get(parser.getContext(),
-// TYPE: _result_v0.getValue(),
-// TYPE: _result_v1.getValue(),
-// TYPE: _result_v2.getValue(),
-// TYPE: _result_v3.getValue());
+// TYPE: return TestEType::get(odsParser.getContext(),
+// TYPE: *_result_v0,
+// TYPE: *_result_v1,
+// TYPE: *_result_v2,
+// TYPE: *_result_v3);
// TYPE: }
-// TYPE: void TestEType::print(::mlir::AsmPrinter &printer) const {
-// TYPE: printer << "v0";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer.printStrippedAttrOrType(getV0());
-// TYPE: printer << ",";
-// TYPE: printer << ' ' << "v2";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer.printStrippedAttrOrType(getV2());
-// TYPE: printer << "v1";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer.printStrippedAttrOrType(getV1());
-// TYPE: printer << ",";
-// TYPE: printer << ' ' << "v3";
-// TYPE: printer << ' ' << "=";
-// TYPE: printer << ' ';
-// TYPE: printer.printStrippedAttrOrType(getV3());
+// TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: odsPrinter << "v0 = ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getV0());
+// TYPE: odsPrinter << ", ";
+// TYPE: odsPrinter << "v2 = ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getV2());
+// TYPE: odsPrinter << ", ";
+// TYPE: odsPrinter << "v1 = ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getV1());
+// TYPE: odsPrinter << ", ";
+// TYPE: odsPrinter << "v3 = ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getV3());
// TYPE: }
def TypeC : TestType<"TestE"> {
@@ -390,3 +388,99 @@ def TypeC : TestType<"TestE"> {
let mnemonic = "type_e";
let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
}
+
+// TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE if (getA()) {
+// TYPE printer << ' ';
+// TYPE printer.printStrippedAttrOrType(getA());
+def TypeD : TestType<"TestF"> {
+ let parameters = (ins OptionalParameter<"int">:$a);
+ let mnemonic = "type_f";
+ let assemblyFormat = "$a";
+}
+
+// TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: if (::mlir::failed(_result_a))
+// TYPE: return {};
+// TYPE: if (::mlir::succeeded(_result_a) && *_result_a)
+// TYPE: if (odsParser.parseComma())
+// TYPE: return {};
+
+// TYPE: if (getA())
+// TYPE: odsPrinter.printStrippedAttrOrType(getA());
+// TYPE: odsPrinter << ", ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getB());
+
+def TypeE : TestType<"TestG"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ let mnemonic = "type_g";
+ let assemblyFormat = "params";
+}
+
+
+// TYPE: ::mlir::Type TestHType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: do {
+// TYPE: if (!_loop_body(_paramKey)) return {};
+// TYPE: } while(!odsParser.parseOptionalComma());
+// TYPE: if (!_seen_b)
+// TYPE: return {};
+
+// TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: if (getA()) {
+// TYPE: odsPrinter << "a = ";
+// TYPE: odsPrinter.printStrippedAttrOrType(getA());
+// TYPE: odsPrinter << ", ";
+// TYPE: }
+
+def TypeF : TestType<"TestH"> {
+ let parameters = (ins OptionalParameter<"int">:$a, "int":$b);
+ let mnemonic = "type_h";
+ let assemblyFormat = "struct(params)";
+}
+
+
+// TYPE: do {
+// TYPE: _result_a = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_a))
+// TYPE: return {};
+// TYPE: if (odsParser.parseOptionalComma()) break;
+// TYPE: _result_b = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_b))
+// TYPE: return {};
+// TYPE: } while(false);
+
+def TypeG : TestType<"TestI"> {
+ let parameters = (ins "int":$a, OptionalParameter<"int">:$b);
+ let mnemonic = "type_i";
+ let assemblyFormat = "params";
+}
+
+// TYPE: ::mlir::Type TestJType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: if (odsParser.parseOptionalLParen()) {
+// TYPE: if (odsParser.parseKeyword("x")) return {};
+// TYPE: } else {
+// TYPE: _result_b = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_b))
+// TYPE: return {};
+// TYPE: if (odsParser.parseRParen()) return {};
+// TYPE: }
+// TYPE: _result_a = ::mlir::FieldParser<int>::parse(odsParser);
+// TYPE: if (::mlir::failed(_result_a))
+// TYPE: return {};
+
+// TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: if (getB()) {
+// TYPE: odsPrinter << "(";
+// TYPE: if (getB())
+// TYPE: odsPrinter.printStrippedAttrOrType(getB());
+// TYPE: odsPrinter << ")";
+// TYPE: } else {
+// TYPE: odsPrinter << ' ' << "x";
+// TYPE: }
+// TYPE: odsPrinter.printStrippedAttrOrType(getA());
+
+def TypeH : TestType<"TestJ"> {
+ let parameters = (ins "int":$a, OptionalParameter<"int">:$b);
+ let mnemonic = "type_j";
+ let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a";
+}
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 34c8588225f70..2393fdae71ed6 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -67,8 +67,8 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
// DECL: return {"cmpnd_a"};
// DECL: }
// DECL: static ::mlir::Attribute parse(
-// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::test::SimpleTypeA getExampleTdType() const;
// DECL: ::llvm::APFloat getApFloat() const;
@@ -107,8 +107,8 @@ def C_IndexAttr : TestAttr<"Index"> {
// DECL: return {"index"};
// DECL: }
// DECL: static ::mlir::Attribute parse(
-// DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL-SAME: ::mlir::AsmParser &odsParser, ::mlir::Type odsType);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
}
def D_SingleParameterAttr : TestAttr<"SingleParameter"> {
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index 733ca5e9b52f0..103dba3201eee 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -70,8 +70,8 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return {"cmpnd_a"};
// DECL: }
-// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
@@ -89,8 +89,8 @@ def C_IndexType : TestType<"Index"> {
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
// DECL: return {"index"};
// DECL: }
-// DECL: static ::mlir::Type parse(::mlir::AsmParser &parser);
-// DECL: void print(::mlir::AsmPrinter &printer) const;
+// DECL: static ::mlir::Type parse(::mlir::AsmParser &odsParser);
+// DECL: void print(::mlir::AsmPrinter &odsPrinter) const;
}
def D_SingleParameterType : TestType<"SingleParameter"> {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index ce90cac72ce34..7e2b147478e17 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -266,9 +266,9 @@ void DefGen::emitParserPrinter() {
// Declare the parser.
SmallVector<MethodParameter> parserParams;
- parserParams.emplace_back("::mlir::AsmParser &", "parser");
+ parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
if (isa<AttrDef>(&def))
- parserParams.emplace_back("::mlir::Type", "type");
+ parserParams.emplace_back("::mlir::Type", "odsType");
auto *parser = defCls.addMethod(
strfmt("::mlir::{0}", valueType), "parse",
def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
@@ -278,7 +278,7 @@ void DefGen::emitParserPrinter() {
def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
Method *printer =
defCls.addMethod("void", "print", props,
- MethodParameter("::mlir::AsmPrinter &", "printer"));
+ MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
// Emit the bodies.
emitParserPrinterBody(parser->body(), printer->body());
}
@@ -431,14 +431,15 @@ void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
if (asmFormat)
return generateAttrOrTypeFormat(def, parser, printer);
- FmtContext ctx = FmtContext(
- {{"_parser", "parser"}, {"_printer", "printer"}, {"_type", "type"}});
+ FmtContext ctx = FmtContext({{"_parser", "odsParser"},
+ {"_printer", "odsPrinter"},
+ {"_type", "odsType"}});
if (parserCode) {
- ctx.addSubst("_ctxt", "parser.getContext()");
+ ctx.addSubst("_ctxt", "odsParser.getContext()");
parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
}
if (printerCode) {
- ctx.addSubst("_ctxt", "printer.getContext()");
+ ctx.addSubst("_ctxt", "odsPrinter.getContext()");
printer.indent().getStream().printReindented(
tgfmt(*printerCode, &ctx).str());
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 4ca20e59cc03b..d697ad2b4ea40 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -16,7 +16,9 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
@@ -48,36 +50,47 @@ class ParameterElement
shouldBeQualifiedFlag = qualified;
}
+ /// Returns true if the element contains an optional parameter.
+ bool isOptional() const { return param.isOptional(); }
+
+ /// Returns the name of the parameter.
+ StringRef getName() const { return param.getName(); }
+
private:
bool shouldBeQualifiedFlag = false;
AttrOrTypeParameter param;
};
+/// Shorthand functions that can be used with ranged-based conditions.
+static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
+static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
+
/// Base class for a directive that contains references to multiple variables.
template <DirectiveElement::Kind DirectiveKind>
class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
public:
using Base = ParamsDirectiveBase<DirectiveKind>;
- ParamsDirectiveBase(std::vector<FormatElement *> &¶ms)
+ ParamsDirectiveBase(std::vector<ParameterElement *> &¶ms)
: params(std::move(params)) {}
/// Get the parameters contained in this directive.
- auto getParams() const {
- return llvm::map_range(params, [](FormatElement *el) {
- return cast<ParameterElement>(el)->getParam();
- });
- }
+ ArrayRef<ParameterElement *> getParams() const { return params; }
/// Get the number of parameters.
unsigned getNumParams() const { return params.size(); }
/// Take all of the parameters from this directive.
- std::vector<FormatElement *> takeParams() { return std::move(params); }
+ std::vector<ParameterElement *> takeParams() { return std::move(params); }
+
+ /// Returns true if there are optional parameters present.
+ bool hasOptionalParams() const {
+ return llvm::any_of(getParams(), paramIsOptional);
+ }
private:
/// The parameters captured by this directive.
- std::vector<FormatElement *> params;
+ std::vector<ParameterElement *> params;
};
/// This class represents a `params` directive that refers to all parameters
@@ -125,36 +138,9 @@ static const char *const qualifiedParameterPrinter = "$_printer << $_self";
/// Print an error when failing to parse an element.
///
/// $0: The parameter C++ class name.
-static const char *const parseErrorStr =
+static const char *const parserErrorStr =
"$_parser.emitError($_parser.getCurrentLocation(), ";
-/// Loop declaration for struct parser.
-///
-/// $0: Number of expected parameters.
-static const char *const structParseLoopStart = R"(
- for (unsigned _index = 0; _index < $0; ++_index) {
- StringRef _paramKey;
- if ($_parser.parseKeyword(&_paramKey)) {
- $_parser.emitError($_parser.getCurrentLocation(),
- "expected a parameter name in struct");
- return {};
- }
-)";
-
-/// Terminator code segment for the struct parser loop. Check for duplicate or
-/// unknown parameters. Parse a comma except on the last element.
-///
-/// {0}: Code template for printing an error.
-/// {1}: Number of elements in the struct.
-static const char *const structParseLoopEnd = R"({{
- {0}"duplicate or unknown struct parameter name: ") << _paramKey;
- return {{};
- }
- if ((_index != {1} - 1) && parser.parseComma())
- return {{};
-}
-)";
-
/// Code format to parse a variable. Separate by lines because variable parsers
/// may be generated inside other directives, which requires indentation.
///
@@ -166,21 +152,20 @@ static const char *const structParseLoopEnd = R"({{
static const char *const variableParser = R"(
// Parse variable '{0}'
_result_{0} = {1};
-if (failed(_result_{0})) {{
+if (::mlir::failed(_result_{0})) {{
{2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
return {{};
}
)";
//===----------------------------------------------------------------------===//
-// AttrOrTypeFormat
+// DefFormat
//===----------------------------------------------------------------------===//
namespace {
-class AttrOrTypeFormat {
+class DefFormat {
public:
- AttrOrTypeFormat(const AttrOrTypeDef &def,
- std::vector<FormatElement *> &&elements)
+ DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
: def(def), elements(std::move(elements)) {}
/// Generate the attribute or type parser.
@@ -192,26 +177,36 @@ class AttrOrTypeFormat {
/// Generate the parser code for a specific format element.
void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a literal.
- void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os);
+ void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
+ bool isOptional = false);
/// Generate the parser code for a variable.
- void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx,
- MethodBody &os);
+ void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `params` directive.
void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the parser code for an optional group.
+ void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
+ MethodBody &os);
/// Generate the printer code for a specific format element.
void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a literal.
void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a variable.
- void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx,
- MethodBody &os, bool printQualified = false);
+ void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
+ bool skipGuard = false);
+ /// Generate a printer for comma-separated parameters.
+ void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
+ FmtContext &ctx, MethodBody &os,
+ function_ref<void(ParameterElement *)> extra);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
+ /// Generate the printer code for an optional group.
+ void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
+ MethodBody &os);
/// The ODS definition of the attribute or type whose format is being used to
/// generate a parser and printer.
@@ -230,35 +225,44 @@ class AttrOrTypeFormat {
// ParserGen
//===----------------------------------------------------------------------===//
-void AttrOrTypeFormat::genParser(MethodBody &os) {
+void DefFormat::genParser(MethodBody &os) {
FmtContext ctx;
- ctx.addSubst("_parser", "parser");
+ ctx.addSubst("_parser", "odsParser");
if (isa<AttrDef>(def))
- ctx.addSubst("_type", "type");
+ ctx.addSubst("_type", "odsType");
os.indent();
- /// Declare variables to store all of the parameters. Allocated parameters
- /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
- /// FailureOr<T> to defer type construction for parameters that are parsed in
- /// a loop (parsers return FailureOr anyways).
+ // Declare variables to store all of the parameters. Allocated parameters
+ // such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
+ // FailureOr<T> to defer type construction for parameters that are parsed in
+ // a loop (parsers return FailureOr anyways).
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
for (const AttrOrTypeParameter ¶m : params) {
- os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n",
+ os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
param.getCppStorageType(), param.getName());
}
- /// Store the initial location of the parser.
- ctx.addSubst("_loc", "loc");
+ // Store the initial location of the parser.
+ ctx.addSubst("_loc", "odsLoc");
os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
"(void) $_loc;\n",
&ctx);
- /// Generate call to each parameter parser.
+ // Generate call to each parameter parser.
for (FormatElement *el : elements)
genElementParser(el, ctx, os);
- /// Generate call to the attribute or type builder. Use the checked getter
- /// if one was generated.
+ // Emit an assert for each mandatory parameter. Triggering an assert means
+ // the generated parser is incorrect (i.e. there is a bug in this code).
+ for (const AttrOrTypeParameter ¶m : params) {
+ if (!param.isOptional()) {
+ os << formatv("assert(::mlir::succeeded(_result_{0}));\n",
+ param.getName());
+ }
+ }
+
+ // Generate call to the attribute or type builder. Use the checked getter
+ // if one was generated.
if (def.genVerifyDecl()) {
os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
@@ -266,29 +270,38 @@ void AttrOrTypeFormat::genParser(MethodBody &os) {
os << tgfmt("return $0::get($_parser.getContext()", &ctx,
def.getCppClassName());
}
- for (const AttrOrTypeParameter ¶m : params)
- os << formatv(",\n _result_{0}.getValue()", param.getName());
+ for (const AttrOrTypeParameter ¶m : params) {
+ if (param.isOptional())
+ os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(),
+ param.getCppStorageType());
+ else
+ os << formatv(",\n *_result_{0}", param.getName());
+ }
os << ");";
}
-void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx,
- MethodBody &os) {
+void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
- return genVariableParser(var->getParam(), ctx, os);
+ return genVariableParser(var, ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os);
+ if (auto *optional = dyn_cast<OptionalElement>(el))
+ return genOptionalGroupParser(optional, ctx, os);
llvm_unreachable("unknown format element");
}
-void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
- MethodBody &os) {
+void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
+ MethodBody &os, bool isOptional) {
os << "// Parse literal '" << value << "'\n";
os << tgfmt("if ($_parser.parse", &ctx);
+ if (isOptional)
+ os << "Optional";
if (value.front() == '_' || isalpha(value.front())) {
os << "Keyword(\"" << value << "\")";
} else {
@@ -310,70 +323,275 @@ void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
.Case("*", "Star")
<< "()";
}
- os << ")\n";
+ if (isOptional) {
+ // Leave the `if` unclosed to guard optional groups.
+ return;
+ }
// Parser will emit an error
- os << " return {};\n";
+ os << ") return {};\n";
}
-void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m,
- FmtContext &ctx, MethodBody &os) {
- /// Check for a custom parser. Use the default attribute parser otherwise.
+void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
+ MethodBody &os) {
+ // Check for a custom parser. Use the default attribute parser otherwise.
+ const AttrOrTypeParameter ¶m = el->getParam();
auto customParser = param.getParser();
auto parser =
customParser ? *customParser : StringRef(defaultParameterParser);
os << formatv(variableParser, param.getName(),
tgfmt(parser, &ctx, param.getCppStorageType()),
- tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType());
+ tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
}
-void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os) {
os << "// Parse parameter list\n";
- llvm::interleave(
- el->getParams(),
- [&](auto param) { this->genVariableParser(param, ctx, os); },
- [&]() { this->genLiteralParser(",", ctx, os); });
+
+ // If there are optional parameters, we need to switch to `parseOptionalComma`
+ // if there are no more required parameters after a certain point.
+ bool hasOptional = el->hasOptionalParams();
+ if (hasOptional) {
+ // Wrap everything in a do-while so that we can `break`.
+ os << "do {\n";
+ os.indent();
+ }
+
+ ArrayRef<ParameterElement *> params = el->getParams();
+ using IteratorT = ParameterElement *const *;
+ IteratorT it = params.begin();
+
+ // Find the last required parameter. Commas become optional aftewards.
+ // Note: IteratorT's copy assignment is deleted.
+ ParameterElement *lastReq = nullptr;
+ for (ParameterElement *param : params)
+ if (!param->isOptional())
+ lastReq = param;
+ IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
+
+ auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
+ auto betweenFn = [&](IteratorT it) {
+ ParameterElement *el = *std::prev(it);
+ // Parse a comma if the last optional parameter had a value.
+ if (el->isOptional()) {
+ os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n",
+ el->getName());
+ os.indent();
+ }
+ if (it <= lastReqIt) {
+ genLiteralParser(",", ctx, os);
+ } else {
+ genLiteralParser(",", ctx, os, /*isOptional=*/true);
+ os << ") break;\n";
+ }
+ if (el->isOptional())
+ os.unindent() << "}\n";
+ };
+
+ // llvm::interleave
+ if (it != params.end()) {
+ eachFn(*it++);
+ for (IteratorT e = params.end(); it != e; ++it) {
+ betweenFn(it);
+ eachFn(*it);
+ }
+ }
+
+ if (hasOptional)
+ os.unindent() << "} while(false);\n";
}
-void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
- MethodBody &os) {
+void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
+ MethodBody &os) {
+ // Loop declaration for struct parser with only required parameters.
+ //
+ // $0: Number of expected parameters.
+ const char *const loopHeader = R"(
+ for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
+)";
+
+ // Loop body start for struct parser.
+ const char *const loopStart = R"(
+ ::llvm::StringRef _paramKey;
+ if ($_parser.parseKeyword(&_paramKey)) {
+ $_parser.emitError($_parser.getCurrentLocation(),
+ "expected a parameter name in struct");
+ return {};
+ }
+ if (!_loop_body(_paramKey)) return {};
+)";
+
+ // Struct parser loop end. Check for duplicate or unknown struct parameters.
+ //
+ // {0}: Code template for printing an error.
+ const char *const loopEnd = R"({{
+ {0}"duplicate or unknown struct parameter name: ") << _paramKey;
+ return {{};
+}
+)";
+
+ // Struct parser loop terminator. Parse a comma except on the last element.
+ //
+ // {0}: Number of elements in the struct.
+ const char *const loopTerminator = R"(
+ if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
+ return {{};
+}
+)";
+
+ // Check that a mandatory parameter was parse.
+ //
+ // {0}: Name of the parameter.
+ const char *const checkParam = R"(
+ if (!_seen_{0}) {
+ {1}"struct is missing required parameter: ") << "{0}";
+ return {{};
+ }
+)";
+
+ // Optional parameters in a struct must be parsed successfully if the
+ // keyword is present.
+ //
+ // {0}: Name of the parameter.
+ // {1}: Emit error string
+ const char *const checkOptionalParam = R"(
+ if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{
+ {1}"expected a value for parameter '{0}'");
+ return {{};
+ }
+)";
+
+ // First iteration of the loop parsing an optional struct.
+ const char *const optionalStructFirst = R"(
+ ::llvm::StringRef _paramKey;
+ if (!$_parser.parseOptionalKeyword(&_paramKey)) {
+ if (!_loop_body(_paramKey)) return {};
+ while (!$_parser.parseOptionalComma()) {
+)";
+
os << "// Parse parameter struct\n";
- /// Declare a "seen" variable for each key.
- for (const AttrOrTypeParameter ¶m : el->getParams())
- os << formatv("bool _seen_{0} = false;\n", param.getName());
+ // Declare a "seen" variable for each key.
+ for (ParameterElement *param : el->getParams())
+ os << formatv("bool _seen_{0} = false;\n", param->getName());
- /// Generate the parsing loop.
- os.getStream().printReindented(
- tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str());
- os.indent();
- genLiteralParser("=", ctx, os);
- for (const AttrOrTypeParameter ¶m : el->getParams()) {
+ // Generate the body of the parsing loop inside a lambda.
+ os << "{\n";
+ os.indent()
+ << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
+ genLiteralParser("=", ctx, os.indent());
+ for (ParameterElement *param : el->getParams()) {
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
" _seen_{0} = true;\n",
- param.getName());
+ param->getName());
genVariableParser(param, ctx, os.indent());
+ if (param->isOptional()) {
+ os.getStream().printReindented(strfmt(checkOptionalParam,
+ param->getName(),
+ tgfmt(parserErrorStr, &ctx).str()));
+ }
os.unindent() << "} else ";
+ // Print the check for duplicate or unknown parameter.
}
+ os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
+ os << "return true;\n";
+ os.unindent() << "};\n";
+
+ // Generate the parsing loop. If optional parameters are present, then the
+ // parse loop is guarded by commas.
+ unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
+ if (numOptional) {
+ // If the struct itself is optional, pull out the first iteration.
+ if (numOptional == el->getNumParams()) {
+ os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
+ os.indent();
+ } else {
+ os << "do {\n";
+ }
+ } else {
+ os.getStream().printReindented(
+ tgfmt(loopHeader, &ctx, el->getNumParams()).str());
+ }
+ os.indent();
+ os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
os.unindent();
- /// Duplicate or unknown parameter.
- os.getStream().printReindented(strfmt(
- structParseLoopEnd, tgfmt(parseErrorStr, &ctx), el->getNumParams()));
+ // Print the loop terminator. For optional parameters, we have to check that
+ // all mandatory parameters have been parsed.
+ // The whole struct is optional if all its parameters are optional.
+ if (numOptional) {
+ if (numOptional == el->getNumParams()) {
+ os << "}\n";
+ os.unindent() << "}\n";
+ } else {
+ os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
+ for (ParameterElement *param : el->getParams()) {
+ if (param->isOptional())
+ continue;
+ os.getStream().printReindented(
+ strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
+ }
+ }
+ } else {
+ // Because the loop loops N times and each non-failing iteration sets 1 of
+ // N flags, successfully exiting the loop means that all parameters have
+ // been seen. `parseOptionalComma` would cause issues with any formats that
+ // use "struct(...) `,`" beacuse structs aren't sounded by braces.
+ os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
+ }
+ os.unindent() << "}\n";
+}
+
+void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
+ MethodBody &os) {
+ ArrayRef<FormatElement *> elements =
+ el->getThenElements().drop_front(el->getParseStart());
+
+ FormatElement *first = elements.front();
+ const auto guardOn = [&](auto params) {
+ os << "if (!(";
+ llvm::interleave(
+ params, os,
+ [&](ParameterElement *el) {
+ os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
+ el->getName());
+ },
+ " || ");
+ os << ")) {\n";
+ };
+ if (auto *literal = dyn_cast<LiteralElement>(first)) {
+ genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
+ os << ") {\n";
+ } else if (auto *param = dyn_cast<ParameterElement>(first)) {
+ genVariableParser(param, ctx, os);
+ guardOn(llvm::makeArrayRef(param));
+ } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
+ genParamsParser(params, ctx, os);
+ guardOn(params->getParams());
+ } else {
+ auto *strct = cast<StructDirective>(first);
+ genStructParser(strct, ctx, os);
+ guardOn(params->getParams());
+ }
+ os.indent();
- /// Because the loop loops N times and each non-failing iteration sets 1 of
- /// N flags, successfully exiting the loop means that all parameters have been
- /// seen. `parseOptionalComma` would cause issues with any formats that use
- /// "struct(...) `,`" beacuse structs aren't sounded by braces.
+ // Generate the parsers for the rest of the elements.
+ for (FormatElement *element : el->getElseElements())
+ genElementParser(element, ctx, os);
+ os.unindent() << "} else {\n";
+ os.indent();
+ for (FormatElement *element : elements.drop_front())
+ genElementParser(element, ctx, os);
+ os.unindent() << "}\n";
}
//===----------------------------------------------------------------------===//
// PrinterGen
//===----------------------------------------------------------------------===//
-void AttrOrTypeFormat::genPrinter(MethodBody &os) {
+void DefFormat::genPrinter(MethodBody &os) {
FmtContext ctx;
- ctx.addSubst("_printer", "printer");
+ ctx.addSubst("_printer", "odsPrinter");
+ os.indent();
/// Generate printers.
shouldEmitSpace = true;
@@ -382,8 +600,8 @@ void AttrOrTypeFormat::genPrinter(MethodBody &os) {
genElementPrinter(el, ctx, os);
}
-void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
- MethodBody &os) {
+void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
+ MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
@@ -391,63 +609,147 @@ void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
- return genVariablePrinter(var->getParam(), ctx, os,
- var->shouldBeQualified());
+ return genVariablePrinter(var, ctx, os);
+ if (auto *optional = dyn_cast<OptionalElement>(el))
+ return genOptionalGroupPrinter(optional, ctx, os);
- llvm_unreachable("unknown format element");
+ llvm::PrintFatalError("unsupported format element");
}
-void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
- MethodBody &os) {
- /// Don't insert a space before certain punctuation.
+void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
+ MethodBody &os) {
+ // Don't insert a space before certain punctuation.
bool needSpace =
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
- os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
+ os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
value);
- /// Update the flags.
+ // Update the flags.
shouldEmitSpace =
value.size() != 1 || !StringRef("<({[").contains(value.front());
lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
}
-void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m,
- FmtContext &ctx, MethodBody &os,
- bool printQualified) {
- /// Insert a space before the next parameter, if necessary.
+void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
+ MethodBody &os, bool skipGuard) {
+ const AttrOrTypeParameter ¶m = el->getParam();
+ ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+
+ // Guard the printer on the presence of optional parameters.
+ if (el->isOptional() && !skipGuard) {
+ os << tgfmt("if ($_self) {\n", &ctx);
+ os.indent();
+ }
+
+ // Insert a space before the next parameter, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
- os << tgfmt(" $_printer << ' ';\n", &ctx);
+ os << tgfmt("$_printer << ' ';\n", &ctx);
shouldEmitSpace = true;
lastWasPunctuation = false;
- ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
- os << " ";
- if (printQualified)
+ if (el->shouldBeQualified())
os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
else if (auto printer = param.getPrinter())
os << tgfmt(*printer, &ctx) << ";\n";
else
os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
+
+ if (el->isOptional() && !skipGuard)
+ os.unindent() << "}\n";
}
-void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
- MethodBody &os) {
- llvm::interleave(
- el->getParams(),
- [&](auto param) { this->genVariablePrinter(param, ctx, os); },
- [&] { this->genLiteralPrinter(",", ctx, os); });
+void DefFormat::genCommaSeparatedPrinter(
+ ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
+ function_ref<void(ParameterElement *)> extra) {
+ // Emit a space if necessary, but only if the struct is present.
+ if (shouldEmitSpace || !lastWasPunctuation) {
+ bool allOptional = llvm::all_of(params, paramIsOptional);
+ if (allOptional) {
+ os << "if (";
+ llvm::interleave(
+ params, os,
+ [&](ParameterElement *param) {
+ os << getParameterAccessorName(param->getName()) << "()";
+ },
+ " || ");
+ os << ") {\n";
+ os.indent();
+ }
+ os << tgfmt("$_printer << ' ';\n", &ctx);
+ if (allOptional)
+ os.unindent() << "}\n";
+ }
+
+ // The first printed element does not need to emit a comma.
+ os << "{\n";
+ os.indent() << "bool _firstPrinted = true;\n";
+ for (ParameterElement *param : params) {
+ if (param->isOptional()) {
+ os << tgfmt("if ($_self()) {\n",
+ &ctx.withSelf(getParameterAccessorName(param->getName())));
+ os.indent();
+ }
+ os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
+ os << "_firstPrinted = false;\n";
+ extra(param);
+ shouldEmitSpace = false;
+ lastWasPunctuation = true;
+ genVariablePrinter(param, ctx, os);
+ if (param->isOptional())
+ os.unindent() << "}\n";
+ }
+ os.unindent() << "}\n";
+}
+
+void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+ MethodBody &os) {
+ genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
+ [&](ParameterElement *param) {});
+}
+
+void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+ MethodBody &os) {
+ genCommaSeparatedPrinter(
+ llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
+ os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
+ });
}
-void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
- llvm::interleave(
- el->getParams(),
- [&](auto param) {
- this->genLiteralPrinter(param.getName(), ctx, os);
- this->genLiteralPrinter("=", ctx, os);
- this->genVariablePrinter(param, ctx, os);
- },
- [&] { this->genLiteralPrinter(",", ctx, os); });
+ // Emit the check on whether the group should be printed.
+ const auto guardOn = [&](auto params) {
+ os << "if (";
+ llvm::interleave(
+ params, os,
+ [&](ParameterElement *el) {
+ os << getParameterAccessorName(el->getName()) << "()";
+ },
+ " || ");
+ os << ") {\n";
+ os.indent();
+ };
+ FormatElement *anchor = el->getAnchor();
+ if (auto *param = dyn_cast<ParameterElement>(anchor)) {
+ guardOn(llvm::makeArrayRef(param));
+ } else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
+ guardOn(params->getParams());
+ } else {
+ auto *strct = dyn_cast<StructDirective>(anchor);
+ guardOn(strct->getParams());
+ }
+ // Generate the printer for the contained elements.
+ {
+ llvm::SaveAndRestore<bool> shouldEmitSpaceFlag(shouldEmitSpace);
+ llvm::SaveAndRestore<bool> lastWasPunctuationFlag(lastWasPunctuation);
+ for (FormatElement *element : el->getThenElements())
+ genElementPrinter(element, ctx, os);
+ }
+ os.unindent() << "} else {\n";
+ os.indent();
+ for (FormatElement *element : el->getElseElements())
+ genElementPrinter(element, ctx, os);
+ os.unindent() << "}\n";
}
//===----------------------------------------------------------------------===//
@@ -462,7 +764,7 @@ class DefFormatParser : public FormatParser {
seenParams(def.getNumParameters()) {}
/// Parse the attribute or type format and create the format elements.
- FailureOr<AttrOrTypeFormat> parse();
+ FailureOr<DefFormat> parse();
protected:
/// Verify the parsed elements.
@@ -476,9 +778,7 @@ class DefFormatParser : public FormatParser {
/// Verify the elements of an optional group.
LogicalResult
verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
- Optional<unsigned> anchorIndex) override {
- return emitError(loc, "optional groups not (yet) supported");
- }
+ Optional<unsigned> anchorIndex) override;
/// Parse an attribute or type variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
@@ -505,30 +805,76 @@ class DefFormatParser : public FormatParser {
LogicalResult DefFormatParser::verify(SMLoc loc,
ArrayRef<FormatElement *> elements) {
+ // Check that all parameters are referenced in the format.
for (auto &it : llvm::enumerate(def.getParameters())) {
- if (!seenParams.test(it.index())) {
+ if (!it.value().isOptional() && !seenParams.test(it.index())) {
return emitError(loc, "format is missing reference to parameter: " +
it.value().getName());
}
}
+ // A `struct` directive that contains optional parameters cannot be followed
+ // by a comma literal, which is ambiguous.
+ for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
+ auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
+ auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
+ if (!structEl || !literalEl)
+ continue;
+ if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
+ return emitError(loc, "`struct` directive with optional parameters "
+ "cannot be followed by a comma literal");
+ }
+ }
+ return success();
+}
+
+LogicalResult
+DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
+ ArrayRef<FormatElement *> elements,
+ Optional<unsigned> anchorIndex) {
+ // `params` and `struct` directives are allowed only if all the contained
+ // parameters are optional.
+ for (FormatElement *el : elements) {
+ if (auto *param = dyn_cast<ParameterElement>(el)) {
+ if (!param->isOptional()) {
+ return emitError(loc,
+ "parameters in an optional group must be optional");
+ }
+ } else if (auto *params = dyn_cast<ParamsDirective>(el)) {
+ if (llvm::any_of(params->getParams(), paramNotOptional)) {
+ return emitError(loc, "`params` directive allowed in optional group "
+ "only if all parameters are optional");
+ }
+ } else if (auto *strct = dyn_cast<StructDirective>(el)) {
+ if (llvm::any_of(strct->getParams(), paramNotOptional)) {
+ return emitError(loc, "`struct` is only allowed in an optional group "
+ "if all captured parameters are optional");
+ }
+ }
+ }
+ // The anchor must be a parameter or one of the aforementioned directives.
+ if (anchorIndex && !isa<ParameterElement, ParamsDirective, StructDirective>(
+ elements[*anchorIndex])) {
+ return emitError(loc,
+ "optional group anchor must be a parameter or directive");
+ }
return success();
}
-FailureOr<AttrOrTypeFormat> DefFormatParser::parse() {
+FailureOr<DefFormat> DefFormatParser::parse() {
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
if (failed(elements))
return failure();
- return AttrOrTypeFormat(def, std::move(*elements));
+ return DefFormat(def, std::move(*elements));
}
FailureOr<FormatElement *>
DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
- /// Lookup the parameter.
+ // Lookup the parameter.
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
auto *it = llvm::find_if(
params, [&](auto ¶m) { return param.getName() == name; });
- /// Check that the parameter reference is valid.
+ // Check that the parameter reference is valid.
if (it == params.end()) {
return emitError(loc,
def.getName() + " has no parameter named '" + name + "'");
@@ -581,9 +927,9 @@ DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
}
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
- /// Collect all of the attribute's or type's parameters.
- std::vector<FormatElement *> vars;
- /// Ensure that none of the parameters have already been captured.
+ // Collect all of the attribute's or type's parameters.
+ std::vector<ParameterElement *> vars;
+ // Ensure that none of the parameters have already been captured.
for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
return emitError(loc, "`params` captures duplicate parameter: " +
@@ -600,27 +946,27 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
"expected '(' before `struct` argument list")))
return failure();
- /// Parse variables captured by `struct`.
- std::vector<FormatElement *> vars;
+ // Parse variables captured by `struct`.
+ std::vector<ParameterElement *> vars;
- /// Parse first captured parameter or a `params` directive.
+ // Parse first captured parameter or a `params` directive.
FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
return emitError(loc,
"`struct` argument list expected a variable or directive");
}
if (isa<VariableElement>(*var)) {
- /// Parse any other parameters.
- vars.push_back(std::move(*var));
+ // Parse any other parameters.
+ vars.push_back(cast<ParameterElement>(*var));
while (peekToken().is(FormatToken::comma)) {
consumeToken();
var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement>(*var))
return emitError(loc, "expected a variable in `struct` argument list");
- vars.push_back(std::move(*var));
+ vars.push_back(cast<ParameterElement>(*var));
}
} else {
- /// `struct(params)` captures all parameters in the attribute or type.
+ // `struct(params)` captures all parameters in the attribute or type.
vars = cast<ParamsDirective>(*var)->takeParams();
}
@@ -642,16 +988,16 @@ void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
- /// Parse the custom assembly format>
+ // Parse the custom assembly format>
DefFormatParser fmtParser(mgr, def);
- FailureOr<AttrOrTypeFormat> format = fmtParser.parse();
+ FailureOr<DefFormat> format = fmtParser.parse();
if (failed(format)) {
if (formatErrorIsFatal)
PrintFatalError(def.getLoc(), "failed to parse assembly format");
return;
}
- /// Generate the parser and printer.
+ // Generate the parser and printer.
format->genParser(parser);
format->genPrinter(printer);
}
More information about the Mlir-commits
mailing list