[Mlir-commits] [mlir] 9a2fdc3 - [MLIR] Attribute and type formats in ODS
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 8 09:38:32 PST 2021
Author: Jeff Niu
Date: 2021-11-08T17:38:28Z
New Revision: 9a2fdc369dae21a6bb1d2145479a76d3775c980b
URL: https://github.com/llvm/llvm-project/commit/9a2fdc369dae21a6bb1d2145479a76d3775c980b
DIFF: https://github.com/llvm/llvm-project/commit/9a2fdc369dae21a6bb1d2145479a76d3775c980b.diff
LOG: [MLIR] Attribute and type formats in ODS
Declarative attribute and type formats with assembly formats. Define an
`assemblyFormat` field in attribute and type defs with a `mnemonic` to
generate a parser and printer.
```tablegen
def MyAttr : AttrDef<MyDialect, "MyAttr"> {
let parameters = (ins "int64_t":$count, "AffineMap":$map);
let mnemonic = "my_attr";
let assemblyFormat = "`<` $count `,` $map `>`";
}
```
Use `struct` to define a comma-separated list of key-value pairs:
```tablegen
def MyType : TypeDef<MyDialect, "MyType"> {
let parameters = (ins "int":$one, "int":$two, "int":$three);
let mnemonic = "my_attr";
let assemblyFormat = "`<` $three `:` struct($one, $two) `>`";
}
```
Use `struct(*)` to capture all parameters.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D111594
Added:
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
mlir/test/mlir-tblgen/attr-or-type-format.mlir
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
mlir/tools/mlir-tblgen/FormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.h
Modified:
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/include/mlir/IR/DialectImplementation.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/include/mlir/TableGen/Dialect.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.h
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/CMakeLists.txt
mlir/tools/mlir-tblgen/OpFormatGen.cpp
utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 30d6a6e9412e8..0f8edc5bf1ae7 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -382,3 +382,172 @@ the things named `*Type` are generally now named `*Attr`.
Aside from that, all of the interfaces for uniquing and storage construction are
all the same.
+
+## Defining Custom Parsers and Printers using Assembly Formats
+
+Attributes and types defined in ODS with a mnemonic can define an
+`assemblyFormat` to declaratively describe custom parsers and printers. The
+assembly format consists of literals, variables, and directives.
+
+* A literal is a keyword or valid punctuation enclosed in backticks, e.g.
+ `` `keyword` `` or `` `<` ``.
+* A variable is a parameter name preceeded by a dollar sign, e.g. `$param0`,
+ which captures one attribute or type parameter.
+* A directive is a keyword followed by an optional argument list that defines
+ special parser and printer behaviour.
+
+```tablegen
+// An example type with an assembly format.
+def MyType : TypeDef<My_Dialect, "MyType"> {
+ // Define a mnemonic to allow the dialect's parser hook to call into the
+ // generated parser.
+ let mnemonic = "my_type";
+
+ // Define two parameters whose C++ types are indicated in string literals.
+ let parameters = (ins "int":$count, "AffineMap":$map);
+
+ // Define the assembly format. Surround the format with less `<` and greater
+ // `>` so that MLIR's printers use the pretty format.
+ let assemblyFormat = "`<` $count `,` `map` `=` $map `>`";
+}
+```
+
+The declarative assembly format for `MyType` results in the following format
+in the IR:
+
+```mlir
+!my_dialect.my_type<42, map = affine_map<(i, j) -> (j, i)>
+```
+
+### Parameter Parsing and Printing
+
+For many basic parameter types, no additional work is needed to define how
+these parameters are parsed or printerd.
+
+* The default printer for any parameter is `$_printer << $_self`,
+ where `$_self` is the C++ value of the parameter and `$_printer` is a
+ `DialectAsmPrinter`.
+* The default parser for a parameter is
+ `FieldParser<$cppClass>::parse($_parser)`, where `$cppClass` is the C++ type
+ of the parameter and `$_parser` is a `DialectAsmParser`.
+
+Printing and parsing behaviour can be added to additional C++ types by
+overloading these functions or by defining a `parser` and `printer` in an ODS
+parameter class.
+
+Example of overloading:
+
+```c++
+using MyParameter = std::pair<int, int>;
+
+DialectAsmPrinter &operator<<(DialectAsmPrinter &printer, MyParameter param) {
+ printer << param.first << " * " << param.second;
+}
+
+template <> struct FieldParser<MyParameter> {
+ static FailureOr<MyParameter> parse(DialectAsmParser &parser) {
+ int a, b;
+ if (parser.parseInteger(a) || parser.parseStar() ||
+ parser.parseInteger(b))
+ return failure();
+ return MyParameter(a, b);
+ }
+};
+```
+
+Example of using ODS parameter classes:
+
+```
+def MyParameter : TypeParameter<"std::pair<int, int>", "pair of ints"> {
+ let printer = [{ $_printer << $_self.first << " * " << $_self.second }];
+ let parser = [{ [&] -> FailureOr<std::pair<int, int>> {
+ int a, b;
+ if ($_parser.parseInteger(a) || $_parser.parseStar() ||
+ $_parser.parseInteger(b))
+ return failure();
+ return std::make_pair(a, b);
+ }() }];
+}
+```
+
+A type using this parameter with the assembly format `` `<` $myParam `>` ``
+will look as follows in the IR:
+
+```mlir
+!my_dialect.my_type<42 * 24>
+```
+
+#### Non-POD Parameters
+
+Parameters that aren't plain-old-data (e.g. references) may need to define a
+`cppStorageType` to contain the data until it is copied into the allocator.
+For example, `StringRefParameter` uses `std::string` as its storage type,
+whereas `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers
+for these parameters are expected to return `FailureOr<$cppStorageType>`.
+
+### Assembly Format Directives
+
+Attribute and type assembly formats have the following directives:
+
+* `params`: capture all parameters of an attribute or type.
+* `struct`: generate a "struct-like" parser and printer for a list of key-value
+ pairs.
+
+#### `params` Directive
+
+This directive is used to refer to all parameters of an attribute or type.
+When used as a top-level directive, `params` generates a parser and printer for
+a comma-separated list of the parameters. For example:
+
+```tablegen
+def MyPairType : TypeDef<My_Dialect, "MyPairType"> {
+ let parameters = (ins "int":$a, "int":$b);
+ let mnemonic = "pair";
+ let assemblyFormat = "`<` params `>`";
+}
+```
+
+In the IR, this type will appear as:
+
+```mlir
+!my_dialect.pair<42, 24>
+```
+
+The `params` directive can also be passed to other directives, such as `struct`,
+as an argument that refers to all parameters in place of explicitly listing all
+parameters as variables.
+
+#### `struct` Directive
+
+The `struct` directive accepts a list of variables to capture and will generate
+a parser and printer for a comma-separated list of key-value pairs. The
+variables are printed in the order they are specified in the argument list **but
+can be parsed in any order**. For example:
+
+```tablegen
+def MyStructType : TypeDef<My_Dialect, "MyStructType"> {
+ let parameters = (ins StringRefParameter<>:$sym_name,
+ "int":$a, "int":$b, "int":$c);
+ let mnemonic = "struct";
+ let assemblyFormat = "`<` $sym_name `->` struct($a, $b, $c) `>`";
+}
+```
+
+In the IR, this type can appear with any permutation of the order of the
+parameters captured in the directive.
+
+```mlir
+!my_dialect.struct<"foo" -> a = 1, b = 2, c = 3>
+!my_dialect.struct<"foo" -> b = 2, c = 3, a = 1>
+```
+
+Passing `params` as the only argument to `struct` makes the directive capture
+all the parameters of the attribute or type. For the same type above, an
+assembly format of `` `<` struct(params) `>` `` will result in:
+
+```mlir
+!my_dialect.struct<b = 2, sym_name = "foo", c = 3, a = 1>
+```
+
+The order in which the parameters are printed is the order in which they are
+declared in the attribute's or type's `parameter` list.
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 728e24605c29f..9302eb9146d4b 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -47,6 +47,74 @@ class DialectAsmParser : public AsmParser {
virtual StringRef getFullSymbolSpec() const = 0;
};
+//===----------------------------------------------------------------------===//
+// Parse Fields
+//===----------------------------------------------------------------------===//
+
+/// Provide a template class that can be specialized by users to dispatch to
+/// parsers. Auto-generated parsers generate calls to `FieldParser<T>::parse`,
+/// where `T` is the parameter storage type, to parse custom types.
+template <typename T, typename = T>
+struct FieldParser;
+
+/// Parse an attribute.
+template <typename AttributeT>
+struct FieldParser<
+ AttributeT, std::enable_if_t<std::is_base_of<Attribute, AttributeT>::value,
+ AttributeT>> {
+ static FailureOr<AttributeT> parse(DialectAsmParser &parser) {
+ AttributeT value;
+ if (parser.parseAttribute(value))
+ return failure();
+ return value;
+ }
+};
+
+/// Parse any integer.
+template <typename IntT>
+struct FieldParser<IntT,
+ std::enable_if_t<std::is_integral<IntT>::value, IntT>> {
+ static FailureOr<IntT> parse(DialectAsmParser &parser) {
+ IntT value;
+ if (parser.parseInteger(value))
+ return failure();
+ return value;
+ }
+};
+
+/// Parse a string.
+template <>
+struct FieldParser<std::string> {
+ static FailureOr<std::string> parse(DialectAsmParser &parser) {
+ std::string value;
+ if (parser.parseString(&value))
+ return failure();
+ return value;
+ }
+};
+
+/// Parse any container that supports back insertion as a list.
+template <typename ContainerT>
+struct FieldParser<
+ ContainerT, std::enable_if_t<std::is_member_function_pointer<
+ decltype(&ContainerT::push_back)>::value,
+ ContainerT>> {
+ using ElementT = typename ContainerT::value_type;
+ static FailureOr<ContainerT> parse(DialectAsmParser &parser) {
+ ContainerT elements;
+ auto elementParser = [&]() {
+ auto element = FieldParser<ElementT>::parse(parser);
+ if (failed(element))
+ return failure();
+ elements.push_back(element.getValue());
+ return success();
+ };
+ if (parser.parseCommaSeparatedList(elementParser))
+ return failure();
+ return elements;
+ }
+};
+
} // end namespace mlir
-#endif
+#endif // MLIR_IR_DIALECTIMPLEMENTATION_H
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index e63a2672bf316..37bf1d233c2b3 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2886,6 +2886,11 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
code printer = ?;
code parser = ?;
+ // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
+ // specified at the same time as either 'printer' or 'parser'. The generated
+ // printer requires 'genAccessors' to be true.
+ string assemblyFormat = ?;
+
// If set, generate accessors for each parameter.
bit genAccessors = 1;
@@ -2964,10 +2969,22 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
string cppType = type;
// The C++ type of the accessor for this parameter.
string cppAccessorType = !if(!empty(accessorType), type, accessorType);
+ // The C++ storage type of of this parameter if it is a reference, e.g.
+ // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
+ string cppStorageType = ?;
// One-line human-readable description of the argument.
string summary = desc;
// The format string for the asm syntax (documentation only).
string syntax = ?;
+ // The default parameter parser is `::mlir::parseField<T>($_parser)`, which
+ // returns `FailureOr<T>`. Overload `parseField` to support parsing for your
+ // type. Or you can provide a customer printer. For attributes, "$_type" will
+ // be replaced with the required attribute type.
+ string parser = ?;
+ // The default parameter printer is `$_printer << $_self`. Overload the stream
+ // operator of `DialectAsmPrinter` as necessary to print your type. Or you can
+ // provide a custom printer.
+ string printer = ?;
}
class AttrParameter<string type, string desc, string accessorType = "">
: AttrOrTypeParameter<type, desc, accessorType>;
@@ -2978,6 +2995,8 @@ class TypeParameter<string type, string desc, string accessorType = "">
class StringRefParameter<string desc = ""> :
AttrOrTypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+ let printer = [{$_printer << '"' << $_self << '"';}];
+ let cppStorageType = "std::string";
}
// For APFloats, which require comparison.
@@ -2990,6 +3009,7 @@ class APFloatParameter<string desc> :
class ArrayRefParameter<string arrayOf, string desc = ""> :
AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
+ let cppStorageType = "::llvm::SmallVector<" # arrayOf # ">";
}
// For classes which require allocation and have their own allocateInto method.
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index db1f7a3c071d2..34e6cd08ea3c7 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -182,10 +182,10 @@ operator<<(AsmPrinterT &p, const TypeRange &types) {
llvm::interleaveComma(types, p);
return p;
}
-template <typename AsmPrinterT>
+template <typename AsmPrinterT, typename ElementT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
-operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
+operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
llvm::interleaveComma(types, p);
return p;
}
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 2029c0e624cd3..09294c2fa8081 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -101,6 +101,9 @@ class AttrOrTypeDef {
// None. Otherwise, returns the contents of that code block.
Optional<StringRef> getParserCode() const;
+ // Returns the custom assembly format, if one was specified.
+ Optional<StringRef> getAssemblyFormat() const;
+
// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
@@ -199,6 +202,15 @@ class AttrOrTypeParameter {
// Get the C++ accessor type of this parameter.
StringRef getCppAccessorType() const;
+ // Get the C++ storage type of this parameter.
+ StringRef getCppStorageType() const;
+
+ // Get an optional C++ parameter parser.
+ Optional<StringRef> getParser() const;
+
+ // Get an optional C++ parameter printer.
+ Optional<StringRef> getPrinter() const;
+
// Get a description of this parameter for documentation purposes.
Optional<StringRef> getSummary() const;
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 2de0d9b0406eb..3030d6556b5bd 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -1,3 +1,4 @@
+//===- Dialect.h - Dialect class --------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 2a0ad96ea4e93..f43949c30a222 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -132,6 +132,10 @@ Optional<StringRef> AttrOrTypeDef::getParserCode() const {
return def->getValueAsOptionalString("parser");
}
+Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
+ return def->getValueAsOptionalString("assemblyFormat");
+}
+
bool AttrOrTypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
@@ -219,6 +223,32 @@ StringRef AttrOrTypeParameter::getCppAccessorType() const {
return getCppType();
}
+StringRef AttrOrTypeParameter::getCppStorageType() const {
+ if (auto *param = dyn_cast<llvm::DefInit>(def->getArg(index))) {
+ if (auto type = param->getDef()->getValueAsOptionalString("cppStorageType"))
+ return *type;
+ }
+ return getCppType();
+}
+
+Optional<StringRef> AttrOrTypeParameter::getParser() const {
+ auto *parameterType = def->getArg(index);
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+ if (auto parser = param->getDef()->getValueAsOptionalString("parser"))
+ return *parser;
+ }
+ return {};
+}
+
+Optional<StringRef> AttrOrTypeParameter::getPrinter() const {
+ auto *parameterType = def->getArg(index);
+ if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
+ if (auto printer = param->getDef()->getValueAsOptionalString("printer"))
+ return *printer;
+ }
+ return {};
+}
+
Optional<StringRef> AttrOrTypeParameter::getSummary() const {
auto *parameterType = def->getArg(index);
if (auto *param = dyn_cast<llvm::DefInit>(parameterType)) {
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 3062fd6c65ca9..06e2599def7df 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -116,4 +116,44 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
);
}
+def TestParamOne : AttrParameter<"int64_t", ""> {}
+
+def TestParamTwo : AttrParameter<"std::string", "", "llvm::StringRef"> {
+ let printer = "$_printer << '\"' << $_self << '\"'";
+}
+
+def TestParamFour : ArrayRefParameter<"int", ""> {
+ let cppStorageType = "llvm::SmallVector<int>";
+ let parser = "::parseIntArray($_parser)";
+ let printer = "::printIntArray($_printer, $_self)";
+}
+
+def TestAttrWithFormat : Test_Attr<"TestAttrWithFormat"> {
+ let parameters = (
+ ins
+ TestParamOne:$one,
+ TestParamTwo:$two,
+ "::mlir::IntegerAttr":$three,
+ TestParamFour:$four
+ );
+
+ let mnemonic = "attr_with_format";
+ let assemblyFormat = "`<` $one `:` struct($two, $four) `:` $three `>`";
+ let genVerifyDecl = 1;
+}
+
+def TestAttrUgly : Test_Attr<"TestAttrUgly"> {
+ let parameters = (ins "::mlir::Attribute":$attr);
+
+ let mnemonic = "attr_ugly";
+ let assemblyFormat = "`begin` $attr `end`";
+}
+
+def TestAttrParams: Test_Attr<"TestAttrParams"> {
+ let parameters = (ins "int":$v0, "int":$v1);
+
+ let mnemonic = "attr_params";
+ let assemblyFormat = "`<` params `>`";
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 9cd9c574a7bf2..e0c8ebbb7b93f 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -16,9 +16,11 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/ADT/bit.h"
using namespace mlir;
using namespace test;
@@ -127,6 +129,36 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+LogicalResult
+TestAttrWithFormatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
+ int64_t one, std::string two, IntegerAttr three,
+ ArrayRef<int> four) {
+ if (four.size() != static_cast<unsigned>(one))
+ return emitError() << "expected 'one' to equal 'four.size()'";
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Utility Functions for Generated Attributes
+//===----------------------------------------------------------------------===//
+
+static FailureOr<SmallVector<int>> parseIntArray(DialectAsmParser &parser) {
+ SmallVector<int> ints;
+ if (parser.parseLSquare() || parser.parseCommaSeparatedList([&]() {
+ ints.push_back(0);
+ return parser.parseInteger(ints.back());
+ }) ||
+ parser.parseRSquare())
+ return failure();
+ return ints;
+}
+
+static void printIntArray(DialectAsmPrinter &printer, ArrayRef<int> ints) {
+ printer << '[';
+ llvm::interleaveComma(ints, printer);
+ printer << ']';
+}
+
//===----------------------------------------------------------------------===//
// TestSubElementsAccessAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index aef9baa894737..66008f12d574b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -15,6 +15,7 @@
// To get the test dialect def.
include "TestOps.td"
+include "TestAttrDefs.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -189,4 +190,44 @@ def TestTypeWithTrait : Test_Type<"TestTypeWithTrait", [TestTypeTrait]> {
let mnemonic = "test_type_with_trait";
}
+// Type with assembly format.
+def TestTypeWithFormat : Test_Type<"TestTypeWithFormat"> {
+ let parameters = (
+ ins
+ TestParamOne:$one,
+ TestParamTwo:$two,
+ "::mlir::Attribute":$three
+ );
+
+ let mnemonic = "type_with_format";
+ let assemblyFormat = "`<` $one `,` struct($three, $two) `>`";
+}
+
+// Test dispatch to parseField
+def TestTypeNoParser : Test_Type<"TestTypeNoParser"> {
+ let parameters = (
+ ins
+ "uint32_t":$one,
+ ArrayRefParameter<"int64_t">:$two,
+ StringRefParameter<>:$three,
+ "::test::CustomParam":$four
+ );
+
+ let mnemonic = "no_parser";
+ let assemblyFormat = "`<` $one `,` `[` $two `]` `,` $three `,` $four `>`";
+}
+
+def TestTypeStructCaptureAll : Test_Type<"TestStructTypeCaptureAll"> {
+ let parameters = (
+ ins
+ "int":$v0,
+ "int":$v1,
+ "int":$v2,
+ "int":$v3
+ );
+
+ let mnemonic = "struct_capture_all";
+ let assemblyFormat = "`<` struct(params) `>`";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h
index 9da2e1713d9d0..7614ae401d1f0 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.h
+++ b/mlir/test/lib/Dialect/Test/TestTypes.h
@@ -38,8 +38,38 @@ struct FieldInfo {
}
};
+/// A custom type for a test type parameter.
+struct CustomParam {
+ int value;
+
+ bool operator==(const CustomParam &other) const {
+ return other.value == value;
+ }
+};
+
+inline llvm::hash_code hash_value(const test::CustomParam ¶m) {
+ return llvm::hash_value(param.value);
+}
+
} // namespace test
+namespace mlir {
+template <>
+struct FieldParser<test::CustomParam> {
+ static FailureOr<test::CustomParam> parse(DialectAsmParser &parser) {
+ auto value = FieldParser<int>::parse(parser);
+ if (failed(value))
+ return failure();
+ return test::CustomParam{value.getValue()};
+ }
+};
+} // end namespace mlir
+
+inline mlir::DialectAsmPrinter &operator<<(mlir::DialectAsmPrinter &printer,
+ const test::CustomParam ¶m) {
+ return printer << param.value;
+}
+
#include "TestTypeInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
@@ -52,17 +82,19 @@ namespace test {
struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
using KeyTy = ::llvm::StringRef;
- explicit TestRecursiveTypeStorage(::llvm::StringRef key) : name(key), body(::mlir::Type()) {}
+ explicit TestRecursiveTypeStorage(::llvm::StringRef key)
+ : name(key), body(::mlir::Type()) {}
bool operator==(const KeyTy &other) const { return name == other; }
- static TestRecursiveTypeStorage *construct(::mlir::TypeStorageAllocator &allocator,
- const KeyTy &key) {
+ static TestRecursiveTypeStorage *
+ construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
return new (allocator.allocate<TestRecursiveTypeStorage>())
TestRecursiveTypeStorage(allocator.copyInto(key));
}
- ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator, ::mlir::Type newBody) {
+ ::mlir::LogicalResult mutate(::mlir::TypeStorageAllocator &allocator,
+ ::mlir::Type newBody) {
// Cannot set a
diff erent body than before.
if (body && body != newBody)
return ::mlir::failure();
@@ -79,11 +111,13 @@ struct TestRecursiveTypeStorage : public ::mlir::TypeStorage {
/// type, potentially itself. This requires the body to be mutated separately
/// from type creation.
class TestRecursiveType
- : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type, TestRecursiveTypeStorage> {
+ : public ::mlir::Type::TypeBase<TestRecursiveType, ::mlir::Type,
+ TestRecursiveTypeStorage> {
public:
using Base::Base;
- static TestRecursiveType get(::mlir::MLIRContext *ctx, ::llvm::StringRef name) {
+ static TestRecursiveType get(::mlir::MLIRContext *ctx,
+ ::llvm::StringRef name) {
return Base::get(ctx, name);
}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
new file mode 100644
index 0000000000000..372aef6dfa3e5
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -0,0 +1,76 @@
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include -asmformat-error-is-fatal=false %s 2>&1 | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "TestDialect";
+ let cppNamespace = "::test";
+}
+
+class InvalidType<string name, string asm> : TypeDef<Test_Dialect, name> {
+ let mnemonic = asm;
+}
+
+/// Test format is missing a parameter capture.
+def InvalidTypeA : InvalidType<"InvalidTypeA", "invalid_a"> {
+ let parameters = (ins "int":$v0, "int":$v1);
+ // CHECK: format is missing reference to parameter: v1
+ let assemblyFormat = "`<` $v0 `>`";
+}
+
+/// Test format has duplicate parameter captures.
+def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
+ let parameters = (ins "int":$v0, "int":$v1);
+ // CHECK: duplicate parameter 'v0'
+ let assemblyFormat = "`<` $v0 `,` $v1 `,` $v0 `>`";
+}
+
+/// Test format has invalid syntax.
+def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
+ let parameters = (ins "int":$v0, "int":$v1);
+ // CHECK: expected literal, directive, or variable
+ let assemblyFormat = "`<` $v0, $v1 `>`";
+}
+
+/// Test struct directive has invalid syntax.
+def InvalidTypeD : InvalidType<"InvalidTypeD", "invalid_d"> {
+ let parameters = (ins "int":$v0);
+ // CHECK: literals may only be used in the top-level section of the format
+ // CHECK: expected a variable in `struct` argument list
+ let assemblyFormat = "`<` struct($v0, `,`) `>`";
+}
+
+/// Test struct directive cannot capture zero parameters.
+def InvalidTypeE : InvalidType<"InvalidTypeE", "invalid_e"> {
+ let parameters = (ins "int":$v0);
+ // CHECK: `struct` argument list expected a variable or directive
+ let assemblyFormat = "`<` struct() $v0 `>`";
+}
+
+/// Test capture parameter that does not exist.
+def InvalidTypeF : InvalidType<"InvalidTypeF", "invalid_f"> {
+ let parameters = (ins "int":$v0);
+ // CHECK: InvalidTypeF has no parameter named 'v1'
+ let assemblyFormat = "`<` $v0 $v1 `>`";
+}
+
+/// Test duplicate capture of parameter in capture-all struct.
+def InvalidTypeG : InvalidType<"InvalidTypeG", "invalid_g"> {
+ let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
+ // CHECK: duplicate parameter 'v0'
+ let assemblyFormat = "`<` struct(params) $v0 `>`";
+}
+
+/// Test capture-all struct duplicate capture.
+def InvalidTypeH : InvalidType<"InvalidTypeH", "invalid_h"> {
+ let parameters = (ins "int":$v0, "int":$v1, "int":$v2);
+ // CHECK: `params` captures duplicate parameter: v0
+ let assemblyFormat = "`<` $v0 struct(params) `>`";
+}
+
+/// Test capture of parameter after `params` directive.
+def InvalidTypeI : InvalidType<"InvalidTypeI", "invalid_i"> {
+ let parameters = (ins "int":$v0);
+ // CHECK: duplicate parameter 'v0'
+ let assemblyFormat = "`<` params $v0 `>`";
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
new file mode 100644
index 0000000000000..f403f6f4b059d
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+
+// CHECK-LABEL: @test_roundtrip_parameter_parsers
+// CHECK: !test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">
+// CHECK: !test.type_with_format<2147, three = "hi", two = "hi">
+func private @test_roundtrip_parameter_parsers(!test.type_with_format<111, three = #test<"attr_ugly begin 5 : index end">, two = "foo">) -> !test.type_with_format<2147, two = "hi", three = "hi">
+attributes {
+ // CHECK: #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>
+ attr0 = #test.attr_with_format<3 : two = "hello", four = [1, 2, 3] : 42 : i64>,
+ // CHECK: #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>
+ attr1 = #test.attr_with_format<5 : two = "a_string", four = [4, 5, 6, 7, 8] : 8 : i8>,
+ // CHECK: #test<"attr_ugly begin 5 : index end">
+ attr2 = #test<"attr_ugly begin 5 : index end">,
+ // CHECK: #test.attr_params<42, 24>
+ attr3 = #test.attr_params<42, 24>
+}
+
+// CHECK-LABEL: @test_roundtrip_default_parsers_struct
+// CHECK: !test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
+// CHECK: !test.struct_capture_all<v0 = 0, v1 = 1, v2 = 2, v3 = 3>
+func private @test_roundtrip_default_parsers_struct(!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>) -> !test.struct_capture_all<v3 = 3, v1 = 1, v2 = 2, v0 = 0>
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.mlir b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
new file mode 100644
index 0000000000000..3ff638c5f640f
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.mlir
@@ -0,0 +1,127 @@
+// RUN: mlir-opt --split-input-file %s --verify-diagnostics
+
+func private @test_ugly_attr_cannot_be_pretty() -> () attributes {
+ // expected-error at +1 {{expected 'begin'}}
+ attr = #test.attr_ugly
+}
+
+// -----
+
+func private @test_ugly_attr_no_mnemonic() -> () attributes {
+ // expected-error at +1 {{expected valid keyword}}
+ attr = #test<"">
+}
+
+// -----
+
+func private @test_ugly_attr_parser_dispatch() -> () attributes {
+ // expected-error at +1 {{expected 'begin'}}
+ attr = #test<"attr_ugly">
+}
+
+// -----
+
+func private @test_ugly_attr_missing_parameter() -> () attributes {
+ // expected-error at +2 {{failed to parse TestAttrUgly parameter 'attr'}}
+ // expected-error at +1 {{expected non-function type}}
+ attr = #test<"attr_ugly begin">
+}
+
+// -----
+
+func private @test_ugly_attr_missing_literal() -> () attributes {
+ // expected-error at +1 {{expected 'end'}}
+ attr = #test<"attr_ugly begin \"string_attr\"">
+}
+
+// -----
+
+func private @test_pretty_attr_expects_less() -> () attributes {
+ // expected-error at +1 {{expected '<'}}
+ attr = #test.attr_with_format
+}
+
+// -----
+
+func private @test_pretty_attr_missing_param() -> () attributes {
+ // expected-error at +2 {{expected integer value}}
+ // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'one'}}
+ attr = #test.attr_with_format<>
+}
+
+// -----
+
+func private @test_parse_invalid_param() -> () attributes {
+ // Test parameter parser failure is propagated
+ // expected-error at +2 {{expected integer value}}
+ // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'one'}}
+ attr = #test.attr_with_format<"hi">
+}
+
+// -----
+
+func private @test_pretty_attr_invalid_syntax() -> () attributes {
+ // expected-error at +1 {{expected ':'}}
+ attr = #test.attr_with_format<42>
+}
+
+// -----
+
+func private @test_struct_missing_key() -> () attributes {
+ // expected-error at +2 {{expected valid keyword}}
+ // expected-error at +1 {{expected a parameter name in struct}}
+ attr = #test.attr_with_format<42 :>
+}
+
+// -----
+
+func private @test_struct_unknown_key() -> () attributes {
+ // expected-error at +1 {{duplicate or unknown struct parameter}}
+ attr = #test.attr_with_format<42 : nine = "foo">
+}
+
+// -----
+
+func private @test_struct_duplicate_key() -> () attributes {
+ // expected-error at +1 {{duplicate or unknown struct parameter}}
+ attr = #test.attr_with_format<42 : two = "foo", two = "bar">
+}
+
+// -----
+
+func private @test_struct_not_enough_values() -> () attributes {
+ // expected-error at +1 {{expected ','}}
+ attr = #test.attr_with_format<42 : two = "foo">
+}
+
+// -----
+
+func private @test_parse_param_after_struct() -> () attributes {
+ // expected-error at +2 {{expected non-function type}}
+ // expected-error at +1 {{failed to parse TestAttrWithFormat parameter 'three'}}
+ attr = #test.attr_with_format<42 : two = "foo", four = [1, 2, 3] : >
+}
+
+// -----
+
+// expected-error at +1 {{expected '<'}}
+func private @test_invalid_type() -> !test.type_with_format
+
+// -----
+
+// expected-error at +2 {{expected integer value}}
+// expected-error at +1 {{failed to parse TestTypeWithFormat parameter 'one'}}
+func private @test_pretty_type_invalid_param() -> !test.type_with_format<>
+
+// -----
+
+// expected-error at +2 {{expected ':'}}
+// expected-error at +1 {{failed to parse TestTypeWithFormat parameter 'three'}}
+func private @test_type_syntax_error() -> !test.type_with_format<42, two = "hi", three = #test.attr_with_format<42>>
+
+// -----
+
+func private @test_verifier_fails() -> () attributes {
+ // expected-error at +1 {{expected 'one' to equal 'four.size()'}}
+ attr = #test.attr_with_format<42 : two = "hello", four = [1, 2, 3] : 42 : i64>
+}
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
new file mode 100644
index 0000000000000..2d426935fa415
--- /dev/null
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -0,0 +1,394 @@
+// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=ATTR
+// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=TYPE
+
+include "mlir/IR/OpBase.td"
+
+/// Test that attribute and type printers and parsers are correctly generated.
+def Test_Dialect : Dialect {
+ let name = "TestDialect";
+ let cppNamespace = "::test";
+}
+
+class TestAttr<string name> : AttrDef<Test_Dialect, name>;
+class TestType<string name> : TypeDef<Test_Dialect, name>;
+
+def AttrParamA : AttrParameter<"TestParamA", "an attribute param A"> {
+ let parser = "::parseAttrParamA($_parser, $_type)";
+ let printer = "::printAttrParamA($_printer, $_self)";
+}
+
+def AttrParamB : AttrParameter<"TestParamB", "an attribute param B"> {
+ let parser = "$_type ? ::parseAttrWithType($_parser, $_type) : ::parseAttrWithout($_parser)";
+ let printer = "::printAttrB($_printer, $_self)";
+}
+
+def TypeParamA : TypeParameter<"TestParamC", "a type param C"> {
+ let parser = "::parseTypeParamC($_parser)";
+ let printer = "$_printer << $_self";
+}
+
+def TypeParamB : TypeParameter<"TestParamD", "a type param D"> {
+ let parser = "someFcnCall()";
+ let printer = "myPrinter($_self)";
+}
+
+/// Check simple attribute parser and printer are generated correctly.
+
+// ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR: ::mlir::Type attrType) {
+// ATTR: FailureOr<IntegerAttr> _result_value;
+// ATTR: FailureOr<TestParamA> _result_complex;
+// ATTR: if (parser.parseKeyword("hello"))
+// ATTR: return {};
+// ATTR: if (parser.parseEqual())
+// ATTR: return {};
+// ATTR: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// ATTR: if (failed(_result_value))
+// ATTR: return {};
+// ATTR: if (parser.parseComma())
+// ATTR: return {};
+// ATTR: _result_complex = ::parseAttrParamA(parser, attrType);
+// ATTR: if (failed(_result_complex))
+// ATTR: return {};
+// ATTR: if (parser.parseRParen())
+// ATTR: return {};
+// ATTR: return TestAAttr::get(parser.getContext(),
+// ATTR: _result_value.getValue(),
+// ATTR: _result_complex.getValue());
+// ATTR: }
+
+// ATTR: void TestAAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR: printer << "attr_a";
+// ATTR: printer << ' ' << "hello";
+// ATTR: printer << ' ' << "=";
+// ATTR: printer << ' ';
+// ATTR: printer << getValue();
+// ATTR: printer << ",";
+// ATTR: printer << ' ';
+// ATTR: ::printAttrParamA(printer, getComplex());
+// ATTR: printer << ")";
+// ATTR: }
+
+def AttrA : TestAttr<"TestA"> {
+ let parameters = (ins
+ "IntegerAttr":$value,
+ AttrParamA:$complex
+ );
+
+ let mnemonic = "attr_a";
+ let assemblyFormat = "`hello` `=` $value `,` $complex `)`";
+}
+
+/// Test simple struct parser and printer are generated correctly.
+
+// ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR: ::mlir::Type attrType) {
+// ATTR: bool _seen_v0 = false;
+// ATTR: bool _seen_v1 = false;
+// ATTR: for (unsigned _index = 0; _index < 2; ++_index) {
+// ATTR: StringRef _paramKey;
+// ATTR: if (parser.parseKeyword(&_paramKey))
+// ATTR: return {};
+// ATTR: if (parser.parseEqual())
+// ATTR: return {};
+// ATTR: if (!_seen_v0 && _paramKey == "v0") {
+// ATTR: _seen_v0 = true;
+// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType);
+// ATTR: if (failed(_result_v0))
+// ATTR: return {};
+// ATTR: } else if (!_seen_v1 && _paramKey == "v1") {
+// ATTR: _seen_v1 = true;
+// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser);
+// ATTR: if (failed(_result_v1))
+// ATTR: return {};
+// ATTR: } else {
+// ATTR: return {};
+// ATTR: }
+// ATTR: if ((_index != 2 - 1) && parser.parseComma())
+// ATTR: return {};
+// ATTR: }
+// ATTR: return TestBAttr::get(parser.getContext(),
+// ATTR: _result_v0.getValue(),
+// ATTR: _result_v1.getValue());
+// ATTR: }
+
+// ATTR: void TestBAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR: printer << "v0";
+// ATTR: printer << ' ' << "=";
+// ATTR: printer << ' ';
+// ATTR: ::printAttrParamA(printer, getV0());
+// ATTR: printer << ",";
+// ATTR: printer << ' ' << "v1";
+// ATTR: printer << ' ' << "=";
+// ATTR: printer << ' ';
+// ATTR: ::printAttrB(printer, getV1());
+// ATTR: }
+
+def AttrB : TestAttr<"TestB"> {
+ let parameters = (ins
+ AttrParamA:$v0,
+ AttrParamB:$v1
+ );
+
+ let mnemonic = "attr_b";
+ let assemblyFormat = "`{` struct($v0, $v1) `}`";
+}
+
+/// Test attribute with capture-all params has correct parser and printer.
+
+// ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::DialectAsmParser &parser,
+// ATTR: ::mlir::Type attrType) {
+// ATTR: ::mlir::FailureOr<int> _result_v0;
+// ATTR: ::mlir::FailureOr<int> _result_v1;
+// ATTR: _result_v0 = ::mlir::FieldParser<int>::parse(parser);
+// ATTR: if (failed(_result_v0))
+// ATTR: return {};
+// ATTR: if (parser.parseComma())
+// ATTR: return {};
+// ATTR: _result_v1 = ::mlir::FieldParser<int>::parse(parser);
+// ATTR: if (failed(_result_v1))
+// ATTR: return {};
+// ATTR: return TestFAttr::get(parser.getContext(),
+// ATTR: _result_v0.getValue(),
+// ATTR: _result_v1.getValue());
+// ATTR: }
+
+// ATTR: void TestFAttr::print(::mlir::DialectAsmPrinter &printer) const {
+// ATTR: printer << "attr_c";
+// ATTR: printer << ' ';
+// ATTR: printer << getV0();
+// ATTR: printer << ",";
+// ATTR: printer << ' ';
+// ATTR: printer << getV1();
+// ATTR: }
+
+def AttrC : TestAttr<"TestF"> {
+ let parameters = (ins "int":$v0, "int":$v1);
+
+ let mnemonic = "attr_c";
+ let assemblyFormat = "params";
+}
+
+/// Test type parser and printer that mix variables and struct are generated
+/// correctly.
+
+// TYPE: ::mlir::Type TestCType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE: FailureOr<IntegerAttr> _result_value;
+// TYPE: FailureOr<TestParamC> _result_complex;
+// TYPE: if (parser.parseKeyword("foo"))
+// TYPE: return {};
+// TYPE: if (parser.parseComma())
+// TYPE: return {};
+// TYPE: if (parser.parseColon())
+// TYPE: return {};
+// TYPE: if (parser.parseKeyword("bob"))
+// TYPE: return {};
+// TYPE: if (parser.parseKeyword("bar"))
+// TYPE: return {};
+// TYPE: _result_value = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE: if (failed(_result_value))
+// TYPE: return {};
+// TYPE: bool _seen_complex = false;
+// TYPE: for (unsigned _index = 0; _index < 1; ++_index) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (parser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (!_seen_complex && _paramKey == "complex") {
+// TYPE: _seen_complex = true;
+// TYPE: _result_complex = ::parseTypeParamC(parser);
+// TYPE: if (failed(_result_complex))
+// TYPE: return {};
+// TYPE: } else {
+// TYPE: return {};
+// TYPE: }
+// TYPE: if ((_index != 1 - 1) && parser.parseComma())
+// TYPE: return {};
+// TYPE: }
+// TYPE: if (parser.parseRParen())
+// TYPE: return {};
+// TYPE: }
+
+// TYPE: void TestCType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE: printer << "type_c";
+// TYPE: printer << ' ' << "foo";
+// TYPE: printer << ",";
+// TYPE: printer << ' ' << ":";
+// TYPE: printer << ' ' << "bob";
+// TYPE: printer << ' ' << "bar";
+// TYPE: printer << ' ';
+// TYPE: printer << getValue();
+// TYPE: printer << ' ' << "complex";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getComplex();
+// TYPE: printer << ")";
+// TYPE: }
+
+def TypeA : TestType<"TestC"> {
+ let parameters = (ins
+ "IntegerAttr":$value,
+ TypeParamA:$complex
+ );
+
+ let mnemonic = "type_c";
+ let assemblyFormat = "`foo` `,` `:` `bob` `bar` $value struct($complex) `)`";
+}
+
+/// Test type parser and printer with mix of variables and struct are generated
+/// correctly.
+
+// TYPE: ::mlir::Type TestDType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE: _result_v0 = ::parseTypeParamC(parser);
+// TYPE: if (failed(_result_v0))
+// TYPE: return {};
+// TYPE: bool _seen_v1 = false;
+// TYPE: bool _seen_v2 = false;
+// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (parser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (parser.parseEqual())
+// TYPE: return {};
+// TYPE: if (!_seen_v1 && _paramKey == "v1") {
+// TYPE: _seen_v1 = true;
+// TYPE: _result_v1 = someFcnCall();
+// TYPE: if (failed(_result_v1))
+// TYPE: return {};
+// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
+// TYPE: _seen_v2 = true;
+// TYPE: _result_v2 = ::parseTypeParamC(parser);
+// TYPE: if (failed(_result_v2))
+// TYPE: return {};
+// TYPE: } else {
+// TYPE: return {};
+// TYPE: }
+// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return {};
+// TYPE: }
+// TYPE: _result_v3 = someFcnCall();
+// TYPE: if (failed(_result_v3))
+// TYPE: return {};
+// TYPE: return TestDType::get(parser.getContext(),
+// TYPE: _result_v0.getValue(),
+// TYPE: _result_v1.getValue(),
+// TYPE: _result_v2.getValue(),
+// TYPE: _result_v3.getValue());
+// TYPE: }
+
+// TYPE: void TestDType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE: printer << getV0();
+// TYPE: myPrinter(getV1());
+// TYPE: printer << ' ' << "v2";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getV2();
+// TYPE: myPrinter(getV3());
+// TYPE: }
+
+def TypeB : TestType<"TestD"> {
+ let parameters = (ins
+ TypeParamA:$v0,
+ TypeParamB:$v1,
+ TypeParamA:$v2,
+ TypeParamB:$v3
+ );
+
+ let mnemonic = "type_d";
+ let assemblyFormat = "`<` `foo` `:` $v0 `,` struct($v1, $v2) `,` $v3 `>`";
+}
+
+/// Type test with two struct directives has correctly generated parser and
+/// printer.
+
+// TYPE: ::mlir::Type TestEType::parse(::mlir::DialectAsmParser &parser) {
+// TYPE: FailureOr<IntegerAttr> _result_v0;
+// TYPE: FailureOr<IntegerAttr> _result_v1;
+// TYPE: FailureOr<IntegerAttr> _result_v2;
+// TYPE: FailureOr<IntegerAttr> _result_v3;
+// TYPE: bool _seen_v0 = false;
+// TYPE: bool _seen_v2 = false;
+// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (parser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (parser.parseEqual())
+// TYPE: return {};
+// TYPE: if (!_seen_v0 && _paramKey == "v0") {
+// TYPE: _seen_v0 = true;
+// TYPE: _result_v0 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE: if (failed(_result_v0))
+// TYPE: return {};
+// TYPE: } else if (!_seen_v2 && _paramKey == "v2") {
+// TYPE: _seen_v2 = true;
+// TYPE: _result_v2 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE: if (failed(_result_v2))
+// TYPE: return {};
+// TYPE: } else {
+// TYPE: return {};
+// TYPE: }
+// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return {};
+// TYPE: }
+// TYPE: bool _seen_v1 = false;
+// TYPE: bool _seen_v3 = false;
+// TYPE: for (unsigned _index = 0; _index < 2; ++_index) {
+// TYPE: StringRef _paramKey;
+// TYPE: if (parser.parseKeyword(&_paramKey))
+// TYPE: return {};
+// TYPE: if (parser.parseEqual())
+// TYPE: return {};
+// TYPE: if (!_seen_v1 && _paramKey == "v1") {
+// TYPE: _seen_v1 = true;
+// TYPE: _result_v1 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE: if (failed(_result_v1))
+// TYPE: return {};
+// TYPE: } else if (!_seen_v3 && _paramKey == "v3") {
+// TYPE: _seen_v3 = true;
+// TYPE: _result_v3 = ::mlir::FieldParser<IntegerAttr>::parse(parser);
+// TYPE: if (failed(_result_v3))
+// TYPE: return {};
+// TYPE: } else {
+// TYPE: return {};
+// TYPE: }
+// TYPE: if ((_index != 2 - 1) && parser.parseComma())
+// TYPE: return {};
+// TYPE: }
+// TYPE: return TestEType::get(parser.getContext(),
+// TYPE: _result_v0.getValue(),
+// TYPE: _result_v1.getValue(),
+// TYPE: _result_v2.getValue(),
+// TYPE: _result_v3.getValue());
+// TYPE: }
+
+// TYPE: void TestEType::print(::mlir::DialectAsmPrinter &printer) const {
+// TYPE: printer << "v0";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getV0();
+// TYPE: printer << ",";
+// TYPE: printer << ' ' << "v2";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getV2();
+// TYPE: printer << "v1";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getV1();
+// TYPE: printer << ",";
+// TYPE: printer << ' ' << "v3";
+// TYPE: printer << ' ' << "=";
+// TYPE: printer << ' ';
+// TYPE: printer << getV3();
+// TYPE: }
+
+def TypeC : TestType<"TestE"> {
+ let parameters = (ins
+ "IntegerAttr":$v0,
+ "IntegerAttr":$v1,
+ "IntegerAttr":$v2,
+ "IntegerAttr":$v3
+ );
+
+ let mnemonic = "type_e";
+ let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
+}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index a1b0836a55d7b..5e4cb4d73a392 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "AttrOrTypeFormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/CodeGenHelpers.h"
@@ -24,6 +25,17 @@
using namespace mlir;
using namespace mlir::tblgen;
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+std::string mlir::tblgen::getParameterAccessorName(StringRef name) {
+ assert(!name.empty() && "parameter has empty name");
+ auto ret = "get" + name.str();
+ ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name
+ return ret;
+}
+
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
@@ -399,7 +411,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
<< " }\n";
// If mnemonic specified, emit print/parse declarations.
- if (def.getParserCode() || def.getPrinterCode() || !params.empty()) {
+ if (def.getParserCode() || def.getPrinterCode() ||
+ def.getAssemblyFormat() || !params.empty()) {
os << llvm::formatv(defDeclParsePrintStr, valueType,
isAttrGenerator ? ", ::mlir::Type type" : "");
}
@@ -410,10 +423,8 @@ void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) {
def.getParameters(parameters);
for (AttrOrTypeParameter ¶meter : parameters) {
- SmallString<16> name = parameter.getName();
- name[0] = llvm::toUpper(name[0]);
- os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(),
- name);
+ os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(),
+ getParameterAccessorName(parameter.getName()));
}
}
@@ -700,8 +711,32 @@ void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) {
}
void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
+ auto printerCode = def.getPrinterCode();
+ auto parserCode = def.getParserCode();
+ auto assemblyFormat = def.getAssemblyFormat();
+ if (assemblyFormat && (printerCode || parserCode)) {
+ // Custom assembly format cannot be specified at the same time as either
+ // custom printer or parser code.
+ PrintFatalError(def.getLoc(),
+ def.getName() + ": assembly format cannot be specified at "
+ "the same time as printer or parser code");
+ }
+
+ // Generate a parser and printer based on the assembly format, if specified.
+ if (assemblyFormat) {
+ // A custom assembly format requires accessors to be generated for the
+ // generated printer.
+ if (!def.genAccessors()) {
+ PrintFatalError(def.getLoc(),
+ def.getName() +
+ ": the generated printer from 'assemblyFormat' "
+ "requires 'genAccessors' to be true");
+ }
+ return generateAttrOrTypeFormat(def, os);
+ }
+
// Emit the printer code, if specified.
- if (Optional<StringRef> printerCode = def.getPrinterCode()) {
+ if (printerCode) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << def.getCppClassName()
@@ -717,7 +752,7 @@ void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) {
}
// Emit the parser code, if specified.
- if (Optional<StringRef> parserCode = def.getParserCode()) {
+ if (parserCode) {
FmtContext fmtCtxt;
fmtCtxt.addSubst("_parser", "parser")
.addSubst("_ctxt", "parser.getContext()");
@@ -857,11 +892,10 @@ void DefGenerator::emitDefDef(const AttrOrTypeDef &def) {
paramStorageName = param.getName();
}
- SmallString<16> name = param.getName();
- name[0] = llvm::toUpper(name[0]);
- os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n",
- param.getCppAccessorType(), name, paramStorageName,
- def.getCppClassName());
+ os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n",
+ param.getCppAccessorType(),
+ getParameterAccessorName(param.getName()),
+ paramStorageName, def.getCppClassName());
}
}
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
new file mode 100644
index 0000000000000..52e921f2fb27a
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -0,0 +1,781 @@
+//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "AttrOrTypeFormatGen.h"
+#include "FormatGen.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/TableGen/AttrOrTypeDef.h"
+#include "mlir/TableGen/Format.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/BitVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+using llvm::formatv;
+
+//===----------------------------------------------------------------------===//
+// Element
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// This class represents a single format element.
+class Element {
+public:
+ /// LLVM-style RTTI.
+ enum class Kind {
+ /// This element is a directive.
+ ParamsDirective,
+ StructDirective,
+
+ /// This element is a literal.
+ Literal,
+
+ /// This element is a variable.
+ Variable,
+ };
+ Element(Kind kind) : kind(kind) {}
+ virtual ~Element() = default;
+
+ /// Return the kind of this element.
+ Kind getKind() const { return kind; }
+
+private:
+ /// The kind of this element.
+ Kind kind;
+};
+
+/// This class represents an instance of a literal element.
+class LiteralElement : public Element {
+public:
+ LiteralElement(StringRef literal)
+ : Element(Kind::Literal), literal(literal) {}
+
+ static bool classof(const Element *el) {
+ return el->getKind() == Kind::Literal;
+ }
+
+ /// Get the literal spelling.
+ StringRef getSpelling() const { return literal; }
+
+private:
+ /// The spelling of the literal for this element.
+ StringRef literal;
+};
+
+/// This class represents an instance of a variable element. A variable refers
+/// to an attribute or type parameter.
+class VariableElement : public Element {
+public:
+ VariableElement(AttrOrTypeParameter param)
+ : Element(Kind::Variable), param(param) {}
+
+ static bool classof(const Element *el) {
+ return el->getKind() == Kind::Variable;
+ }
+
+ /// Get the parameter in the element.
+ const AttrOrTypeParameter &getParam() const { return param; }
+
+private:
+ AttrOrTypeParameter param;
+};
+
+/// Base class for a directive that contains references to multiple variables.
+template <Element::Kind ElementKind>
+class ParamsDirectiveBase : public Element {
+public:
+ using Base = ParamsDirectiveBase<ElementKind>;
+
+ ParamsDirectiveBase(SmallVector<std::unique_ptr<Element>> &¶ms)
+ : Element(ElementKind), params(std::move(params)) {}
+
+ static bool classof(const Element *el) {
+ return el->getKind() == ElementKind;
+ }
+
+ /// Get the parameters contained in this directive.
+ auto getParams() const {
+ return llvm::map_range(params, [](auto &el) {
+ return cast<VariableElement>(el.get())->getParam();
+ });
+ }
+
+ /// Get the number of parameters.
+ unsigned getNumParams() const { return params.size(); }
+
+ /// Take all of the parameters from this directive.
+ SmallVector<std::unique_ptr<Element>> takeParams() {
+ return std::move(params);
+ }
+
+private:
+ /// The parameters captured by this directive.
+ SmallVector<std::unique_ptr<Element>> params;
+};
+
+/// This class represents a `params` directive that refers to all parameters
+/// of an attribute or type. When used as a top-level directive, it generates
+/// a format of the form:
+///
+/// (param-value (`,` param-value)*)?
+///
+/// When used as an argument to another directive that accepts variables,
+/// `params` can be used in place of manually listing all parameters of an
+/// attribute or type.
+class ParamsDirective
+ : public ParamsDirectiveBase<Element::Kind::ParamsDirective> {
+public:
+ using Base::Base;
+};
+
+/// This class represents a `struct` directive that generates a struct format
+/// of the form:
+///
+/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
+///
+class StructDirective
+ : public ParamsDirectiveBase<Element::Kind::StructDirective> {
+public:
+ using Base::Base;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Format Strings
+//===----------------------------------------------------------------------===//
+
+/// Format for defining an attribute parser.
+///
+/// $0: The attribute C++ class name.
+static const char *const attrParserDefn = R"(
+::mlir::Attribute $0::parse(::mlir::DialectAsmParser &$_parser,
+ ::mlir::Type $_type) {
+)";
+
+/// Format for defining a type parser.
+///
+/// $0: The type C++ class name.
+static const char *const typeParserDefn = R"(
+::mlir::Type $0::parse(::mlir::DialectAsmParser &$_parser) {
+)";
+
+/// Default parser for attribute or type parameters.
+static const char *const defaultParameterParser =
+ "::mlir::FieldParser<$0>::parse($_parser)";
+
+/// Default printer for attribute or type parameters.
+static const char *const defaultParameterPrinter = "$_printer << $_self";
+
+/// Print an error when failing to parse an element.
+///
+/// $0: The parameter C++ class name.
+static const char *const parseErrorStr =
+ "$_parser.emitError($_parser.getCurrentLocation(), ";
+
+/// Format for defining an attribute or type printer.
+///
+/// $0: The attribute or type C++ class name.
+/// $1: The attribute or type mnemonic.
+static const char *const attrOrTypePrinterDefn = R"(
+void $0::print(::mlir::DialectAsmPrinter &$_printer) const {
+ $_printer << "$1";
+)";
+
+/// Loop declaration for struct parser.
+///
+/// $0: Number of expected parameters.
+static const char *const structParseLoopStart = R"(
+ for (unsigned _index = 0; _index < $0; ++_index) {
+ StringRef _paramKey;
+ if ($_parser.parseKeyword(&_paramKey)) {
+ $_parser.emitError($_parser.getCurrentLocation(),
+ "expected a parameter name in struct");
+ return {};
+ }
+)";
+
+/// Terminator code segment for the struct parser loop. Check for duplicate or
+/// unknown parameters. Parse a comma except on the last element.
+///
+/// {0}: Code template for printing an error.
+/// {1}: Number of elements in the struct.
+static const char *const structParseLoopEnd = R"({{
+ {0}"duplicate or unknown struct parameter name: ") << _paramKey;
+ return {{};
+ }
+ if ((_index != {1} - 1) && parser.parseComma())
+ return {{};
+ }
+)";
+
+/// Code format to parse a variable. Separate by lines because variable parsers
+/// may be generated inside other directives, which requires indentation.
+///
+/// {0}: The parameter name.
+/// {1}: The parse code for the parameter.
+/// {2}: Code template for printing an error.
+/// {3}: Name of the attribute or type.
+/// {4}: C++ class of the parameter.
+static const char *const variableParser[] = {
+ " // Parse variable '{0}'",
+ " _result_{0} = {1};",
+ " if (failed(_result_{0})) {{",
+ " {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");",
+ " return {{};",
+ " }",
+};
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+/// Get a list of an attribute's or type's parameters. These can be wrapper
+/// objects around `AttrOrTypeParameter` or string inits.
+static auto getParameters(const AttrOrTypeDef &def) {
+ SmallVector<AttrOrTypeParameter> params;
+ def.getParameters(params);
+ return params;
+}
+
+//===----------------------------------------------------------------------===//
+// AttrOrTypeFormat
+//===----------------------------------------------------------------------===//
+
+namespace {
+class AttrOrTypeFormat {
+public:
+ AttrOrTypeFormat(const AttrOrTypeDef &def,
+ std::vector<std::unique_ptr<Element>> &&elements)
+ : def(def), elements(std::move(elements)) {}
+
+ /// Generate the attribute or type parser.
+ void genParser(raw_ostream &os);
+ /// Generate the attribute or type printer.
+ void genPrinter(raw_ostream &os);
+
+private:
+ /// Generate the parser code for a specific format element.
+ void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os);
+ /// Generate the parser code for a literal.
+ void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os,
+ unsigned indent = 0);
+ /// Generate the parser code for a variable.
+ void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx,
+ raw_ostream &os, unsigned indent = 0);
+ /// Generate the parser code for a `params` directive.
+ void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+ /// Generate the parser code for a `struct` directive.
+ void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+
+ /// Generate the printer code for a specific format element.
+ void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os);
+ /// Generate the printer code for a literal.
+ void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os);
+ /// Generate the printer code for a variable.
+ void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx,
+ raw_ostream &os);
+ /// Generate the printer code for a `params` directive.
+ void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os);
+ /// Generate the printer code for a `struct` directive.
+ void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os);
+
+ /// The ODS definition of the attribute or type whose format is being used to
+ /// generate a parser and printer.
+ const AttrOrTypeDef &def;
+ /// The list of top-level format elements returned by the assembly format
+ /// parser.
+ std::vector<std::unique_ptr<Element>> elements;
+
+ /// Flags for printing spaces.
+ bool shouldEmitSpace;
+ bool lastWasPunctuation;
+};
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// ParserGen
+//===----------------------------------------------------------------------===//
+
+void AttrOrTypeFormat::genParser(raw_ostream &os) {
+ FmtContext ctx;
+ ctx.addSubst("_parser", "parser");
+
+ /// Generate the definition.
+ if (isa<AttrDef>(def)) {
+ ctx.addSubst("_type", "attrType");
+ os << tgfmt(attrParserDefn, &ctx, def.getCppClassName());
+ } else {
+ os << tgfmt(typeParserDefn, &ctx, def.getCppClassName());
+ }
+
+ /// Declare variables to store all of the parameters. Allocated parameters
+ /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
+ /// FailureOr<T> to defer type construction for parameters that are parsed in
+ /// a loop (parsers return FailureOr anyways).
+ SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ for (const AttrOrTypeParameter ¶m : params) {
+ os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n",
+ param.getCppStorageType(), param.getName());
+ }
+
+ /// Store the initial location of the parser.
+ ctx.addSubst("_loc", "loc");
+ os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
+ " (void) $_loc;\n",
+ &ctx);
+
+ /// Generate call to each parameter parser.
+ for (auto &el : elements)
+ genElementParser(el.get(), ctx, os);
+
+ /// Generate call to the attribute or type builder. Use the checked getter
+ /// if one was generated.
+ if (def.genVerifyDecl()) {
+ os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
+ &ctx, def.getCppClassName());
+ } else {
+ os << tgfmt(" return $0::get($_parser.getContext()", &ctx,
+ def.getCppClassName());
+ }
+ for (const AttrOrTypeParameter ¶m : params)
+ os << formatv(",\n _result_{0}.getValue()", param.getName());
+ os << ");\n}\n\n";
+}
+
+void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
+ raw_ostream &os) {
+ if (auto *literal = dyn_cast<LiteralElement>(el))
+ return genLiteralParser(literal->getSpelling(), ctx, os);
+ if (auto *var = dyn_cast<VariableElement>(el))
+ return genVariableParser(var->getParam(), ctx, os);
+ if (auto *params = dyn_cast<ParamsDirective>(el))
+ return genParamsParser(params, ctx, os);
+ if (auto *strct = dyn_cast<StructDirective>(el))
+ return genStructParser(strct, ctx, os);
+
+ llvm_unreachable("unknown format element");
+}
+
+void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx,
+ raw_ostream &os, unsigned indent) {
+ os.indent(indent) << " // Parse literal '" << value << "'\n";
+ os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx);
+ if (value.front() == '_' || isalpha(value.front())) {
+ os << "Keyword(\"" << value << "\")";
+ } else {
+ os << StringSwitch<StringRef>(value)
+ .Case("->", "Arrow")
+ .Case(":", "Colon")
+ .Case(",", "Comma")
+ .Case("=", "Equal")
+ .Case("<", "Less")
+ .Case(">", "Greater")
+ .Case("{", "LBrace")
+ .Case("}", "RBrace")
+ .Case("(", "LParen")
+ .Case(")", "RParen")
+ .Case("[", "LSquare")
+ .Case("]", "RSquare")
+ .Case("?", "Question")
+ .Case("+", "Plus")
+ .Case("*", "Star")
+ << "()";
+ }
+ os << ")\n";
+ // Parser will emit an error
+ os.indent(indent) << " return {};\n";
+}
+
+void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m,
+ FmtContext &ctx, raw_ostream &os,
+ unsigned indent) {
+ /// Check for a custom parser. Use the default attribute parser otherwise.
+ auto customParser = param.getParser();
+ auto parser =
+ customParser ? *customParser : StringRef(defaultParameterParser);
+ for (const char *line : variableParser) {
+ os.indent(indent) << formatv(line, param.getName(),
+ tgfmt(parser, &ctx, param.getCppStorageType()),
+ tgfmt(parseErrorStr, &ctx), def.getName(),
+ param.getCppType())
+ << "\n";
+ }
+}
+
+void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
+ raw_ostream &os) {
+ os << " // Parse parameter list\n";
+ llvm::interleave(
+ el->getParams(), [&](auto param) { genVariableParser(param, ctx, os); },
+ [&]() { genLiteralParser(",", ctx, os); });
+}
+
+void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx,
+ raw_ostream &os) {
+ os << " // Parse parameter struct\n";
+
+ /// Declare a "seen" variable for each key.
+ for (const AttrOrTypeParameter ¶m : el->getParams())
+ os << formatv(" bool _seen_{0} = false;\n", param.getName());
+
+ /// Generate the parsing loop.
+ os << tgfmt(structParseLoopStart, &ctx, el->getNumParams());
+ genLiteralParser("=", ctx, os, 2);
+ os << " ";
+ for (const AttrOrTypeParameter ¶m : el->getParams()) {
+ os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
+ " _seen_{0} = true;\n",
+ param.getName());
+ genVariableParser(param, ctx, os, 4);
+ os << " } else ";
+ }
+
+ /// Duplicate or unknown parameter.
+ os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx),
+ el->getNumParams());
+
+ /// Because the loop loops N times and each non-failing iteration sets 1 of
+ /// N flags, successfully exiting the loop means that all parameters have been
+ /// seen. `parseOptionalComma` would cause issues with any formats that use
+ /// "struct(...) `,`" beacuse structs aren't sounded by braces.
+}
+
+//===----------------------------------------------------------------------===//
+// PrinterGen
+//===----------------------------------------------------------------------===//
+
+void AttrOrTypeFormat::genPrinter(raw_ostream &os) {
+ FmtContext ctx;
+ ctx.addSubst("_printer", "printer");
+
+ /// Generate the definition.
+ os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName(),
+ *def.getMnemonic());
+
+ /// Generate printers.
+ shouldEmitSpace = true;
+ lastWasPunctuation = false;
+ for (auto &el : elements)
+ genElementPrinter(el.get(), ctx, os);
+
+ os << "}\n\n";
+}
+
+void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
+ raw_ostream &os) {
+ if (auto *literal = dyn_cast<LiteralElement>(el))
+ return genLiteralPrinter(literal->getSpelling(), ctx, os);
+ if (auto *params = dyn_cast<ParamsDirective>(el))
+ return genParamsPrinter(params, ctx, os);
+ if (auto *strct = dyn_cast<StructDirective>(el))
+ return genStructPrinter(strct, ctx, os);
+ if (auto *var = dyn_cast<VariableElement>(el))
+ return genVariablePrinter(var->getParam(), ctx, os);
+
+ llvm_unreachable("unknown format element");
+}
+
+void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
+ raw_ostream &os) {
+ /// Don't insert a space before certain punctuation.
+ bool needSpace =
+ shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
+ os << tgfmt(" $_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
+ value);
+
+ /// Update the flags.
+ shouldEmitSpace =
+ value.size() != 1 || !StringRef("<({[").contains(value.front());
+ lastWasPunctuation = !(value.front() == '_' || isalpha(value.front()));
+}
+
+void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m,
+ FmtContext &ctx, raw_ostream &os) {
+ /// Insert a space before the next parameter, if necessary.
+ if (shouldEmitSpace || !lastWasPunctuation)
+ os << tgfmt(" $_printer << ' ';\n", &ctx);
+ shouldEmitSpace = true;
+ lastWasPunctuation = false;
+
+ ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
+ os << " ";
+ if (auto printer = param.getPrinter())
+ os << tgfmt(*printer, &ctx) << ";\n";
+ else
+ os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
+}
+
+void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
+ raw_ostream &os) {
+ llvm::interleave(
+ el->getParams(), [&](auto param) { genVariablePrinter(param, ctx, os); },
+ [&]() { genLiteralPrinter(",", ctx, os); });
+}
+
+void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
+ raw_ostream &os) {
+ llvm::interleave(
+ el->getParams(),
+ [&](auto param) {
+ genLiteralPrinter(param.getName(), ctx, os);
+ genLiteralPrinter("=", ctx, os);
+ os << tgfmt(" $_printer << ' ';\n", &ctx);
+ genVariablePrinter(param, ctx, os);
+ },
+ [&]() { genLiteralPrinter(",", ctx, os); });
+}
+
+//===----------------------------------------------------------------------===//
+// FormatParser
+//===----------------------------------------------------------------------===//
+
+namespace {
+class FormatParser {
+public:
+ FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
+ : lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def),
+ seenParams(def.getNumParameters()) {}
+
+ /// Parse the attribute or type format and create the format elements.
+ FailureOr<AttrOrTypeFormat> parse();
+
+private:
+ /// The current context of the parser when parsing an element.
+ enum ParserContext {
+ /// The element is being parsed in the default context - at the top of the
+ /// format
+ TopLevelContext,
+ /// The element is being parsed as a child to a `struct` directive.
+ StructDirective,
+ };
+
+ /// Emit an error.
+ LogicalResult emitError(const Twine &msg) {
+ lexer.emitError(curToken.getLoc(), msg);
+ return failure();
+ }
+
+ /// Parse an expected token.
+ LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
+ if (curToken.getKind() != kind)
+ return emitError(msg);
+ consumeToken();
+ return success();
+ }
+
+ /// Advance the lexer to the next token.
+ void consumeToken() {
+ assert(curToken.getKind() != FormatToken::eof &&
+ curToken.getKind() != FormatToken::error &&
+ "shouldn't advance past EOF or errors");
+ curToken = lexer.lexToken();
+ }
+
+ /// Parse any element.
+ FailureOr<std::unique_ptr<Element>> parseElement(ParserContext ctx);
+ /// Parse a literal element.
+ FailureOr<std::unique_ptr<Element>> parseLiteral(ParserContext ctx);
+ /// Parse a variable element.
+ FailureOr<std::unique_ptr<Element>> parseVariable(ParserContext ctx);
+ /// Parse a directive.
+ FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
+ /// Parse a `params` directive.
+ FailureOr<std::unique_ptr<Element>> parseParamsDirective();
+ /// Parse a `struct` directive.
+ FailureOr<std::unique_ptr<Element>> parseStructDirective();
+
+ /// The current format lexer.
+ FormatLexer lexer;
+ /// The current token in the stream.
+ FormatToken curToken;
+ /// Attribute or type tablegen def.
+ const AttrOrTypeDef &def;
+
+ /// Seen attribute or type parameters.
+ llvm::BitVector seenParams;
+};
+} // end anonymous namespace
+
+FailureOr<AttrOrTypeFormat> FormatParser::parse() {
+ std::vector<std::unique_ptr<Element>> elements;
+ elements.reserve(16);
+
+ /// Parse the format elements.
+ while (curToken.getKind() != FormatToken::eof) {
+ auto element = parseElement(TopLevelContext);
+ if (failed(element))
+ return failure();
+
+ /// Add the format element and continue.
+ elements.push_back(std::move(*element));
+ }
+
+ /// Check that all parameters have been seen.
+ SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ for (auto it : llvm::enumerate(params)) {
+ if (!seenParams.test(it.index())) {
+ return emitError("format is missing reference to parameter: " +
+ it.value().getName());
+ }
+ }
+
+ return AttrOrTypeFormat(def, std::move(elements));
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseElement(ParserContext ctx) {
+ if (curToken.getKind() == FormatToken::literal)
+ return parseLiteral(ctx);
+ if (curToken.getKind() == FormatToken::variable)
+ return parseVariable(ctx);
+ if (curToken.isKeyword())
+ return parseDirective(ctx);
+
+ return emitError("expected literal, directive, or variable");
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseLiteral(ParserContext ctx) {
+ if (ctx != TopLevelContext) {
+ return emitError(
+ "literals may only be used in the top-level section of the format");
+ }
+
+ /// Get the literal spelling without the surrounding "`".
+ auto value = curToken.getSpelling().drop_front().drop_back();
+ if (!isValidLiteral(value))
+ return emitError("literal '" + value + "' is not valid");
+
+ consumeToken();
+ return {std::make_unique<LiteralElement>(value)};
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseVariable(ParserContext ctx) {
+ /// Get the parameter name without the preceding "$".
+ auto name = curToken.getSpelling().drop_front();
+
+ /// Lookup the parameter.
+ SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ auto *it = llvm::find_if(
+ params, [&](auto ¶m) { return param.getName() == name; });
+
+ /// Check that the parameter reference is valid.
+ if (it == params.end())
+ return emitError(def.getName() + " has no parameter named '" + name + "'");
+ auto idx = std::distance(params.begin(), it);
+ if (seenParams.test(idx))
+ return emitError("duplicate parameter '" + name + "'");
+ seenParams.set(idx);
+
+ consumeToken();
+ return {std::make_unique<VariableElement>(*it)};
+}
+
+FailureOr<std::unique_ptr<Element>>
+FormatParser::parseDirective(ParserContext ctx) {
+
+ switch (curToken.getKind()) {
+ case FormatToken::kw_params:
+ return parseParamsDirective();
+ case FormatToken::kw_struct:
+ if (ctx != TopLevelContext) {
+ return emitError(
+ "`struct` may only be used in the top-level section of the format");
+ }
+ return parseStructDirective();
+ default:
+ return emitError("unknown directive in format: " + curToken.getSpelling());
+ }
+}
+
+FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
+ consumeToken();
+ /// Collect all of the attribute's or type's parameters.
+ SmallVector<AttrOrTypeParameter> params = getParameters(def);
+ SmallVector<std::unique_ptr<Element>> vars;
+ /// Ensure that none of the parameters have already been captured.
+ for (auto it : llvm::enumerate(params)) {
+ if (seenParams.test(it.index())) {
+ return emitError("`params` captures duplicate parameter: " +
+ it.value().getName());
+ }
+ seenParams.set(it.index());
+ vars.push_back(std::make_unique<VariableElement>(it.value()));
+ }
+ return {std::make_unique<ParamsDirective>(std::move(vars))};
+}
+
+FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
+ consumeToken();
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before `struct` argument list")))
+ return failure();
+
+ /// Parse variables captured by `struct`.
+ SmallVector<std::unique_ptr<Element>> vars;
+
+ /// Parse first captured parameter or a `params` directive.
+ FailureOr<std::unique_ptr<Element>> var = parseElement(StructDirective);
+ if (failed(var) || !isa<VariableElement, ParamsDirective>(*var))
+ return emitError("`struct` argument list expected a variable or directive");
+ if (isa<VariableElement>(*var)) {
+ /// Parse any other parameters.
+ vars.push_back(std::move(*var));
+ while (curToken.getKind() == FormatToken::comma) {
+ consumeToken();
+ var = parseElement(StructDirective);
+ if (failed(var) || !isa<VariableElement>(*var))
+ return emitError("expected a variable in `struct` argument list");
+ vars.push_back(std::move(*var));
+ }
+ } else {
+ /// `struct(params)` captures all parameters in the attribute or type.
+ vars = cast<ParamsDirective>(var->get())->takeParams();
+ }
+
+ if (curToken.getKind() != FormatToken::r_paren)
+ return emitError("expected ')' at the end of an argument list");
+
+ consumeToken();
+ return {std::make_unique<::StructDirective>(std::move(vars))};
+}
+
+//===----------------------------------------------------------------------===//
+// Interface
+//===----------------------------------------------------------------------===//
+
+void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
+ raw_ostream &os) {
+ llvm::SourceMgr mgr;
+ mgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
+ llvm::SMLoc());
+
+ /// Parse the custom assembly format>
+ FormatParser parser(mgr, def);
+ FailureOr<AttrOrTypeFormat> format = parser.parse();
+ if (failed(format)) {
+ if (formatErrorIsFatal)
+ PrintFatalError(def.getLoc(), "failed to parse assembly format");
+ return;
+ }
+
+ /// Generate the parser and printer.
+ format->genParser(os);
+ format->genPrinter(os);
+}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
new file mode 100644
index 0000000000000..2a10a157dfc90
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h
@@ -0,0 +1,32 @@
+//===- AttrOrTypeFormatGen.h - MLIR attribute and type format generator ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
+#define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
+
+#include "llvm/Support/raw_ostream.h"
+
+#include <string>
+
+namespace mlir {
+namespace tblgen {
+class AttrOrTypeDef;
+
+/// Generate a parser and printer based on a custom assembly format for an
+/// attribute or type.
+void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os);
+
+/// From the parameter name, get the name of the accessor function in camelcase.
+/// The first letter of the parameter is upper-cased and prefixed with "get".
+/// E.g. 'value' -> 'getValue'.
+std::string getParameterAccessorName(llvm::StringRef name);
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index f16e8965daca4..a937a9d89a1d3 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -6,10 +6,12 @@ set(LLVM_LINK_COMPONENTS
add_tablegen(mlir-tblgen MLIR
AttrOrTypeDefGen.cpp
+ AttrOrTypeFormatGen.cpp
CodeGenHelpers.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
+ FormatGen.cpp
LLVMIRConversionGen.cpp
LLVMIRIntrinsicGen.cpp
mlir-tblgen.cpp
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
new file mode 100644
index 0000000000000..fa6c0603ac7e1
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -0,0 +1,225 @@
+//===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "FormatGen.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/TableGen/Error.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+//===----------------------------------------------------------------------===//
+// FormatToken
+//===----------------------------------------------------------------------===//
+
+llvm::SMLoc FormatToken::getLoc() const {
+ return llvm::SMLoc::getFromPointer(spelling.data());
+}
+
+//===----------------------------------------------------------------------===//
+// FormatLexer
+//===----------------------------------------------------------------------===//
+
+FormatLexer::FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc)
+ : mgr(mgr), loc(loc),
+ curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
+ curPtr(curBuffer.begin()) {}
+
+FormatToken FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
+ mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
+ llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
+ "in custom assembly format for this operation");
+ return formToken(FormatToken::error, loc.getPointer());
+}
+
+FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
+ return emitError(llvm::SMLoc::getFromPointer(loc), msg);
+}
+
+FormatToken FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+ const Twine ¬e) {
+ mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
+ llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
+ "in custom assembly format for this operation");
+ mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
+ return formToken(FormatToken::error, loc.getPointer());
+}
+
+int FormatLexer::getNextChar() {
+ char curChar = *curPtr++;
+ switch (curChar) {
+ default:
+ return (unsigned char)curChar;
+ case 0: {
+ // A nul character in the stream is either the end of the current buffer or
+ // a random nul in the file. Disambiguate that here.
+ if (curPtr - 1 != curBuffer.end())
+ return 0;
+
+ // Otherwise, return end of file.
+ --curPtr;
+ return EOF;
+ }
+ case '\n':
+ case '\r':
+ // Handle the newline character by ignoring it and incrementing the line
+ // count. However, be careful about 'dos style' files with \n\r in them.
+ // Only treat a \n\r or \r\n as a single line.
+ if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
+ ++curPtr;
+ return '\n';
+ }
+}
+
+FormatToken FormatLexer::lexToken() {
+ const char *tokStart = curPtr;
+
+ // This always consumes at least one character.
+ int curChar = getNextChar();
+ switch (curChar) {
+ default:
+ // Handle identifiers: [a-zA-Z_]
+ if (isalpha(curChar) || curChar == '_')
+ return lexIdentifier(tokStart);
+
+ // Unknown character, emit an error.
+ return emitError(tokStart, "unexpected character");
+ case EOF:
+ // Return EOF denoting the end of lexing.
+ return formToken(FormatToken::eof, tokStart);
+
+ // Lex punctuation.
+ case '^':
+ return formToken(FormatToken::caret, tokStart);
+ case ':':
+ return formToken(FormatToken::colon, tokStart);
+ case ',':
+ return formToken(FormatToken::comma, tokStart);
+ case '=':
+ return formToken(FormatToken::equal, tokStart);
+ case '<':
+ return formToken(FormatToken::less, tokStart);
+ case '>':
+ return formToken(FormatToken::greater, tokStart);
+ case '?':
+ return formToken(FormatToken::question, tokStart);
+ case '(':
+ return formToken(FormatToken::l_paren, tokStart);
+ case ')':
+ return formToken(FormatToken::r_paren, tokStart);
+ case '*':
+ return formToken(FormatToken::star, tokStart);
+
+ // Ignore whitespace characters.
+ case 0:
+ case ' ':
+ case '\t':
+ case '\n':
+ return lexToken();
+
+ case '`':
+ return lexLiteral(tokStart);
+ case '$':
+ return lexVariable(tokStart);
+ }
+}
+
+FormatToken FormatLexer::lexLiteral(const char *tokStart) {
+ assert(curPtr[-1] == '`');
+
+ // Lex a literal surrounded by ``.
+ while (const char curChar = *curPtr++) {
+ if (curChar == '`')
+ return formToken(FormatToken::literal, tokStart);
+ }
+ return emitError(curPtr - 1, "unexpected end of file in literal");
+}
+
+FormatToken FormatLexer::lexVariable(const char *tokStart) {
+ if (!isalpha(curPtr[0]) && curPtr[0] != '_')
+ return emitError(curPtr - 1, "expected variable name");
+
+ // Otherwise, consume the rest of the characters.
+ while (isalnum(*curPtr) || *curPtr == '_')
+ ++curPtr;
+ return formToken(FormatToken::variable, tokStart);
+}
+
+FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
+ // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
+ while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
+ ++curPtr;
+
+ // Check to see if this identifier is a keyword.
+ StringRef str(tokStart, curPtr - tokStart);
+ auto kind =
+ StringSwitch<FormatToken::Kind>(str)
+ .Case("attr-dict", FormatToken::kw_attr_dict)
+ .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
+ .Case("custom", FormatToken::kw_custom)
+ .Case("functional-type", FormatToken::kw_functional_type)
+ .Case("operands", FormatToken::kw_operands)
+ .Case("params", FormatToken::kw_params)
+ .Case("ref", FormatToken::kw_ref)
+ .Case("regions", FormatToken::kw_regions)
+ .Case("results", FormatToken::kw_results)
+ .Case("struct", FormatToken::kw_struct)
+ .Case("successors", FormatToken::kw_successors)
+ .Case("type", FormatToken::kw_type)
+ .Default(FormatToken::identifier);
+ return FormatToken(kind, str);
+}
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
+bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
+ bool lastWasPunctuation) {
+ if (value.size() != 1 && value != "->")
+ return true;
+ if (lastWasPunctuation)
+ return !StringRef(">)}],").contains(value.front());
+ return !StringRef("<>(){}[],").contains(value.front());
+}
+
+bool mlir::tblgen::canFormatStringAsKeyword(StringRef value) {
+ if (!isalpha(value.front()) && value.front() != '_')
+ return false;
+ return llvm::all_of(value.drop_front(), [](char c) {
+ return isalnum(c) || c == '_' || c == '$' || c == '.';
+ });
+}
+
+bool mlir::tblgen::isValidLiteral(StringRef value) {
+ if (value.empty())
+ return false;
+ char front = value.front();
+
+ // If there is only one character, this must either be punctuation or a
+ // single character bare identifier.
+ if (value.size() == 1)
+ return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
+
+ // Check the punctuation that are larger than a single character.
+ if (value == "->")
+ return true;
+
+ // Otherwise, this must be an identifier.
+ return canFormatStringAsKeyword(value);
+}
+
+//===----------------------------------------------------------------------===//
+// Commandline Options
+//===----------------------------------------------------------------------===//
+
+llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
+ "asmformat-error-is-fatal",
+ llvm::cl::desc("Emit a fatal error if format parsing fails"),
+ llvm::cl::init(true));
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
new file mode 100644
index 0000000000000..f061d1ed5c678
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -0,0 +1,161 @@
+//===- FormatGen.h - Utilities for custom assembly formats ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains common classes for building custom assembly format parsers
+// and generators.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/SMLoc.h"
+
+namespace llvm {
+class SourceMgr;
+} // end namespace llvm
+
+namespace mlir {
+namespace tblgen {
+
+//===----------------------------------------------------------------------===//
+// FormatToken
+//===----------------------------------------------------------------------===//
+
+/// This class represents a specific token in the input format.
+class FormatToken {
+public:
+ /// Basic token kinds.
+ enum Kind {
+ // Markers.
+ eof,
+ error,
+
+ // Tokens with no info.
+ l_paren,
+ r_paren,
+ caret,
+ colon,
+ comma,
+ equal,
+ less,
+ greater,
+ question,
+ star,
+
+ // Keywords.
+ keyword_start,
+ kw_attr_dict,
+ kw_attr_dict_w_keyword,
+ kw_custom,
+ kw_functional_type,
+ kw_operands,
+ kw_params,
+ kw_ref,
+ kw_regions,
+ kw_results,
+ kw_struct,
+ kw_successors,
+ kw_type,
+ keyword_end,
+
+ // String valued tokens.
+ identifier,
+ literal,
+ variable,
+ };
+
+ FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
+
+ /// Return the bytes that make up this token.
+ StringRef getSpelling() const { return spelling; }
+
+ /// Return the kind of this token.
+ Kind getKind() const { return kind; }
+
+ /// Return a location for this token.
+ llvm::SMLoc getLoc() const;
+
+ /// Return if this token is a keyword.
+ bool isKeyword() const {
+ return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
+ }
+
+private:
+ /// Discriminator that indicates the kind of token this is.
+ Kind kind;
+
+ /// A reference to the entire token contents; this is always a pointer into
+ /// a memory buffer owned by the source manager.
+ StringRef spelling;
+};
+
+//===----------------------------------------------------------------------===//
+// FormatLexer
+//===----------------------------------------------------------------------===//
+
+/// This class implements a simple lexer for operation assembly format strings.
+class FormatLexer {
+public:
+ FormatLexer(llvm::SourceMgr &mgr, llvm::SMLoc loc);
+
+ /// Lex the next token and return it.
+ FormatToken lexToken();
+
+ /// Emit an error to the lexer with the given location and message.
+ FormatToken emitError(llvm::SMLoc loc, const Twine &msg);
+ FormatToken emitError(const char *loc, const Twine &msg);
+
+ FormatToken emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+ const Twine ¬e);
+
+private:
+ /// Return the next character in the stream.
+ int getNextChar();
+
+ /// Lex an identifier, literal, or variable.
+ FormatToken lexIdentifier(const char *tokStart);
+ FormatToken lexLiteral(const char *tokStart);
+ FormatToken lexVariable(const char *tokStart);
+
+ /// Create a token with the current pointer and a start pointer.
+ FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
+ return FormatToken(kind, StringRef(tokStart, curPtr - tokStart));
+ }
+
+ /// The source manager containing the format string.
+ llvm::SourceMgr &mgr;
+ /// Location of the format string.
+ llvm::SMLoc loc;
+ /// Buffer containing the format string.
+ StringRef curBuffer;
+ /// Current pointer in the buffer.
+ const char *curPtr;
+};
+
+/// Whether a space needs to be emitted before a literal. E.g., two keywords
+/// back-to-back require a space separator, but a keyword followed by '<' does
+/// not require a space.
+bool shouldEmitSpaceBefore(StringRef value, bool lastWasPunctuation);
+
+/// Returns true if the given string can be formatted as a keyword.
+bool canFormatStringAsKeyword(StringRef value);
+
+/// Returns true if the given string is valid format literal element.
+bool isValidLiteral(StringRef value);
+
+/// Whether a failure in parsing the assembly format should be a fatal error.
+extern llvm::cl::opt<bool> formatErrorIsFatal;
+
+} // end namespace tblgen
+} // end namespace mlir
+
+#endif // MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ccc4ac9cf4287..19dd6fa7c1016 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "OpFormatGen.h"
+#include "FormatGen.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
@@ -20,7 +21,6 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@@ -30,20 +30,6 @@
using namespace mlir;
using namespace mlir::tblgen;
-static llvm::cl::opt<bool> formatErrorIsFatal(
- "asmformat-error-is-fatal",
- llvm::cl::desc("Emit a fatal error if format parsing fails"),
- llvm::cl::init(true));
-
-/// Returns true if the given string can be formatted as a keyword.
-static bool canFormatStringAsKeyword(StringRef value) {
- if (!isalpha(value.front()) && value.front() != '_')
- return false;
- return llvm::all_of(value.drop_front(), [](char c) {
- return isalnum(c) || c == '_' || c == '$' || c == '.';
- });
-}
-
//===----------------------------------------------------------------------===//
// Element
//===----------------------------------------------------------------------===//
@@ -273,33 +259,12 @@ class LiteralElement : public Element {
/// Return the literal for this element.
StringRef getLiteral() const { return literal; }
- /// Returns true if the given string is a valid literal.
- static bool isValidLiteral(StringRef value);
-
private:
/// The spelling of the literal for this element.
StringRef literal;
};
} // end anonymous namespace
-bool LiteralElement::isValidLiteral(StringRef value) {
- if (value.empty())
- return false;
- char front = value.front();
-
- // If there is only one character, this must either be punctuation or a
- // single character bare identifier.
- if (value.size() == 1)
- return isalpha(front) || StringRef("_:,=<>()[]{}?+*").contains(front);
-
- // Check the punctuation that are larger than a single character.
- if (value == "->")
- return true;
-
- // Otherwise, this must be an identifier.
- return canFormatStringAsKeyword(value);
-}
-
//===----------------------------------------------------------------------===//
// WhitespaceElement
@@ -1705,14 +1670,7 @@ static void genLiteralPrinter(StringRef value, OpMethodBody &body,
body << " _odsPrinter";
// Don't insert a space for certain punctuation.
- auto shouldPrintSpaceBeforeLiteral = [&] {
- if (value.size() != 1 && value != "->")
- return true;
- if (lastWasPunctuation)
- return !StringRef(">)}],").contains(value.front());
- return !StringRef("<>(){}[],").contains(value.front());
- };
- if (shouldEmitSpace && shouldPrintSpaceBeforeLiteral())
+ if (shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation))
body << " << ' '";
body << " << \"" << value << "\";\n";
@@ -2101,253 +2059,6 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
lastWasPunctuation);
}
-//===----------------------------------------------------------------------===//
-// FormatLexer
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class represents a specific token in the input format.
-class Token {
-public:
- enum Kind {
- // Markers.
- eof,
- error,
-
- // Tokens with no info.
- l_paren,
- r_paren,
- caret,
- colon,
- comma,
- equal,
- less,
- greater,
- question,
-
- // Keywords.
- keyword_start,
- kw_attr_dict,
- kw_attr_dict_w_keyword,
- kw_custom,
- kw_functional_type,
- kw_operands,
- kw_ref,
- kw_regions,
- kw_results,
- kw_successors,
- kw_type,
- keyword_end,
-
- // String valued tokens.
- identifier,
- literal,
- variable,
- };
- Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
-
- /// Return the bytes that make up this token.
- StringRef getSpelling() const { return spelling; }
-
- /// Return the kind of this token.
- Kind getKind() const { return kind; }
-
- /// Return a location for this token.
- llvm::SMLoc getLoc() const {
- return llvm::SMLoc::getFromPointer(spelling.data());
- }
-
- /// Return if this token is a keyword.
- bool isKeyword() const { return kind > keyword_start && kind < keyword_end; }
-
-private:
- /// Discriminator that indicates the kind of token this is.
- Kind kind;
-
- /// A reference to the entire token contents; this is always a pointer into
- /// a memory buffer owned by the source manager.
- StringRef spelling;
-};
-
-/// This class implements a simple lexer for operation assembly format strings.
-class FormatLexer {
-public:
- FormatLexer(llvm::SourceMgr &mgr, Operator &op);
-
- /// Lex the next token and return it.
- Token lexToken();
-
- /// Emit an error to the lexer with the given location and message.
- Token emitError(llvm::SMLoc loc, const Twine &msg);
- Token emitError(const char *loc, const Twine &msg);
-
- Token emitErrorAndNote(llvm::SMLoc loc, const Twine &msg, const Twine ¬e);
-
-private:
- Token formToken(Token::Kind kind, const char *tokStart) {
- return Token(kind, StringRef(tokStart, curPtr - tokStart));
- }
-
- /// Return the next character in the stream.
- int getNextChar();
-
- /// Lex an identifier, literal, or variable.
- Token lexIdentifier(const char *tokStart);
- Token lexLiteral(const char *tokStart);
- Token lexVariable(const char *tokStart);
-
- llvm::SourceMgr &srcMgr;
- Operator &op;
- StringRef curBuffer;
- const char *curPtr;
-};
-} // end anonymous namespace
-
-FormatLexer::FormatLexer(llvm::SourceMgr &mgr, Operator &op)
- : srcMgr(mgr), op(op) {
- curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
- curPtr = curBuffer.begin();
-}
-
-Token FormatLexer::emitError(llvm::SMLoc loc, const Twine &msg) {
- srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
- llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
- "in custom assembly format for this operation");
- return formToken(Token::error, loc.getPointer());
-}
-Token FormatLexer::emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
- const Twine ¬e) {
- srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
- llvm::SrcMgr.PrintMessage(op.getLoc()[0], llvm::SourceMgr::DK_Note,
- "in custom assembly format for this operation");
- srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
- return formToken(Token::error, loc.getPointer());
-}
-Token FormatLexer::emitError(const char *loc, const Twine &msg) {
- return emitError(llvm::SMLoc::getFromPointer(loc), msg);
-}
-
-int FormatLexer::getNextChar() {
- char curChar = *curPtr++;
- switch (curChar) {
- default:
- return (unsigned char)curChar;
- case 0: {
- // A nul character in the stream is either the end of the current buffer or
- // a random nul in the file. Disambiguate that here.
- if (curPtr - 1 != curBuffer.end())
- return 0;
-
- // Otherwise, return end of file.
- --curPtr;
- return EOF;
- }
- case '\n':
- case '\r':
- // Handle the newline character by ignoring it and incrementing the line
- // count. However, be careful about 'dos style' files with \n\r in them.
- // Only treat a \n\r or \r\n as a single line.
- if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
- ++curPtr;
- return '\n';
- }
-}
-
-Token FormatLexer::lexToken() {
- const char *tokStart = curPtr;
-
- // This always consumes at least one character.
- int curChar = getNextChar();
- switch (curChar) {
- default:
- // Handle identifiers: [a-zA-Z_]
- if (isalpha(curChar) || curChar == '_')
- return lexIdentifier(tokStart);
-
- // Unknown character, emit an error.
- return emitError(tokStart, "unexpected character");
- case EOF:
- // Return EOF denoting the end of lexing.
- return formToken(Token::eof, tokStart);
-
- // Lex punctuation.
- case '^':
- return formToken(Token::caret, tokStart);
- case ':':
- return formToken(Token::colon, tokStart);
- case ',':
- return formToken(Token::comma, tokStart);
- case '=':
- return formToken(Token::equal, tokStart);
- case '<':
- return formToken(Token::less, tokStart);
- case '>':
- return formToken(Token::greater, tokStart);
- case '?':
- return formToken(Token::question, tokStart);
- case '(':
- return formToken(Token::l_paren, tokStart);
- case ')':
- return formToken(Token::r_paren, tokStart);
-
- // Ignore whitespace characters.
- case 0:
- case ' ':
- case '\t':
- case '\n':
- return lexToken();
-
- case '`':
- return lexLiteral(tokStart);
- case '$':
- return lexVariable(tokStart);
- }
-}
-
-Token FormatLexer::lexLiteral(const char *tokStart) {
- assert(curPtr[-1] == '`');
-
- // Lex a literal surrounded by ``.
- while (const char curChar = *curPtr++) {
- if (curChar == '`')
- return formToken(Token::literal, tokStart);
- }
- return emitError(curPtr - 1, "unexpected end of file in literal");
-}
-
-Token FormatLexer::lexVariable(const char *tokStart) {
- if (!isalpha(curPtr[0]) && curPtr[0] != '_')
- return emitError(curPtr - 1, "expected variable name");
-
- // Otherwise, consume the rest of the characters.
- while (isalnum(*curPtr) || *curPtr == '_')
- ++curPtr;
- return formToken(Token::variable, tokStart);
-}
-
-Token FormatLexer::lexIdentifier(const char *tokStart) {
- // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
- while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
- ++curPtr;
-
- // Check to see if this identifier is a keyword.
- StringRef str(tokStart, curPtr - tokStart);
- Token::Kind kind =
- StringSwitch<Token::Kind>(str)
- .Case("attr-dict", Token::kw_attr_dict)
- .Case("attr-dict-with-keyword", Token::kw_attr_dict_w_keyword)
- .Case("custom", Token::kw_custom)
- .Case("functional-type", Token::kw_functional_type)
- .Case("operands", Token::kw_operands)
- .Case("ref", Token::kw_ref)
- .Case("regions", Token::kw_regions)
- .Case("results", Token::kw_results)
- .Case("successors", Token::kw_successors)
- .Case("type", Token::kw_type)
- .Default(Token::identifier);
- return Token(kind, str);
-}
-
//===----------------------------------------------------------------------===//
// FormatParser
//===----------------------------------------------------------------------===//
@@ -2366,8 +2077,8 @@ namespace {
class FormatParser {
public:
FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
- : lexer(mgr, op), curToken(lexer.lexToken()), fmt(format), op(op),
- seenOperandTypes(op.getNumOperands()),
+ : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
+ op(op), seenOperandTypes(op.getNumOperands()),
seenResultTypes(op.getNumResults()) {}
/// Parse the operation assembly format.
@@ -2469,7 +2180,8 @@ class FormatParser {
LogicalResult parseCustomDirectiveParameter(
std::vector<std::unique_ptr<Element>> ¶meters);
LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
- Token tok, ParserContext context);
+ FormatToken tok,
+ ParserContext context);
LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc, ParserContext context);
LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
@@ -2481,8 +2193,8 @@ class FormatParser {
LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
llvm::SMLoc loc,
ParserContext context);
- LogicalResult parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
- ParserContext context);
+ LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
+ FormatToken tok, ParserContext context);
LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
bool isRefChild = false);
@@ -2492,12 +2204,12 @@ class FormatParser {
/// Advance the current lexer onto the next token.
void consumeToken() {
- assert(curToken.getKind() != Token::eof &&
- curToken.getKind() != Token::error &&
+ assert(curToken.getKind() != FormatToken::eof &&
+ curToken.getKind() != FormatToken::error &&
"shouldn't advance past EOF or errors");
curToken = lexer.lexToken();
}
- LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
+ LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
if (curToken.getKind() != kind)
return emitError(curToken.getLoc(), msg);
consumeToken();
@@ -2518,7 +2230,7 @@ class FormatParser {
//===--------------------------------------------------------------------===//
FormatLexer lexer;
- Token curToken;
+ FormatToken curToken;
OperationFormat &fmt;
Operator &op;
@@ -2539,7 +2251,7 @@ LogicalResult FormatParser::parse() {
llvm::SMLoc loc = curToken.getLoc();
// Parse each of the format elements into the main format.
- while (curToken.getKind() != Token::eof) {
+ while (curToken.getKind() != FormatToken::eof) {
std::unique_ptr<Element> element;
if (failed(parseElement(element, TopLevelContext)))
return ::mlir::failure();
@@ -2864,13 +2576,13 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
if (curToken.isKeyword())
return parseDirective(element, context);
// Literals.
- if (curToken.getKind() == Token::literal)
+ if (curToken.getKind() == FormatToken::literal)
return parseLiteral(element, context);
// Optionals.
- if (curToken.getKind() == Token::l_paren)
+ if (curToken.getKind() == FormatToken::l_paren)
return parseOptional(element, context);
// Variables.
- if (curToken.getKind() == Token::variable)
+ if (curToken.getKind() == FormatToken::variable)
return parseVariable(element, context);
return emitError(curToken.getLoc(),
"expected directive, literal, variable, or optional group");
@@ -2878,7 +2590,7 @@ LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
ParserContext context) {
- Token varTok = curToken;
+ FormatToken varTok = curToken;
consumeToken();
StringRef name = varTok.getSpelling().drop_front();
@@ -2958,31 +2670,31 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
ParserContext context) {
- Token dirTok = curToken;
+ FormatToken dirTok = curToken;
consumeToken();
switch (dirTok.getKind()) {
- case Token::kw_attr_dict:
+ case FormatToken::kw_attr_dict:
return parseAttrDictDirective(element, dirTok.getLoc(), context,
/*withKeyword=*/false);
- case Token::kw_attr_dict_w_keyword:
+ case FormatToken::kw_attr_dict_w_keyword:
return parseAttrDictDirective(element, dirTok.getLoc(), context,
/*withKeyword=*/true);
- case Token::kw_custom:
+ case FormatToken::kw_custom:
return parseCustomDirective(element, dirTok.getLoc(), context);
- case Token::kw_functional_type:
+ case FormatToken::kw_functional_type:
return parseFunctionalTypeDirective(element, dirTok, context);
- case Token::kw_operands:
+ case FormatToken::kw_operands:
return parseOperandsDirective(element, dirTok.getLoc(), context);
- case Token::kw_regions:
+ case FormatToken::kw_regions:
return parseRegionsDirective(element, dirTok.getLoc(), context);
- case Token::kw_results:
+ case FormatToken::kw_results:
return parseResultsDirective(element, dirTok.getLoc(), context);
- case Token::kw_successors:
+ case FormatToken::kw_successors:
return parseSuccessorsDirective(element, dirTok.getLoc(), context);
- case Token::kw_ref:
+ case FormatToken::kw_ref:
return parseReferenceDirective(element, dirTok.getLoc(), context);
- case Token::kw_type:
+ case FormatToken::kw_type:
return parseTypeDirective(element, dirTok, context);
default:
@@ -2992,7 +2704,7 @@ LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
ParserContext context) {
- Token literalTok = curToken;
+ FormatToken literalTok = curToken;
if (context != TopLevelContext) {
return emitError(
literalTok.getLoc(),
@@ -3014,7 +2726,7 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
}
// Check that the parsed literal is valid.
- if (!LiteralElement::isValidLiteral(value))
+ if (!isValidLiteral(value))
return emitError(literalTok.getLoc(), "expected valid literal");
element = std::make_unique<LiteralElement>(value);
@@ -3035,14 +2747,15 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
do {
if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
return ::mlir::failure();
- } while (curToken.getKind() != Token::r_paren);
+ } while (curToken.getKind() != FormatToken::r_paren);
consumeToken();
// Parse the `else` elements of this optional group.
- if (curToken.getKind() == Token::colon) {
+ if (curToken.getKind() == FormatToken::colon) {
consumeToken();
- if (failed(parseToken(Token::l_paren, "expected '(' to start else branch "
- "of optional group")))
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' to start else branch "
+ "of optional group")))
return failure();
do {
llvm::SMLoc childLoc = curToken.getLoc();
@@ -3051,11 +2764,12 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
/*isAnchor=*/false)))
return failure();
- } while (curToken.getKind() != Token::r_paren);
+ } while (curToken.getKind() != FormatToken::r_paren);
consumeToken();
}
- if (failed(parseToken(Token::question, "expected '?' after optional group")))
+ if (failed(parseToken(FormatToken::question,
+ "expected '?' after optional group")))
return ::mlir::failure();
// The optional group is required to have an anchor.
@@ -3090,7 +2804,7 @@ LogicalResult FormatParser::parseOptionalChildElement(
return ::mlir::failure();
// Check to see if this element is the anchor of the optional group.
- bool isAnchor = curToken.getKind() == Token::caret;
+ bool isAnchor = curToken.getKind() == FormatToken::caret;
if (isAnchor) {
if (anchorIdx)
return emitError(childLoc, "only one element can be marked as the anchor "
@@ -3194,16 +2908,16 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'custom' is only valid as a top-level directive");
// Parse the custom directive name.
- if (failed(
- parseToken(Token::less, "expected '<' before custom directive name")))
+ if (failed(parseToken(FormatToken::less,
+ "expected '<' before custom directive name")))
return ::mlir::failure();
- Token nameTok = curToken;
- if (failed(parseToken(Token::identifier,
+ FormatToken nameTok = curToken;
+ if (failed(parseToken(FormatToken::identifier,
"expected custom directive name identifier")) ||
- failed(parseToken(Token::greater,
+ failed(parseToken(FormatToken::greater,
"expected '>' after custom directive name")) ||
- failed(parseToken(Token::l_paren,
+ failed(parseToken(FormatToken::l_paren,
"expected '(' before custom directive parameters")))
return ::mlir::failure();
@@ -3212,12 +2926,12 @@ FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
do {
if (failed(parseCustomDirectiveParameter(elements)))
return ::mlir::failure();
- if (curToken.getKind() != Token::comma)
+ if (curToken.getKind() != FormatToken::comma)
break;
consumeToken();
} while (true);
- if (failed(parseToken(Token::r_paren,
+ if (failed(parseToken(FormatToken::r_paren,
"expected ')' after custom directive parameters")))
return ::mlir::failure();
@@ -3254,9 +2968,8 @@ LogicalResult FormatParser::parseCustomDirectiveParameter(
return ::mlir::success();
}
-LogicalResult
-FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
- Token tok, ParserContext context) {
+LogicalResult FormatParser::parseFunctionalTypeDirective(
+ std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
llvm::SMLoc loc = tok.getLoc();
if (context != TopLevelContext)
return emitError(
@@ -3264,11 +2977,14 @@ FormatParser::parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
// Parse the main operand.
std::unique_ptr<Element> inputs, results;
- if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before argument list")) ||
failed(parseTypeDirectiveOperand(inputs)) ||
- failed(parseToken(Token::comma, "expected ',' after inputs argument")) ||
+ failed(parseToken(FormatToken::comma,
+ "expected ',' after inputs argument")) ||
failed(parseTypeDirectiveOperand(results)) ||
- failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+ failed(
+ parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
std::move(results));
@@ -3299,9 +3015,11 @@ FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'ref' is only valid within a `custom` directive");
std::unique_ptr<Element> operand;
- if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before argument list")) ||
failed(parseElement(operand, RefDirectiveContext)) ||
- failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+ failed(
+ parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<RefDirective>(std::move(operand));
@@ -3360,17 +3078,19 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
}
LogicalResult
-FormatParser::parseTypeDirective(std::unique_ptr<Element> &element, Token tok,
- ParserContext context) {
+FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
+ FormatToken tok, ParserContext context) {
llvm::SMLoc loc = tok.getLoc();
if (context == TypeDirectiveContext)
return emitError(loc, "'type' cannot be used as a child of another `type`");
bool isRefChild = context == RefDirectiveContext;
std::unique_ptr<Element> operand;
- if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) ||
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before argument list")) ||
failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
- failed(parseToken(Token::r_paren, "expected ')' after argument list")))
+ failed(
+ parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return ::mlir::failure();
element = std::make_unique<TypeDirective>(std::move(operand));
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
index f776696ade5ba..eb19a15cabc85 100644
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -220,6 +220,7 @@ cc_library(
"//mlir:SideEffects",
"//mlir:StandardOps",
"//mlir:StandardOpsTransforms",
+ "//mlir:Support",
"//mlir:TensorDialect",
"//mlir:TransformUtils",
"//mlir:Transforms",
More information about the Mlir-commits
mailing list