[Mlir-commits] [mlir] 0bc0ad8 - [mlir][ods] Unify Attr/TypeDef and Operation Format Parsing
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 31 23:28:41 PST 2022
Author: Mogball
Date: 2022-02-01T07:28:37Z
New Revision: 0bc0ad86e2cdef3585d9c47aed66ebbe9aec4f4c
URL: https://github.com/llvm/llvm-project/commit/0bc0ad86e2cdef3585d9c47aed66ebbe9aec4f4c
DIFF: https://github.com/llvm/llvm-project/commit/0bc0ad86e2cdef3585d9c47aed66ebbe9aec4f4c.diff
LOG: [mlir][ods] Unify Attr/TypeDef and Operation Format Parsing
Part 2 of 3 of unifying the assembly formats of attributes/types and operations.The last patch that introduced attribute/type formats (D111594) factored out the format lexer entirely. This patch factors out most of the format parsers such that the attribute/type and op parsers only need to implement handling for specific elements.
Certain things could be factored better (element verification, 'seen' variables) but the primary goal of factoring is so that features can be used across both assembly formats.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D117971
Added:
Modified:
mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
mlir/test/mlir-tblgen/op-format-spec.td
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.cpp
mlir/tools/mlir-tblgen/FormatGen.h
mlir/tools/mlir-tblgen/OpFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index 372aef6dfa3e5..012685fd05cba 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -28,7 +28,7 @@ def InvalidTypeB : InvalidType<"InvalidTypeB", "invalid_b"> {
/// Test format has invalid syntax.
def InvalidTypeC : InvalidType<"InvalidTypeC", "invalid_c"> {
let parameters = (ins "int":$v0, "int":$v1);
- // CHECK: expected literal, directive, or variable
+ // CHECK: expected literal, variable, directive, or optional group
let assemblyFormat = "`<` $v0, $v1 `>`";
}
diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td
index 6f539f599ff72..1c419424d6021 100644
--- a/mlir/test/mlir-tblgen/op-format-spec.td
+++ b/mlir/test/mlir-tblgen/op-format-spec.td
@@ -97,7 +97,7 @@ def DirectiveFunctionalTypeInvalidA : TestFormat_Op<[{
def DirectiveFunctionalTypeInvalidB : TestFormat_Op<[{
functional-type
}]>;
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def DirectiveFunctionalTypeInvalidC : TestFormat_Op<[{
functional-type(
}]>;
@@ -105,7 +105,7 @@ def DirectiveFunctionalTypeInvalidC : TestFormat_Op<[{
def DirectiveFunctionalTypeInvalidD : TestFormat_Op<[{
functional-type(operands
}]>;
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def DirectiveFunctionalTypeInvalidE : TestFormat_Op<[{
functional-type(operands,
}]>;
@@ -262,7 +262,7 @@ def DirectiveSuccessorsInvalidA : TestFormat_Op<[{
def DirectiveTypeInvalidA : TestFormat_Op<[{
type
}]>;
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def DirectiveTypeInvalidB : TestFormat_Op<[{
type(
}]>;
@@ -278,7 +278,7 @@ def DirectiveTypeValid : TestFormat_Op<[{
//===----------------------------------------------------------------------===//
// functional-type/type operands
-// CHECK: error: literals may only be used in a top-level section of the format
+// CHECK: error: literals may only be used in the top-level section of the format
def DirectiveTypeZOperandInvalidA : TestFormat_Op<[{
type(`literal`)
}]>;
@@ -334,7 +334,7 @@ def LiteralInvalidC : TestFormat_Op<[{
}]>;
// CHECK: error: unexpected end of file in literal
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def LiteralInvalidD : TestFormat_Op<[{
`
}]>;
@@ -352,15 +352,15 @@ def LiteralValid : TestFormat_Op<[{
def OptionalInvalidA : TestFormat_Op<[{
type(($attr^)?) attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def OptionalInvalidB : TestFormat_Op<[{
() attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: optional group specified no anchor element
+// CHECK: error: optional group has no anchor element
def OptionalInvalidC : TestFormat_Op<[{
($attr)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first parsable element of an operand group must be an attribute, literal, operand, or region
+// CHECK: error: first parsable element of an optional group must be a literal or variable
def OptionalInvalidD : TestFormat_Op<[{
(type($operand) $operand^)? attr-dict
}]>, Arguments<(ins Optional<I64>:$operand)>;
@@ -370,15 +370,15 @@ def OptionalInvalidE : TestFormat_Op<[{
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
// CHECK: error: only one element can be marked as the anchor of an optional group
def OptionalInvalidF : TestFormat_Op<[{
- ($attr^ $attr2^) attr-dict
+ ($attr^ $attr2^)? attr-dict
}]>, Arguments<(ins OptionalAttr<I64Attr>:$attr, OptionalAttr<I64Attr>:$attr2)>;
// CHECK: error: only optional attributes can be used to anchor an optional group
def OptionalInvalidG : TestFormat_Op<[{
- ($attr^) attr-dict
+ ($attr^)? attr-dict
}]>, Arguments<(ins I64Attr:$attr)>;
// CHECK: error: only variable length operands can be used within an optional group
def OptionalInvalidH : TestFormat_Op<[{
- ($arg^) attr-dict
+ ($arg^)? attr-dict
}]>, Arguments<(ins I64:$arg)>;
// CHECK: error: only literals, types, and variables can be used within an optional group
def OptionalInvalidI : TestFormat_Op<[{
@@ -386,7 +386,7 @@ def OptionalInvalidI : TestFormat_Op<[{
}]>, Arguments<(ins Variadic<I64>:$arg)>;
// CHECK: error: only literals, types, and variables can be used within an optional group
def OptionalInvalidJ : TestFormat_Op<[{
- (attr-dict)
+ (attr-dict^)?
}]>;
// CHECK: error: expected '?' after optional group
def OptionalInvalidK : TestFormat_Op<[{
@@ -404,7 +404,7 @@ def OptionalInvalidM : TestFormat_Op<[{
def OptionalInvalidN : TestFormat_Op<[{
($arg^):
}]>, Arguments<(ins Variadic<I64>:$arg)>;
-// CHECK: error: expected directive, literal, variable, or optional group
+// CHECK: error: expected literal, variable, directive, or optional group
def OptionalInvalidO : TestFormat_Op<[{
($arg^):(`test`
}]>, Arguments<(ins Variadic<I64>:$arg)>;
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 2ac5df7be39e9..4ca20e59cc03b 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -31,61 +31,12 @@ using llvm::formatv;
//===----------------------------------------------------------------------===//
namespace {
-
-/// This class represents a single format element.
-class Element {
-public:
- /// LLVM-style RTTI.
- enum class Kind {
- /// This element is a directive.
- ParamsDirective,
- StructDirective,
-
- /// This element is a literal.
- Literal,
-
- /// This element is a variable.
- Variable,
- };
- Element(Kind kind) : kind(kind) {}
- virtual ~Element() = default;
-
- /// Return the kind of this element.
- Kind getKind() const { return kind; }
-
-private:
- /// The kind of this element.
- Kind kind;
-};
-
-/// This class represents an instance of a literal element.
-class LiteralElement : public Element {
-public:
- LiteralElement(StringRef literal)
- : Element(Kind::Literal), literal(literal) {}
-
- static bool classof(const Element *el) {
- return el->getKind() == Kind::Literal;
- }
-
- /// Get the literal spelling.
- StringRef getSpelling() const { return literal; }
-
-private:
- /// The spelling of the literal for this element.
- StringRef literal;
-};
-
/// This class represents an instance of a variable element. A variable refers
/// to an attribute or type parameter.
-class VariableElement : public Element {
+class ParameterElement
+ : public VariableElementBase<VariableElement::Parameter> {
public:
- VariableElement(AttrOrTypeParameter param)
- : Element(Kind::Variable), param(param) {}
-
- static bool classof(const Element *el) {
- return el->getKind() == Kind::Variable;
- }
+ ParameterElement(AttrOrTypeParameter param) : param(param) {}
/// Get the parameter in the element.
const AttrOrTypeParameter &getParam() const { return param; }
@@ -103,22 +54,18 @@ class VariableElement : public Element {
};
/// Base class for a directive that contains references to multiple variables.
-template <Element::Kind ElementKind>
-class ParamsDirectiveBase : public Element {
+template <DirectiveElement::Kind DirectiveKind>
+class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
public:
- using Base = ParamsDirectiveBase<ElementKind>;
+ using Base = ParamsDirectiveBase<DirectiveKind>;
- ParamsDirectiveBase(SmallVector<std::unique_ptr<Element>> &¶ms)
- : Element(ElementKind), params(std::move(params)) {}
-
- static bool classof(const Element *el) {
- return el->getKind() == ElementKind;
- }
+ ParamsDirectiveBase(std::vector<FormatElement *> &¶ms)
+ : params(std::move(params)) {}
/// Get the parameters contained in this directive.
auto getParams() const {
- return llvm::map_range(params, [](auto &el) {
- return cast<VariableElement>(el.get())->getParam();
+ return llvm::map_range(params, [](FormatElement *el) {
+ return cast<ParameterElement>(el)->getParam();
});
}
@@ -126,13 +73,11 @@ class ParamsDirectiveBase : public Element {
unsigned getNumParams() const { return params.size(); }
/// Take all of the parameters from this directive.
- SmallVector<std::unique_ptr<Element>> takeParams() {
- return std::move(params);
- }
+ std::vector<FormatElement *> takeParams() { return std::move(params); }
private:
/// The parameters captured by this directive.
- SmallVector<std::unique_ptr<Element>> params;
+ std::vector<FormatElement *> params;
};
/// This class represents a `params` directive that refers to all parameters
@@ -144,8 +89,7 @@ class ParamsDirectiveBase : public Element {
/// When used as an argument to another directive that accepts variables,
/// `params` can be used in place of manually listing all parameters of an
/// attribute or type.
-class ParamsDirective
- : public ParamsDirectiveBase<Element::Kind::ParamsDirective> {
+class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
public:
using Base::Base;
};
@@ -155,8 +99,7 @@ class ParamsDirective
///
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
///
-class StructDirective
- : public ParamsDirectiveBase<Element::Kind::StructDirective> {
+class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
public:
using Base::Base;
};
@@ -237,7 +180,7 @@ namespace {
class AttrOrTypeFormat {
public:
AttrOrTypeFormat(const AttrOrTypeDef &def,
- std::vector<std::unique_ptr<Element>> &&elements)
+ std::vector<FormatElement *> &&elements)
: def(def), elements(std::move(elements)) {}
/// Generate the attribute or type parser.
@@ -247,7 +190,7 @@ class AttrOrTypeFormat {
private:
/// Generate the parser code for a specific format element.
- void genElementParser(Element *el, FmtContext &ctx, MethodBody &os);
+ void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a literal.
void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a variable.
@@ -259,7 +202,7 @@ class AttrOrTypeFormat {
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a specific format element.
- void genElementPrinter(Element *el, FmtContext &ctx, MethodBody &os);
+ void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a literal.
void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a variable.
@@ -275,7 +218,7 @@ class AttrOrTypeFormat {
const AttrOrTypeDef &def;
/// The list of top-level format elements returned by the assembly format
/// parser.
- std::vector<std::unique_ptr<Element>> elements;
+ std::vector<FormatElement *> elements;
/// Flags for printing spaces.
bool shouldEmitSpace = false;
@@ -311,8 +254,8 @@ void AttrOrTypeFormat::genParser(MethodBody &os) {
&ctx);
/// Generate call to each parameter parser.
- for (auto &el : elements)
- genElementParser(el.get(), ctx, os);
+ for (FormatElement *el : elements)
+ genElementParser(el, ctx, os);
/// Generate call to the attribute or type builder. Use the checked getter
/// if one was generated.
@@ -328,11 +271,11 @@ void AttrOrTypeFormat::genParser(MethodBody &os) {
os << ");";
}
-void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx,
+void AttrOrTypeFormat::genElementParser(FormatElement *el, FmtContext &ctx,
MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
- if (auto *var = dyn_cast<VariableElement>(el))
+ if (auto *var = dyn_cast<ParameterElement>(el))
return genVariableParser(var->getParam(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsParser(params, ctx, os);
@@ -435,11 +378,11 @@ void AttrOrTypeFormat::genPrinter(MethodBody &os) {
/// Generate printers.
shouldEmitSpace = true;
lastWasPunctuation = false;
- for (auto &el : elements)
- genElementPrinter(el.get(), ctx, os);
+ for (FormatElement *el : elements)
+ genElementPrinter(el, ctx, os);
}
-void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
+void AttrOrTypeFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
@@ -447,7 +390,7 @@ void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx,
return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
- if (auto *var = dyn_cast<VariableElement>(el))
+ if (auto *var = dyn_cast<ParameterElement>(el))
return genVariablePrinter(var->getParam(), ctx, os,
var->shouldBeQualified());
@@ -492,7 +435,7 @@ void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
llvm::interleave(
el->getParams(),
[&](auto param) { this->genVariablePrinter(param, ctx, os); },
- [&]() { this->genLiteralPrinter(",", ctx, os); });
+ [&] { this->genLiteralPrinter(",", ctx, os); });
}
void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
@@ -504,75 +447,54 @@ void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
this->genLiteralPrinter("=", ctx, os);
this->genVariablePrinter(param, ctx, os);
},
- [&]() { this->genLiteralPrinter(",", ctx, os); });
+ [&] { this->genLiteralPrinter(",", ctx, os); });
}
//===----------------------------------------------------------------------===//
-// FormatParser
+// DefFormatParser
//===----------------------------------------------------------------------===//
namespace {
-class FormatParser {
+class DefFormatParser : public FormatParser {
public:
- FormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
- : lexer(mgr, def.getLoc()[0]), curToken(lexer.lexToken()), def(def),
+ DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
+ : FormatParser(mgr, def.getLoc()[0]), def(def),
seenParams(def.getNumParameters()) {}
/// Parse the attribute or type format and create the format elements.
FailureOr<AttrOrTypeFormat> parse();
-private:
- /// The current context of the parser when parsing an element.
- enum ParserContext {
- /// The element is being parsed in the default context - at the top of the
- /// format
- TopLevelContext,
- /// The element is being parsed as a child to a `struct` directive.
- StructDirective,
- };
-
- /// Emit an error.
- LogicalResult emitError(const Twine &msg) {
- lexer.emitError(curToken.getLoc(), msg);
- return failure();
+protected:
+ /// Verify the parsed elements.
+ LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
+ /// Verify the elements of a custom directive.
+ LogicalResult
+ verifyCustomDirectiveArguments(SMLoc loc,
+ ArrayRef<FormatElement *> arguments) override {
+ return emitError(loc, "'custom' not supported (yet)");
}
-
- /// Parse an expected token.
- LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
- if (curToken.getKind() != kind)
- return emitError(msg);
- consumeToken();
- return success();
+ /// Verify the elements of an optional group.
+ LogicalResult
+ verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
+ Optional<unsigned> anchorIndex) override {
+ return emitError(loc, "optional groups not (yet) supported");
}
- /// Advance the lexer to the next token.
- void consumeToken() {
- assert(curToken.getKind() != FormatToken::eof &&
- curToken.getKind() != FormatToken::error &&
- "shouldn't advance past EOF or errors");
- curToken = lexer.lexToken();
- }
+ /// Parse an attribute or type variable.
+ FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
+ Context ctx) override;
+ /// Parse an attribute or type format directive.
+ FailureOr<FormatElement *>
+ parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
- /// Parse any element.
- FailureOr<std::unique_ptr<Element>> parseElement(ParserContext ctx);
- /// Parse a literal element.
- FailureOr<std::unique_ptr<Element>> parseLiteral(ParserContext ctx);
- /// Parse a variable element.
- FailureOr<std::unique_ptr<Element>> parseVariable(ParserContext ctx);
- /// Parse a directive.
- FailureOr<std::unique_ptr<Element>> parseDirective(ParserContext ctx);
+private:
/// Parse a `params` directive.
- FailureOr<std::unique_ptr<Element>> parseParamsDirective();
+ FailureOr<FormatElement *> parseParamsDirective(SMLoc loc);
/// Parse a `qualified` directive.
- FailureOr<std::unique_ptr<Element>>
- parseQualifiedDirective(ParserContext ctx);
+ FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
- FailureOr<std::unique_ptr<Element>> parseStructDirective();
+ FailureOr<FormatElement *> parseStructDirective(SMLoc loc);
- /// The current format lexer.
- FormatLexer lexer;
- /// The current token in the stream.
- FormatToken curToken;
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
@@ -581,170 +503,132 @@ class FormatParser {
};
} // namespace
-FailureOr<AttrOrTypeFormat> FormatParser::parse() {
- std::vector<std::unique_ptr<Element>> elements;
- elements.reserve(16);
-
- /// Parse the format elements.
- while (curToken.getKind() != FormatToken::eof) {
- auto element = parseElement(TopLevelContext);
- if (failed(element))
- return failure();
-
- /// Add the format element and continue.
- elements.push_back(std::move(*element));
- }
-
- /// Check that all parameters have been seen.
+LogicalResult DefFormatParser::verify(SMLoc loc,
+ ArrayRef<FormatElement *> elements) {
for (auto &it : llvm::enumerate(def.getParameters())) {
if (!seenParams.test(it.index())) {
- return emitError("format is missing reference to parameter: " +
- it.value().getName());
+ return emitError(loc, "format is missing reference to parameter: " +
+ it.value().getName());
}
}
-
- return AttrOrTypeFormat(def, std::move(elements));
+ return success();
}
-FailureOr<std::unique_ptr<Element>>
-FormatParser::parseElement(ParserContext ctx) {
- if (curToken.getKind() == FormatToken::literal)
- return parseLiteral(ctx);
- if (curToken.getKind() == FormatToken::variable)
- return parseVariable(ctx);
- if (curToken.isKeyword())
- return parseDirective(ctx);
-
- return emitError("expected literal, directive, or variable");
-}
-
-FailureOr<std::unique_ptr<Element>>
-FormatParser::parseLiteral(ParserContext ctx) {
- if (ctx != TopLevelContext) {
- return emitError(
- "literals may only be used in the top-level section of the format");
- }
-
- /// Get the literal spelling without the surrounding "`".
- auto value = curToken.getSpelling().drop_front().drop_back();
- if (!isValidLiteral(value, [&](Twine diag) {
- (void)emitError("expected valid literal but got '" + value +
- "': " + diag);
- }))
+FailureOr<AttrOrTypeFormat> DefFormatParser::parse() {
+ FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
+ if (failed(elements))
return failure();
-
- consumeToken();
- return {std::make_unique<LiteralElement>(value)};
+ return AttrOrTypeFormat(def, std::move(*elements));
}
-FailureOr<std::unique_ptr<Element>>
-FormatParser::parseVariable(ParserContext ctx) {
- /// Get the parameter name without the preceding "$".
- auto name = curToken.getSpelling().drop_front();
-
+FailureOr<FormatElement *>
+DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
/// Lookup the parameter.
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
auto *it = llvm::find_if(
params, [&](auto ¶m) { return param.getName() == name; });
/// Check that the parameter reference is valid.
- if (it == params.end())
- return emitError(def.getName() + " has no parameter named '" + name + "'");
+ if (it == params.end()) {
+ return emitError(loc,
+ def.getName() + " has no parameter named '" + name + "'");
+ }
auto idx = std::distance(params.begin(), it);
if (seenParams.test(idx))
- return emitError("duplicate parameter '" + name + "'");
+ return emitError(loc, "duplicate parameter '" + name + "'");
seenParams.set(idx);
- consumeToken();
- return {std::make_unique<VariableElement>(*it)};
+ return create<ParameterElement>(*it);
}
-FailureOr<std::unique_ptr<Element>>
-FormatParser::parseDirective(ParserContext ctx) {
+FailureOr<FormatElement *>
+DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
+ Context ctx) {
- switch (curToken.getKind()) {
+ switch (kind) {
case FormatToken::kw_qualified:
- return parseQualifiedDirective(ctx);
+ return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_params:
- return parseParamsDirective();
+ return parseParamsDirective(loc);
case FormatToken::kw_struct:
if (ctx != TopLevelContext) {
return emitError(
+ loc,
"`struct` may only be used in the top-level section of the format");
}
- return parseStructDirective();
+ return parseStructDirective(loc);
+
default:
- return emitError("unknown directive in format: " + curToken.getSpelling());
+ return emitError(loc, "unsupported directive kind");
}
}
-FailureOr<std::unique_ptr<Element>>
-FormatParser::parseQualifiedDirective(ParserContext ctx) {
- consumeToken();
+FailureOr<FormatElement *>
+DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")))
return failure();
- FailureOr<std::unique_ptr<Element>> var = parseElement(ctx);
+ FailureOr<FormatElement *> var = parseElement(ctx);
if (failed(var))
return var;
- if (!isa<VariableElement>(*var))
- return emitError("`qualified` argument list expected a variable");
- cast<VariableElement>(var->get())->setShouldBeQualified();
+ if (!isa<ParameterElement>(*var))
+ return emitError(loc, "`qualified` argument list expected a variable");
+ cast<ParameterElement>(*var)->setShouldBeQualified();
if (failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
return var;
}
-FailureOr<std::unique_ptr<Element>> FormatParser::parseParamsDirective() {
- consumeToken();
+FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc) {
/// Collect all of the attribute's or type's parameters.
- SmallVector<std::unique_ptr<Element>> vars;
+ std::vector<FormatElement *> vars;
/// Ensure that none of the parameters have already been captured.
for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
- return emitError("`params` captures duplicate parameter: " +
- it.value().getName());
+ return emitError(loc, "`params` captures duplicate parameter: " +
+ it.value().getName());
}
seenParams.set(it.index());
- vars.push_back(std::make_unique<VariableElement>(it.value()));
+ vars.push_back(create<ParameterElement>(it.value()));
}
- return {std::make_unique<ParamsDirective>(std::move(vars))};
+ return create<ParamsDirective>(std::move(vars));
}
-FailureOr<std::unique_ptr<Element>> FormatParser::parseStructDirective() {
- consumeToken();
+FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc) {
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
return failure();
/// Parse variables captured by `struct`.
- SmallVector<std::unique_ptr<Element>> vars;
+ std::vector<FormatElement *> vars;
/// Parse first captured parameter or a `params` directive.
- FailureOr<std::unique_ptr<Element>> var = parseElement(StructDirective);
- if (failed(var) || !isa<VariableElement, ParamsDirective>(*var))
- return emitError("`struct` argument list expected a variable or directive");
+ FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
+ if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
+ return emitError(loc,
+ "`struct` argument list expected a variable or directive");
+ }
if (isa<VariableElement>(*var)) {
/// Parse any other parameters.
vars.push_back(std::move(*var));
- while (curToken.getKind() == FormatToken::comma) {
+ while (peekToken().is(FormatToken::comma)) {
consumeToken();
- var = parseElement(StructDirective);
+ var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement>(*var))
- return emitError("expected a variable in `struct` argument list");
+ return emitError(loc, "expected a variable in `struct` argument list");
vars.push_back(std::move(*var));
}
} else {
/// `struct(params)` captures all parameters in the attribute or type.
- vars = cast<ParamsDirective>(var->get())->takeParams();
+ vars = cast<ParamsDirective>(*var)->takeParams();
}
- if (curToken.getKind() != FormatToken::r_paren)
- return emitError("expected ')' at the end of an argument list");
+ if (failed(parseToken(FormatToken::r_paren,
+ "expected ')' at the end of an argument list")))
+ return failure();
- consumeToken();
- return {std::make_unique<::StructDirective>(std::move(vars))};
+ return create<StructDirective>(std::move(vars));
}
//===----------------------------------------------------------------------===//
@@ -756,11 +640,10 @@ void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
MethodBody &printer) {
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(
- llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()),
- SMLoc());
+ llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
/// Parse the custom assembly format>
- FormatParser fmtParser(mgr, def);
+ DefFormatParser fmtParser(mgr, def);
FailureOr<AttrOrTypeFormat> format = fmtParser.parse();
if (failed(format)) {
if (formatErrorIsFatal)
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index d656253c7551b..006995d7a5d65 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -177,6 +177,201 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
return FormatToken(kind, str);
}
+//===----------------------------------------------------------------------===//
+// FormatParser
+//===----------------------------------------------------------------------===//
+
+FormatParser::~FormatParser() = default;
+
+FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
+ SMLoc loc = curToken.getLoc();
+
+ // Parse each of the format elements into the main format.
+ std::vector<FormatElement *> elements;
+ while (curToken.getKind() != FormatToken::eof) {
+ FailureOr<FormatElement *> element = parseElement(TopLevelContext);
+ if (failed(element))
+ return failure();
+ elements.push_back(*element);
+ }
+
+ // Verify the format.
+ if (failed(verify(loc, elements)))
+ return failure();
+ return elements;
+}
+
+//===----------------------------------------------------------------------===//
+// Element Parsing
+
+FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
+ if (curToken.is(FormatToken::literal))
+ return parseLiteral(ctx);
+ if (curToken.is(FormatToken::variable))
+ return parseVariable(ctx);
+ if (curToken.isKeyword())
+ return parseDirective(ctx);
+ if (curToken.is(FormatToken::l_paren))
+ return parseOptionalGroup(ctx);
+ return emitError(curToken.getLoc(),
+ "expected literal, variable, directive, or optional group");
+}
+
+FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
+ FormatToken tok = curToken;
+ SMLoc loc = tok.getLoc();
+ consumeToken();
+
+ if (ctx != TopLevelContext) {
+ return emitError(
+ loc,
+ "literals may only be used in the top-level section of the format");
+ }
+ // Get the spelling without the surrounding backticks.
+ StringRef value = tok.getSpelling().drop_front().drop_back();
+
+ // The parsed literal is a space element (`` or ` `) or a newline.
+ if (value.empty() || value == " " || value == "\\n")
+ return create<WhitespaceElement>(value);
+
+ // Check that the parsed literal is valid.
+ if (!isValidLiteral(value, [&](Twine msg) {
+ (void)emitError(loc, "expected valid literal but got '" + value +
+ "': " + msg);
+ }))
+ return failure();
+ return create<LiteralElement>(value);
+}
+
+FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
+ FormatToken tok = curToken;
+ SMLoc loc = tok.getLoc();
+ consumeToken();
+
+ // Get the name of the variable without the leading `$`.
+ StringRef name = tok.getSpelling().drop_front();
+ return parseVariableImpl(loc, name, ctx);
+}
+
+FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
+ FormatToken tok = curToken;
+ SMLoc loc = tok.getLoc();
+ consumeToken();
+
+ if (tok.is(FormatToken::kw_custom))
+ return parseCustomDirective(loc, ctx);
+ return parseDirectiveImpl(loc, tok.getKind(), ctx);
+}
+
+FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
+ SMLoc loc = curToken.getLoc();
+ consumeToken();
+ if (ctx != TopLevelContext) {
+ return emitError(loc,
+ "optional groups can only be used as top-level elements");
+ }
+
+ // Parse the child elements for this optional group.
+ std::vector<FormatElement *> thenElements, elseElements;
+ Optional<unsigned> anchorIndex;
+ do {
+ FailureOr<FormatElement *> element = parseElement(TopLevelContext);
+ if (failed(element))
+ return failure();
+ // Check for an anchor.
+ if (curToken.is(FormatToken::caret)) {
+ if (anchorIndex)
+ return emitError(curToken.getLoc(), "only one element can be marked as "
+ "the anchor of an optional group");
+ anchorIndex = thenElements.size();
+ consumeToken();
+ }
+ thenElements.push_back(*element);
+ } while (!curToken.is(FormatToken::r_paren));
+ consumeToken();
+
+ // Parse the `else` elements of this optional group.
+ if (curToken.is(FormatToken::colon)) {
+ consumeToken();
+ if (failed(
+ parseToken(FormatToken::l_paren,
+ "expected '(' to start else branch of optional group")))
+ return failure();
+ do {
+ FailureOr<FormatElement *> element = parseElement(TopLevelContext);
+ if (failed(element))
+ return failure();
+ elseElements.push_back(*element);
+ } while (!curToken.is(FormatToken::r_paren));
+ consumeToken();
+ }
+ if (failed(parseToken(FormatToken::question,
+ "expected '?' after optional group")))
+ return failure();
+
+ // The optional group is required to have an anchor.
+ if (!anchorIndex)
+ return emitError(loc, "optional group has no anchor element");
+
+ // Verify the child elements.
+ if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) ||
+ failed(verifyOptionalGroupElements(loc, elseElements, llvm::None)))
+ return failure();
+
+ // Get the first parsable element. It must be an element that can be
+ // optionally-parsed.
+ auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
+ return isa<WhitespaceElement>(element);
+ });
+ if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
+ return emitError(loc, "first parsable element of an optional group must be "
+ "a literal or variable");
+ }
+
+ unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
+ return create<OptionalElement>(std::move(thenElements),
+ std::move(elseElements), *anchorIndex,
+ parseStart);
+}
+
+FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
+ Context ctx) {
+ if (ctx != TopLevelContext)
+ return emitError(loc, "'custom' is only valid as a top-level directive");
+
+ FailureOr<FormatToken> nameTok;
+ if (failed(parseToken(FormatToken::less,
+ "expected '<' before custom directive name")) ||
+ failed(nameTok =
+ parseToken(FormatToken::identifier,
+ "expected custom directive name identifier")) ||
+ failed(parseToken(FormatToken::greater,
+ "expected '>' after custom directive name")) ||
+ failed(parseToken(FormatToken::l_paren,
+ "expected '(' before custom directive parameters")))
+ return failure();
+
+ // Parse the arguments.
+ std::vector<FormatElement *> arguments;
+ while (true) {
+ FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext);
+ if (failed(argument))
+ return failure();
+ arguments.push_back(*argument);
+ if (!curToken.is(FormatToken::comma))
+ break;
+ consumeToken();
+ }
+
+ if (failed(parseToken(FormatToken::r_paren,
+ "expected ')' after custom directive parameters")))
+ return failure();
+
+ if (failed(verifyCustomDirectiveArguments(loc, arguments)))
+ return failure();
+ return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
+}
+
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 4d290a9ee1f90..d03ceec43942b 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -15,9 +15,13 @@
#define MLIR_TOOLS_MLIRTBLGEN_FORMATGEN_H_
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/SMLoc.h"
+#include <vector>
namespace llvm {
class SourceMgr;
@@ -85,6 +89,9 @@ class FormatToken {
/// Return a location for this token.
SMLoc getLoc() const;
+ /// Returns true if the token is of the given kind.
+ bool is(Kind kind) { return getKind() == kind; }
+
/// Return if this token is a keyword.
bool isKeyword() const {
return getKind() > Kind::keyword_start && getKind() < Kind::keyword_end;
@@ -115,8 +122,7 @@ class FormatLexer {
FormatToken emitError(SMLoc loc, const Twine &msg);
FormatToken emitError(const char *loc, const Twine &msg);
- FormatToken emitErrorAndNote(SMLoc loc, const Twine &msg,
- const Twine ¬e);
+ FormatToken emitErrorAndNote(SMLoc loc, const Twine &msg, const Twine ¬e);
private:
/// Return the next character in the stream.
@@ -142,6 +148,362 @@ class FormatLexer {
const char *curPtr;
};
+//===----------------------------------------------------------------------===//
+// FormatElement
+//===----------------------------------------------------------------------===//
+
+/// This class represents a single format element.
+///
+/// If you squint and take a close look, you can see the outline of a `Format`
+/// dialect.
+class FormatElement {
+public:
+ /// The top-level kinds of format elements.
+ enum Kind { Literal, Variable, Whitespace, Directive, Optional };
+
+ /// Support LLVM-style RTTI.
+ static bool classof(const FormatElement *el) { return true; }
+
+ /// Get the element kind.
+ Kind getKind() const { return kind; }
+
+protected:
+ /// Create a format element with the given kind.
+ FormatElement(Kind kind) : kind(kind) {}
+
+private:
+ /// The kind of the element.
+ Kind kind;
+};
+
+/// The base class for all format elements. This class implements common methods
+/// for LLVM-style RTTI.
+template <FormatElement::Kind ElementKind>
+class FormatElementBase : public FormatElement {
+public:
+ /// Support LLVM-style RTTI.
+ static bool classof(const FormatElement *el) {
+ return ElementKind == el->getKind();
+ }
+
+protected:
+ /// Create a format element with the given kind.
+ FormatElementBase() : FormatElement(ElementKind) {}
+};
+
+/// This class represents a literal element. A literal is either one of the
+/// supported punctuation characters (e.g. `(` or `,`) or a string literal (e.g.
+/// `literal`).
+class LiteralElement : public FormatElementBase<FormatElement::Literal> {
+public:
+ /// Create a literal element with the given spelling.
+ explicit LiteralElement(StringRef spelling) : spelling(spelling) {}
+
+ /// Get the spelling of the literal.
+ StringRef getSpelling() const { return spelling; }
+
+private:
+ /// The spelling of the variable, i.e. the string contained within the
+ /// backticks.
+ StringRef spelling;
+};
+
+/// This class represents a variable element. A variable refers to some part of
+/// the object being parsed, e.g. an attribute or operand on an operation or a
+/// parameter on an attribute.
+class VariableElement : public FormatElementBase<FormatElement::Variable> {
+public:
+ /// These are the kinds of variables.
+ enum Kind { Attribute, Operand, Region, Result, Successor, Parameter };
+
+ /// Get the kind of variable.
+ Kind getKind() const { return kind; }
+
+protected:
+ /// Create a variable with a kind.
+ VariableElement(Kind kind) : kind(kind) {}
+
+private:
+ /// The kind of variable.
+ Kind kind;
+};
+
+/// Base class for variable elements. This class implements common methods for
+/// LLVM-style RTTI.
+template <VariableElement::Kind VariableKind>
+class VariableElementBase : public VariableElement {
+public:
+ /// An element is of this class if it is a variable and has the same variable
+ /// type.
+ static bool classof(const FormatElement *el) {
+ if (auto *varEl = dyn_cast<VariableElement>(el))
+ return VariableKind == varEl->getKind();
+ return false;
+ }
+
+protected:
+ /// Create a variable element with the given variable kind.
+ VariableElementBase() : VariableElement(VariableKind) {}
+};
+
+/// This class represents a whitespace element, e.g. a newline or space. It is a
+/// literal that is printed but never parsed. When the value is empty, i.e. ``,
+/// a space is elided where one would have been printed automatically.
+class WhitespaceElement : public FormatElementBase<FormatElement::Whitespace> {
+public:
+ /// Create a whitespace element.
+ explicit WhitespaceElement(StringRef value) : value(value) {}
+
+ /// Get the whitespace value.
+ StringRef getValue() const { return value; }
+
+private:
+ /// The value of the whitespace element. Can be empty.
+ StringRef value;
+};
+
+class DirectiveElement : public FormatElementBase<FormatElement::Directive> {
+public:
+ /// These are the kinds of directives.
+ enum Kind {
+ AttrDict,
+ Custom,
+ FunctionalType,
+ Operands,
+ Ref,
+ Regions,
+ Results,
+ Successors,
+ Type,
+ Params,
+ Struct
+ };
+
+ /// Get the directive kind.
+ Kind getKind() const { return kind; }
+
+protected:
+ /// Create a directive element with a kind.
+ DirectiveElement(Kind kind) : kind(kind) {}
+
+private:
+ /// The directive kind.
+ Kind kind;
+};
+
+/// Base class for directive elements. This class implements common methods for
+/// LLVM-style RTTI.
+template <DirectiveElement::Kind DirectiveKind>
+class DirectiveElementBase : public DirectiveElement {
+public:
+ /// Create a directive element with the specified kind.
+ DirectiveElementBase() : DirectiveElement(DirectiveKind) {}
+
+ /// A format element is of this class if it is a directive element and has the
+ /// same kind.
+ static bool classof(const FormatElement *el) {
+ if (auto *directiveEl = dyn_cast<DirectiveElement>(el))
+ return DirectiveKind == directiveEl->getKind();
+ return false;
+ }
+};
+
+/// This class represents a custom format directive that is implemented by the
+/// user in C++. The directive accepts a list of arguments that is passed to the
+/// C++ function.
+class CustomDirective : public DirectiveElementBase<DirectiveElement::Custom> {
+public:
+ /// Create a custom directive with a name and list of arguments.
+ CustomDirective(StringRef name, std::vector<FormatElement *> &&arguments)
+ : name(name), arguments(std::move(arguments)) {}
+
+ /// Get the custom directive name.
+ StringRef getName() const { return name; }
+
+ /// Get the arguments to the custom directive.
+ ArrayRef<FormatElement *> getArguments() const { return arguments; }
+
+private:
+ /// The name of the custom directive. The name is used to call two C++
+ /// methods: `parse{name}` and `print{name}` with the given arguments.
+ StringRef name;
+ /// The arguments with which to call the custom functions. These are either
+ /// variables (for which the functions are responsible for populating) or
+ /// references to variables.
+ std::vector<FormatElement *> arguments;
+};
+
+/// This class represents a group of elements that are optionally emitted based
+/// on an optional variable "anchor" and a group of elements that are emitted
+/// when the anchor element is not present.
+class OptionalElement : public FormatElementBase<FormatElement::Optional> {
+public:
+ /// Create an optional group with the given child elements.
+ OptionalElement(std::vector<FormatElement *> &&thenElements,
+ std::vector<FormatElement *> &&elseElements,
+ unsigned anchorIndex, unsigned parseStart)
+ : thenElements(std::move(thenElements)),
+ elseElements(std::move(elseElements)), anchorIndex(anchorIndex),
+ parseStart(parseStart) {}
+
+ /// Return the `then` elements of the optional group.
+ ArrayRef<FormatElement *> getThenElements() const { return thenElements; }
+
+ /// Return the `else` elements of the optional group.
+ ArrayRef<FormatElement *> getElseElements() const { return elseElements; }
+
+ /// Return the anchor of the optional group.
+ FormatElement *getAnchor() const { return thenElements[anchorIndex]; }
+
+ /// Return the index of the first element to be parsed.
+ unsigned getParseStart() const { return parseStart; }
+
+private:
+ /// The child elements emitted when the anchor is present.
+ std::vector<FormatElement *> thenElements;
+ /// The child elements emitted when the anchor is not present.
+ std::vector<FormatElement *> elseElements;
+ /// The index of the anchor element of the optional group within
+ /// `thenElements`.
+ unsigned anchorIndex;
+ /// The index of the first element that is parsed in `thenElements`. That is,
+ /// the first non-whitespace element.
+ unsigned parseStart;
+};
+
+//===----------------------------------------------------------------------===//
+// FormatParserBase
+//===----------------------------------------------------------------------===//
+
+/// Base class for a parser that implements an assembly format. This class
+/// defines a common assembly format syntax and the creation of format elements.
+/// Subclasses will need to implement parsing for the format elements they
+/// support.
+class FormatParser {
+public:
+ /// Vtable anchor.
+ virtual ~FormatParser();
+
+ /// Parse the assembly format.
+ FailureOr<std::vector<FormatElement *>> parse();
+
+protected:
+ /// The current context of the parser when parsing an element.
+ enum Context {
+ /// The element is being parsed in a "top-level" context, i.e. at the top of
+ /// the format or in an optional group.
+ TopLevelContext,
+ /// The element is being parsed as a custom directive child.
+ CustomDirectiveContext,
+ /// The element is being parsed as a type directive child.
+ TypeDirectiveContext,
+ /// The element is being parsed as a reference directive child.
+ RefDirectiveContext,
+ /// The element is being parsed as a struct directive child.
+ StructDirectiveContext
+ };
+
+ /// Create a format parser with the given source manager and a location.
+ explicit FormatParser(llvm::SourceMgr &mgr, llvm::SMLoc loc)
+ : lexer(mgr, loc), curToken(lexer.lexToken()) {}
+
+ /// Allocate and construct a format element.
+ template <typename FormatElementT, typename... Args>
+ FormatElementT *create(Args &&...args) {
+ FormatElementT *ptr = allocator.Allocate<FormatElementT>();
+ ::new (ptr) FormatElementT(std::forward<Args>(args)...);
+ return ptr;
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Element Parsing
+
+ /// Parse a single element of any kind.
+ FailureOr<FormatElement *> parseElement(Context ctx);
+ /// Parse a literal.
+ FailureOr<FormatElement *> parseLiteral(Context ctx);
+ /// Parse a variable.
+ FailureOr<FormatElement *> parseVariable(Context ctx);
+ /// Parse a directive.
+ FailureOr<FormatElement *> parseDirective(Context ctx);
+ /// Parse an optional group.
+ FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
+
+ /// Parse a custom directive.
+ FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
+
+ /// Parse a format-specific variable kind.
+ virtual FailureOr<FormatElement *>
+ parseVariableImpl(llvm::SMLoc loc, StringRef name, Context ctx) = 0;
+ /// Parse a format-specific directive kind.
+ virtual FailureOr<FormatElement *>
+ parseDirectiveImpl(llvm::SMLoc loc, FormatToken::Kind kind, Context ctx) = 0;
+
+ //===--------------------------------------------------------------------===//
+ // Format Verification
+
+ /// Verify that the format is well-formed.
+ virtual LogicalResult verify(llvm::SMLoc loc,
+ ArrayRef<FormatElement *> elements) = 0;
+ /// Verify the arguments to a custom directive.
+ virtual LogicalResult
+ verifyCustomDirectiveArguments(llvm::SMLoc loc,
+ ArrayRef<FormatElement *> arguments) = 0;
+ /// Verify the elements of an optional group.
+ virtual LogicalResult
+ verifyOptionalGroupElements(llvm::SMLoc loc,
+ ArrayRef<FormatElement *> elements,
+ Optional<unsigned> anchorIndex) = 0;
+
+ //===--------------------------------------------------------------------===//
+ // Lexer Utilities
+
+ /// Emit an error at the given location.
+ LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
+ lexer.emitError(loc, msg);
+ return failure();
+ }
+
+ /// Emit an error and a note at the given notation.
+ LogicalResult emitErrorAndNote(llvm::SMLoc loc, const Twine &msg,
+ const Twine ¬e) {
+ lexer.emitErrorAndNote(loc, msg, note);
+ return failure();
+ }
+
+ /// Parse a single token of the expected kind.
+ FailureOr<FormatToken> parseToken(FormatToken::Kind kind, const Twine &msg) {
+ if (!curToken.is(kind))
+ return emitError(curToken.getLoc(), msg);
+ FormatToken tok = curToken;
+ consumeToken();
+ return tok;
+ }
+
+ /// Advance the lexer to the next token.
+ void consumeToken() {
+ assert(!curToken.is(FormatToken::eof) && !curToken.is(FormatToken::error) &&
+ "shouldn't advance past EOF or errors");
+ curToken = lexer.lexToken();
+ }
+
+ /// Get the current token.
+ FormatToken peekToken() { return curToken; }
+
+private:
+ /// The format parser retains ownership of the format elements in a bump
+ /// pointer allocator.
+ llvm::BumpPtrAllocator allocator;
+ /// The format lexer to use.
+ FormatLexer lexer;
+ /// The current token in the lexer.
+ FormatToken curToken;
+};
+
+//===----------------------------------------------------------------------===//
+// Utility Functions
+//===----------------------------------------------------------------------===//
+
/// Whether a space needs to be emitted before a literal. E.g., two keywords
/// back-to-back require a space separator, but a keyword followed by '<' does
/// not require a space.
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 0be172ef3fc81..7bbc2fae29eb8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -9,6 +9,7 @@
#include "OpFormatGen.h"
#include "FormatGen.h"
#include "OpClass.h"
+#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Class.h"
#include "mlir/TableGen/Format.h"
@@ -33,84 +34,37 @@ using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
-// Element
-//===----------------------------------------------------------------------===//
+// VariableElement
namespace {
-/// This class represents a single format element.
-class Element {
+/// This class represents an instance of an op variable element. A variable
+/// refers to something registered on the operation itself, e.g. an operand,
+/// result, attribute, region, or successor.
+template <typename VarT, VariableElement::Kind VariableKind>
+class OpVariableElement : public VariableElementBase<VariableKind> {
public:
- enum class Kind {
- /// This element is a directive.
- AttrDictDirective,
- CustomDirective,
- FunctionalTypeDirective,
- OperandsDirective,
- RefDirective,
- RegionsDirective,
- ResultsDirective,
- SuccessorsDirective,
- TypeDirective,
-
- /// This element is a literal.
- Literal,
-
- /// This element is a whitespace.
- Newline,
- Space,
-
- /// This element is an variable value.
- AttributeVariable,
- OperandVariable,
- RegionVariable,
- ResultVariable,
- SuccessorVariable,
-
- /// This element is an optional element.
- Optional,
- };
- Element(Kind kind) : kind(kind) {}
- virtual ~Element() = default;
+ using Base = OpVariableElement<VarT, VariableKind>;
- /// Return the kind of this element.
- Kind getKind() const { return kind; }
-
-private:
- /// The kind of this element.
- Kind kind;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// VariableElement
+ /// Create an op variable element with the variable value.
+ OpVariableElement(const VarT *var) : var(var) {}
-namespace {
-/// This class represents an instance of an variable element. A variable refers
-/// to something registered on the operation itself, e.g. an argument, result,
-/// etc.
-template <typename VarT, Element::Kind kindVal>
-class VariableElement : public Element {
-public:
- VariableElement(const VarT *var) : Element(kindVal), var(var) {}
- static bool classof(const Element *element) {
- return element->getKind() == kindVal;
- }
+ /// Get the variable.
const VarT *getVar() { return var; }
protected:
+ /// The op variable, e.g. a type or attribute constraint.
const VarT *var;
};
/// This class represents a variable that refers to an attribute argument.
struct AttributeVariable
- : public VariableElement<NamedAttribute, Element::Kind::AttributeVariable> {
- using VariableElement<NamedAttribute,
- Element::Kind::AttributeVariable>::VariableElement;
+ : public OpVariableElement<NamedAttribute, VariableElement::Attribute> {
+ using Base::Base;
/// Return the constant builder call for the type of this attribute, or None
/// if it doesn't have one.
- Optional<StringRef> getTypeBuilder() const {
- Optional<Type> attrType = var->attr.getValueType();
+ llvm::Optional<StringRef> getTypeBuilder() const {
+ llvm::Optional<Type> attrType = var->attr.getValueType();
return attrType ? attrType->getBuilderCall() : llvm::None;
}
@@ -132,54 +86,49 @@ struct AttributeVariable
/// This class represents a variable that refers to an operand argument.
using OperandVariable =
- VariableElement<NamedTypeConstraint, Element::Kind::OperandVariable>;
-
-/// This class represents a variable that refers to a region.
-using RegionVariable =
- VariableElement<NamedRegion, Element::Kind::RegionVariable>;
+ OpVariableElement<NamedTypeConstraint, VariableElement::Operand>;
/// This class represents a variable that refers to a result.
using ResultVariable =
- VariableElement<NamedTypeConstraint, Element::Kind::ResultVariable>;
+ OpVariableElement<NamedTypeConstraint, VariableElement::Result>;
+
+/// This class represents a variable that refers to a region.
+using RegionVariable = OpVariableElement<NamedRegion, VariableElement::Region>;
/// This class represents a variable that refers to a successor.
using SuccessorVariable =
- VariableElement<NamedSuccessor, Element::Kind::SuccessorVariable>;
+ OpVariableElement<NamedSuccessor, VariableElement::Successor>;
} // namespace
//===----------------------------------------------------------------------===//
// DirectiveElement
namespace {
-/// This class implements single kind directives.
-template <Element::Kind type> class DirectiveElement : public Element {
-public:
- DirectiveElement() : Element(type){};
- static bool classof(const Element *ele) { return ele->getKind() == type; }
-};
/// This class represents the `operands` directive. This directive represents
/// all of the operands of an operation.
-using OperandsDirective = DirectiveElement<Element::Kind::OperandsDirective>;
-
-/// This class represents the `regions` directive. This directive represents
-/// all of the regions of an operation.
-using RegionsDirective = DirectiveElement<Element::Kind::RegionsDirective>;
+using OperandsDirective = DirectiveElementBase<DirectiveElement::Operands>;
/// This class represents the `results` directive. This directive represents
/// all of the results of an operation.
-using ResultsDirective = DirectiveElement<Element::Kind::ResultsDirective>;
+using ResultsDirective = DirectiveElementBase<DirectiveElement::Results>;
+
+/// This class represents the `regions` directive. This directive represents
+/// all of the regions of an operation.
+using RegionsDirective = DirectiveElementBase<DirectiveElement::Regions>;
/// This class represents the `successors` directive. This directive represents
/// all of the successors of an operation.
-using SuccessorsDirective =
- DirectiveElement<Element::Kind::SuccessorsDirective>;
+using SuccessorsDirective = DirectiveElementBase<DirectiveElement::Successors>;
/// This class represents the `attr-dict` directive. This directive represents
/// the attribute dictionary of the operation.
class AttrDictDirective
- : public DirectiveElement<Element::Kind::AttrDictDirective> {
+ : public DirectiveElementBase<DirectiveElement::AttrDict> {
public:
explicit AttrDictDirective(bool withKeyword) : withKeyword(withKeyword) {}
+
+ /// Return whether the dictionary should be printed with the 'attributes'
+ /// keyword.
bool isWithKeyword() const { return withKeyword; }
private:
@@ -187,66 +136,41 @@ class AttrDictDirective
bool withKeyword;
};
-/// This class represents a custom format directive that is implemented by the
-/// user in C++.
-class CustomDirective : public Element {
-public:
- CustomDirective(StringRef name,
- std::vector<std::unique_ptr<Element>> &&arguments)
- : Element{Kind::CustomDirective}, name(name),
- arguments(std::move(arguments)) {}
-
- static bool classof(const Element *element) {
- return element->getKind() == Kind::CustomDirective;
- }
-
- /// Return the name of the custom directive.
- StringRef getName() const { return name; }
-
- /// Return the arguments to the custom directive.
- auto getArguments() const { return llvm::make_pointee_range(arguments); }
-
-private:
- /// The user provided name of the directive.
- StringRef name;
-
- /// The arguments to the custom directive.
- std::vector<std::unique_ptr<Element>> arguments;
-};
-
/// This class represents the `functional-type` directive. This directive takes
/// two arguments and formats them, respectively, as the inputs and results of a
/// FunctionType.
class FunctionalTypeDirective
- : public DirectiveElement<Element::Kind::FunctionalTypeDirective> {
+ : public DirectiveElementBase<DirectiveElement::FunctionalType> {
public:
- FunctionalTypeDirective(std::unique_ptr<Element> inputs,
- std::unique_ptr<Element> results)
- : inputs(std::move(inputs)), results(std::move(results)) {}
- Element *getInputs() const { return inputs.get(); }
- Element *getResults() const { return results.get(); }
+ FunctionalTypeDirective(FormatElement *inputs, FormatElement *results)
+ : inputs(inputs), results(results) {}
+
+ FormatElement *getInputs() const { return inputs; }
+ FormatElement *getResults() const { return results; }
private:
/// The input and result arguments.
- std::unique_ptr<Element> inputs, results;
+ FormatElement *inputs, *results;
};
/// This class represents the `ref` directive.
-class RefDirective : public DirectiveElement<Element::Kind::RefDirective> {
+class RefDirective : public DirectiveElementBase<DirectiveElement::Ref> {
public:
- RefDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
- Element *getOperand() const { return operand.get(); }
+ RefDirective(FormatElement *arg) : arg(arg) {}
+
+ FormatElement *getArg() const { return arg; }
private:
- /// The operand that is used to format the directive.
- std::unique_ptr<Element> operand;
+ /// The argument that is used to format the directive.
+ FormatElement *arg;
};
/// This class represents the `type` directive.
-class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
+class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
public:
- TypeDirective(std::unique_ptr<Element> arg) : operand(std::move(arg)) {}
- Element *getOperand() const { return operand.get(); }
+ TypeDirective(FormatElement *arg) : arg(arg) {}
+
+ FormatElement *getArg() const { return arg; }
/// Indicate if this type is printed "qualified" (that is it is
/// prefixed with the `!dialect.mnemonic`).
@@ -256,126 +180,13 @@ class TypeDirective : public DirectiveElement<Element::Kind::TypeDirective> {
}
private:
- /// The operand that is used to format the directive.
- std::unique_ptr<Element> operand;
+ /// The argument that is used to format the directive.
+ FormatElement *arg;
bool shouldBeQualifiedFlag = false;
};
} // namespace
-//===----------------------------------------------------------------------===//
-// LiteralElement
-
-namespace {
-/// This class represents an instance of a literal element.
-class LiteralElement : public Element {
-public:
- LiteralElement(StringRef literal)
- : Element{Kind::Literal}, literal(literal) {}
- static bool classof(const Element *element) {
- return element->getKind() == Kind::Literal;
- }
-
- /// Return the literal for this element.
- StringRef getLiteral() const { return literal; }
-
-private:
- /// The spelling of the literal for this element.
- StringRef literal;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// WhitespaceElement
-
-namespace {
-/// This class represents a whitespace element, e.g. newline or space. It's a
-/// literal that is printed but never parsed.
-class WhitespaceElement : public Element {
-public:
- WhitespaceElement(Kind kind) : Element{kind} {}
- static bool classof(const Element *element) {
- Kind kind = element->getKind();
- return kind == Kind::Newline || kind == Kind::Space;
- }
-};
-
-/// This class represents an instance of a newline element. It's a literal that
-/// prints a newline. It is ignored by the parser.
-class NewlineElement : public WhitespaceElement {
-public:
- NewlineElement() : WhitespaceElement(Kind::Newline) {}
- static bool classof(const Element *element) {
- return element->getKind() == Kind::Newline;
- }
-};
-
-/// This class represents an instance of a space element. It's a literal that
-/// prints or omits printing a space. It is ignored by the parser.
-class SpaceElement : public WhitespaceElement {
-public:
- SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {}
- static bool classof(const Element *element) {
- return element->getKind() == Kind::Space;
- }
-
- /// Returns true if this element should print as a space. Otherwise, the
- /// element should omit printing a space between the surrounding elements.
- bool getValue() const { return value; }
-
-private:
- bool value;
-};
-} // namespace
-
-//===----------------------------------------------------------------------===//
-// OptionalElement
-
-namespace {
-/// This class represents a group of elements that are optionally emitted based
-/// upon an optional variable of the operation, and a group of elements that are
-/// emotted when the anchor element is not present.
-class OptionalElement : public Element {
-public:
- OptionalElement(std::vector<std::unique_ptr<Element>> &&thenElements,
- std::vector<std::unique_ptr<Element>> &&elseElements,
- unsigned anchor, unsigned parseStart)
- : Element{Kind::Optional}, thenElements(std::move(thenElements)),
- elseElements(std::move(elseElements)), anchor(anchor),
- parseStart(parseStart) {}
- static bool classof(const Element *element) {
- return element->getKind() == Kind::Optional;
- }
-
- /// Return the `then` elements of this grouping.
- auto getThenElements() const {
- return llvm::make_pointee_range(thenElements);
- }
-
- /// Return the `else` elements of this grouping.
- auto getElseElements() const {
- return llvm::make_pointee_range(elseElements);
- }
-
- /// Return the anchor of this optional group.
- Element *getAnchor() const { return thenElements[anchor].get(); }
-
- /// Return the index of the first element that needs to be parsed.
- unsigned getParseStart() const { return parseStart; }
-
-private:
- /// The child elements of `then` branch of this optional.
- std::vector<std::unique_ptr<Element>> thenElements;
- /// The child elements of `else` branch of this optional.
- std::vector<std::unique_ptr<Element>> elseElements;
- /// The index of the element that acts as the anchor for the optional group.
- unsigned anchor;
- /// The index of the first element that is parsed (is not a
- /// WhitespaceElement).
- unsigned parseStart;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// OperationFormat
//===----------------------------------------------------------------------===//
@@ -450,7 +261,7 @@ struct OperationFormat {
/// Generate the operation parser from this format.
void genParser(Operator &op, OpClass &opClass);
/// Generate the parser code for a specific format element.
- void genElementParser(Element *element, MethodBody &body,
+ void genElementParser(FormatElement *element, MethodBody &body,
FmtContext &attrTypeCtx,
GenContext genCtx = GenContext::Normal);
/// Generate the C++ to resolve the types of operands and results during
@@ -471,11 +282,11 @@ struct OperationFormat {
void genPrinter(Operator &op, OpClass &opClass);
/// Generate the printer code for a specific format element.
- void genElementPrinter(Element *element, MethodBody &body, Operator &op,
+ void genElementPrinter(FormatElement *element, MethodBody &body, Operator &op,
bool &shouldEmitSpace, bool &lastWasPunctuation);
/// The various elements in this format.
- std::vector<std::unique_ptr<Element>> elements;
+ std::vector<FormatElement *> elements;
/// A flag indicating if all operand/result types were seen. If the format
/// contains these, it can not contain individual type resolvers.
@@ -848,7 +659,8 @@ getArgumentLengthKind(const NamedTypeConstraint *var) {
/// Get the name used for the type list for the given type directive operand.
/// 'lengthKind' to the corresponding kind for the given argument.
-static StringRef getTypeListName(Element *arg, ArgumentLengthKind &lengthKind) {
+static StringRef getTypeListName(FormatElement *arg,
+ ArgumentLengthKind &lengthKind) {
if (auto *operand = dyn_cast<OperandVariable>(arg)) {
lengthKind = getArgumentLengthKind(operand->getVar());
return operand->getVar()->name;
@@ -891,26 +703,26 @@ static void genLiteralParser(StringRef value, MethodBody &body) {
}
/// Generate the storage code required for parsing the given element.
-static void genElementParserStorage(Element *element, const Operator &op,
+static void genElementParserStorage(FormatElement *element, const Operator &op,
MethodBody &body) {
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- auto elements = optional->getThenElements();
+ ArrayRef<FormatElement *> elements = optional->getThenElements();
// If the anchor is a unit attribute, it won't be parsed directly so elide
// it.
auto *anchor = dyn_cast<AttributeVariable>(optional->getAnchor());
- Element *elidedAnchorElement = nullptr;
- if (anchor && anchor != &*elements.begin() && anchor->isUnitAttr())
+ FormatElement *elidedAnchorElement = nullptr;
+ if (anchor && anchor != elements.front() && anchor->isUnitAttr())
elidedAnchorElement = anchor;
- for (auto &childElement : elements)
- if (&childElement != elidedAnchorElement)
- genElementParserStorage(&childElement, op, body);
- for (auto &childElement : optional->getElseElements())
- genElementParserStorage(&childElement, op, body);
+ for (FormatElement *childElement : elements)
+ if (childElement != elidedAnchorElement)
+ genElementParserStorage(childElement, op, body);
+ for (FormatElement *childElement : optional->getElseElements())
+ genElementParserStorage(childElement, op, body);
} else if (auto *custom = dyn_cast<CustomDirective>(element)) {
- for (auto ¶mElement : custom->getArguments())
- genElementParserStorage(¶mElement, op, body);
+ for (FormatElement *paramElement : custom->getArguments())
+ genElementParserStorage(paramElement, op, body);
} else if (isa<OperandsDirective>(element)) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
@@ -972,7 +784,7 @@ static void genElementParserStorage(Element *element, const Operator &op,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
- StringRef name = getTypeListName(dir->getOperand(), lengthKind);
+ StringRef name = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind != ArgumentLengthKind::Single)
body << " ::mlir::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
else
@@ -990,12 +802,12 @@ static void genElementParserStorage(Element *element, const Operator &op,
}
/// Generate the parser for a parameter to a custom directive.
-static void genCustomParameterParser(Element ¶m, MethodBody &body) {
- if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
+ if (auto *attr = dyn_cast<AttributeVariable>(param)) {
body << attr->getVar()->name << "Attr";
- } else if (isa<AttrDictDirective>(¶m)) {
+ } else if (isa<AttrDictDirective>(param)) {
body << "result.attributes";
- } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
@@ -1007,26 +819,26 @@ static void genCustomParameterParser(Element ¶m, MethodBody &body) {
else
body << formatv("{0}RawOperands[0]", name);
- } else if (auto *region = dyn_cast<RegionVariable>(¶m)) {
+ } else if (auto *region = dyn_cast<RegionVariable>(param)) {
StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic())
body << llvm::formatv("{0}Regions", name);
else
body << llvm::formatv("*{0}Region", name);
- } else if (auto *successor = dyn_cast<SuccessorVariable>(¶m)) {
+ } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic())
body << llvm::formatv("{0}Successors", name);
else
body << llvm::formatv("{0}Successor", name);
- } else if (auto *dir = dyn_cast<RefDirective>(¶m)) {
- genCustomParameterParser(*dir->getOperand(), body);
+ } else if (auto *dir = dyn_cast<RefDirective>(param)) {
+ genCustomParameterParser(dir->getArg(), body);
- } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
ArgumentLengthKind lengthKind;
- StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}TypeGroups", listName);
else if (lengthKind == ArgumentLengthKind::Variadic)
@@ -1048,48 +860,48 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
// * Add a local variable for optional operands and types. This provides a
// better API to the user defined parser methods.
// * Set the location of operand variables.
- for (Element ¶m : dir->getArguments()) {
- if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ for (FormatElement *param : dir->getArguments()) {
+ if (auto *operand = dyn_cast<OperandVariable>(param)) {
auto *var = operand->getVar();
body << " " << var->name
<< "OperandsLoc = parser.getCurrentLocation();\n";
if (var->isOptional()) {
body << llvm::formatv(
- " llvm::Optional<::mlir::OpAsmParser::OperandType> "
+ " ::llvm::Optional<::mlir::OpAsmParser::OperandType> "
"{0}Operand;\n",
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(" "
- "llvm::SmallVector<llvm::SmallVector<::mlir::"
+ "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
"OpAsmParser::OperandType>> "
"{0}OperandGroups;\n",
var->name);
}
- } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
ArgumentLengthKind lengthKind;
- StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(
- " llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
+ " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
"{0}TypeGroups;\n",
listName);
}
- } else if (auto *dir = dyn_cast<RefDirective>(¶m)) {
- Element *input = dir->getOperand();
+ } else if (auto *dir = dyn_cast<RefDirective>(param)) {
+ FormatElement *input = dir->getArg();
if (auto *operand = dyn_cast<OperandVariable>(input)) {
if (!operand->getVar()->isOptional())
continue;
body << llvm::formatv(
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
"{1}Operands[0];\n",
- "llvm::Optional<::mlir::OpAsmParser::OperandType>",
+ "::llvm::Optional<::mlir::OpAsmParser::OperandType>",
operand->getVar()->name);
} else if (auto *type = dyn_cast<TypeDirective>(input)) {
ArgumentLengthKind lengthKind;
- StringRef listName = getTypeListName(type->getOperand(), lengthKind);
+ StringRef listName = getTypeListName(type->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
"::mlir::Type() : {0}Types[0];\n",
@@ -1100,7 +912,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
}
body << " if (parse" << dir->getName() << "(parser";
- for (Element ¶m : dir->getArguments()) {
+ for (FormatElement *param : dir->getArguments()) {
body << ", ";
genCustomParameterParser(param, body);
}
@@ -1109,15 +921,15 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
<< " return ::mlir::failure();\n";
// After parsing, add handling for any of the optional constructs.
- for (Element ¶m : dir->getArguments()) {
- if (auto *attr = dyn_cast<AttributeVariable>(¶m)) {
+ for (FormatElement *param : dir->getArguments()) {
+ if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name);
- } else if (auto *operand = dyn_cast<OperandVariable>(¶m)) {
+ } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) {
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
@@ -1131,9 +943,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) {
" }\n",
var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
}
- } else if (auto *dir = dyn_cast<TypeDirective>(¶m)) {
+ } else if (auto *dir = dyn_cast<TypeDirective>(param)) {
ArgumentLengthKind lengthKind;
- StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
@@ -1205,16 +1017,16 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
// Generate variables to store the operands and type within the format. This
// allows for referencing these variables in the presence of optional
// groupings.
- for (auto &element : elements)
- genElementParserStorage(&*element, op, body);
+ for (FormatElement *element : elements)
+ genElementParserStorage(element, op, body);
// A format context used when parsing attributes with buildable types.
FmtContext attrTypeCtx;
attrTypeCtx.withBuilder("parser.getBuilder()");
// Generate parsers for each of the elements.
- for (auto &element : elements)
- genElementParser(element.get(), body, attrTypeCtx);
+ for (FormatElement *element : elements)
+ genElementParser(element, body, attrTypeCtx);
// Generate the code to resolve the operand/result types and successors now
// that they have been parsed.
@@ -1226,23 +1038,23 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
body << " return ::mlir::success();\n";
}
-void OperationFormat::genElementParser(Element *element, MethodBody &body,
+void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
FmtContext &attrTypeCtx,
GenContext genCtx) {
/// Optional Group.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
- auto elements = llvm::drop_begin(optional->getThenElements(),
- optional->getParseStart());
+ ArrayRef<FormatElement *> elements =
+ optional->getThenElements().drop_front(optional->getParseStart());
// Generate a special optional parser for the first element to gate the
// parsing of the rest of the elements.
- Element *firstElement = &*elements.begin();
+ FormatElement *firstElement = elements.front();
if (auto *attrVar = dyn_cast<AttributeVariable>(firstElement)) {
genElementParser(attrVar, body, attrTypeCtx);
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (succeeded(parser.parseOptional";
- genLiteralParser(literal->getLiteral(), body);
+ genLiteralParser(literal->getSpelling(), body);
body << ")) {\n";
} else if (auto *opVar = dyn_cast<OperandVariable>(firstElement)) {
genElementParser(opVar, body, attrTypeCtx);
@@ -1265,7 +1077,7 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
- Element *elidedAnchorElement = nullptr;
+ FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(optional->getAnchor());
if (anchorAttr && anchorAttr != firstElement && anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
@@ -1277,20 +1089,17 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
// Generate the rest of the elements inside an optional group. Elements in
// an optional group after the guard are parsed as required.
- for (Element &childElement : llvm::drop_begin(elements, 1)) {
- if (&childElement != elidedAnchorElement) {
- genElementParser(&childElement, body, attrTypeCtx,
- GenContext::Optional);
- }
- }
+ for (FormatElement *childElement : llvm::drop_begin(elements, 1))
+ if (childElement != elidedAnchorElement)
+ genElementParser(childElement, body, attrTypeCtx, GenContext::Optional);
body << " }";
// Generate the else elements.
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
- for (Element &childElement : elseElements)
- genElementParser(&childElement, body, attrTypeCtx);
+ for (FormatElement *childElement : elseElements)
+ genElementParser(childElement, body, attrTypeCtx);
body << " }";
}
body << "\n";
@@ -1298,7 +1107,7 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
/// Literals.
} else if (LiteralElement *literal = dyn_cast<LiteralElement>(element)) {
body << " if (parser.parse";
- genLiteralParser(literal->getLiteral(), body);
+ genLiteralParser(literal->getSpelling(), body);
body << ")\n return ::mlir::failure();\n";
/// Whitespaces.
@@ -1398,7 +1207,7 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
- StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
+ StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
} else if (lengthKind == ArgumentLengthKind::Variadic) {
@@ -1408,7 +1217,7 @@ void OperationFormat::genElementParser(Element *element, MethodBody &body,
} else {
const char *parserCode =
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
- TypeSwitch<Element *>(dir->getOperand())
+ TypeSwitch<FormatElement *>(dir->getArg())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(parserCode,
operand->getVar()->constraint.getCPPClassName(),
@@ -1610,7 +1419,7 @@ void OperationFormat::genParserRegionResolution(Operator &op,
MethodBody &body) {
// Check for the case where all regions were parsed.
bool hasAllRegions = llvm::any_of(
- elements, [](auto &elt) { return isa<RegionsDirective>(elt.get()); });
+ elements, [](FormatElement *elt) { return isa<RegionsDirective>(elt); });
if (hasAllRegions) {
body << " result.addRegions(fullRegions);\n";
return;
@@ -1628,8 +1437,9 @@ void OperationFormat::genParserRegionResolution(Operator &op,
void OperationFormat::genParserSuccessorResolution(Operator &op,
MethodBody &body) {
// Check for the case where all successors were parsed.
- bool hasAllSuccessors = llvm::any_of(
- elements, [](auto &elt) { return isa<SuccessorsDirective>(elt.get()); });
+ bool hasAllSuccessors = llvm::any_of(elements, [](FormatElement *elt) {
+ return isa<SuccessorsDirective>(elt);
+ });
if (hasAllSuccessors) {
body << " result.addSuccessors(fullSuccessors);\n";
return;
@@ -1773,7 +1583,7 @@ static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace,
}
/// Generate the printer for a custom directive parameter.
-static void genCustomDirectiveParameterPrinter(Element *element,
+static void genCustomDirectiveParameterPrinter(FormatElement *element,
const Operator &op,
MethodBody &body) {
if (auto *attr = dyn_cast<AttributeVariable>(element)) {
@@ -1792,10 +1602,10 @@ static void genCustomDirectiveParameterPrinter(Element *element,
body << op.getGetterName(successor->getVar()->name) << "()";
} else if (auto *dir = dyn_cast<RefDirective>(element)) {
- genCustomDirectiveParameterPrinter(dir->getOperand(), op, body);
+ genCustomDirectiveParameterPrinter(dir->getArg(), op, body);
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
- auto *typeOperand = dir->getOperand();
+ auto *typeOperand = dir->getArg();
auto *operand = dyn_cast<OperandVariable>(typeOperand);
auto *var = operand ? operand->getVar()
: cast<ResultVariable>(typeOperand)->getVar();
@@ -1815,9 +1625,9 @@ static void genCustomDirectiveParameterPrinter(Element *element,
static void genCustomDirectivePrinter(CustomDirective *customDir,
const Operator &op, MethodBody &body) {
body << " print" << customDir->getName() << "(_odsPrinter, *this";
- for (Element ¶m : customDir->getArguments()) {
+ for (FormatElement *param : customDir->getArguments()) {
body << ", ";
- genCustomDirectiveParameterPrinter(¶m, op, body);
+ genCustomDirectiveParameterPrinter(param, op, body);
}
body << ");\n";
}
@@ -1841,7 +1651,7 @@ static void genVariadicRegionPrinter(const Twine ®ionListName,
}
/// Generate the C++ for an operand to a (*-)type directive.
-static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op,
+static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
MethodBody &body,
bool useArrayRef = true) {
if (isa<OperandsDirective>(arg))
@@ -1945,9 +1755,10 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
}
/// Generate the check for the anchor of an optional group.
-static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
+static void genOptionalGroupPrinterAnchor(FormatElement *anchor,
+ const Operator &op,
MethodBody &body) {
- TypeSwitch<Element *>(anchor)
+ TypeSwitch<FormatElement *>(anchor)
.Case<OperandVariable, ResultVariable>([&](auto *element) {
const NamedTypeConstraint *var = element->getVar();
std::string name = op.getGetterName(var->name);
@@ -1963,7 +1774,7 @@ static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
body << " if (!" << name << "().empty()) {\n";
})
.Case<TypeDirective>([&](TypeDirective *element) {
- genOptionalGroupPrinterAnchor(element->getOperand(), op, body);
+ genOptionalGroupPrinterAnchor(element->getArg(), op, body);
})
.Case<FunctionalTypeDirective>([&](FunctionalTypeDirective *element) {
genOptionalGroupPrinterAnchor(element->getInputs(), op, body);
@@ -1974,42 +1785,45 @@ static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op,
});
}
-void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
- Operator &op, bool &shouldEmitSpace,
+void OperationFormat::genElementPrinter(FormatElement *element,
+ MethodBody &body, Operator &op,
+ bool &shouldEmitSpace,
bool &lastWasPunctuation) {
if (LiteralElement *literal = dyn_cast<LiteralElement>(element))
- return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace,
+ return genLiteralPrinter(literal->getSpelling(), body, shouldEmitSpace,
lastWasPunctuation);
// Emit a whitespace element.
- if (isa<NewlineElement>(element)) {
- body << " _odsPrinter.printNewline();\n";
+ if (auto *space = dyn_cast<WhitespaceElement>(element)) {
+ if (space->getValue() == "\\n") {
+ body << " _odsPrinter.printNewline();\n";
+ } else {
+ genSpacePrinter(!space->getValue().empty(), body, shouldEmitSpace,
+ lastWasPunctuation);
+ }
return;
}
- if (SpaceElement *space = dyn_cast<SpaceElement>(element))
- return genSpacePrinter(space->getValue(), body, shouldEmitSpace,
- lastWasPunctuation);
// Emit an optional group.
if (OptionalElement *optional = dyn_cast<OptionalElement>(element)) {
// Emit the check for the presence of the anchor element.
- Element *anchor = optional->getAnchor();
+ FormatElement *anchor = optional->getAnchor();
genOptionalGroupPrinterAnchor(anchor, op, body);
// If the anchor is a unit attribute, we don't need to print it. When
// parsing, we will add this attribute if this group is present.
auto elements = optional->getThenElements();
- Element *elidedAnchorElement = nullptr;
+ FormatElement *elidedAnchorElement = nullptr;
auto *anchorAttr = dyn_cast<AttributeVariable>(anchor);
- if (anchorAttr && anchorAttr != &*elements.begin() &&
+ if (anchorAttr && anchorAttr != elements.front() &&
anchorAttr->isUnitAttr()) {
elidedAnchorElement = anchorAttr;
}
// Emit each of the elements.
- for (Element &childElement : elements) {
- if (&childElement != elidedAnchorElement) {
- genElementPrinter(&childElement, body, op, shouldEmitSpace,
+ for (FormatElement *childElement : elements) {
+ if (childElement != elidedAnchorElement) {
+ genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
}
@@ -2019,8 +1833,8 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
auto elseElements = optional->getElseElements();
if (!elseElements.empty()) {
body << " else {\n";
- for (Element &childElement : elseElements) {
- genElementPrinter(&childElement, body, op, shouldEmitSpace,
+ for (FormatElement *childElement : elseElements) {
+ genElementPrinter(childElement, body, op, shouldEmitSpace,
lastWasPunctuation);
}
body << " }";
@@ -2111,7 +1925,7 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), "
"_odsPrinter);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
- if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) {
+ if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << llvm::formatv(
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
@@ -2123,9 +1937,9 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
}
const NamedTypeConstraint *var = nullptr;
{
- if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand()))
+ if (auto *operand = dyn_cast<OperandVariable>(dir->getArg()))
var = operand->getVar();
- else if (auto *operand = dyn_cast<ResultVariable>(dir->getOperand()))
+ else if (auto *operand = dyn_cast<ResultVariable>(dir->getArg()))
var = operand->getVar();
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
@@ -2147,7 +1961,7 @@ void OperationFormat::genElementPrinter(Element *element, MethodBody &body,
return;
}
body << " _odsPrinter << ";
- genTypeOperandPrinter(dir->getOperand(), op, body, /*useArrayRef=*/false)
+ genTypeOperandPrinter(dir->getArg(), op, body, /*useArrayRef=*/false)
<< ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
body << " _odsPrinter.printFunctionalType(";
@@ -2167,13 +1981,12 @@ void OperationFormat::genPrinter(Operator &op, OpClass &opClass) {
// Flags for if we should emit a space, and if the last element was
// punctuation.
bool shouldEmitSpace = true, lastWasPunctuation = false;
- for (auto &element : elements)
- genElementPrinter(element.get(), body, op, shouldEmitSpace,
- lastWasPunctuation);
+ for (FormatElement *element : elements)
+ genElementPrinter(element, body, op, shouldEmitSpace, lastWasPunctuation);
}
//===----------------------------------------------------------------------===//
-// FormatParser
+// OpFormatParser
//===----------------------------------------------------------------------===//
/// Function to find an element within the given range that has the same name as
@@ -2186,30 +1999,35 @@ template <typename RangeT> static auto findArg(RangeT &&range, StringRef name) {
namespace {
/// This class implements a parser for an instance of an operation assembly
/// format.
-class FormatParser {
+class OpFormatParser : public FormatParser {
public:
- FormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
- : lexer(mgr, op.getLoc()[0]), curToken(lexer.lexToken()), fmt(format),
- op(op), seenOperandTypes(op.getNumOperands()),
+ OpFormatParser(llvm::SourceMgr &mgr, OperationFormat &format, Operator &op)
+ : FormatParser(mgr, op.getLoc()[0]), fmt(format), op(op),
+ seenOperandTypes(op.getNumOperands()),
seenResultTypes(op.getNumResults()) {}
- /// Parse the operation assembly format.
- LogicalResult parse();
+protected:
+ /// Verify the format elements.
+ LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
+ /// Verify the arguments to a custom directive.
+ LogicalResult
+ verifyCustomDirectiveArguments(SMLoc loc,
+ ArrayRef<FormatElement *> arguments) override;
+ /// Verify the elements of an optional group.
+ LogicalResult
+ verifyOptionalGroupElements(SMLoc loc, ArrayRef<FormatElement *> elements,
+ Optional<unsigned> anchorIndex) override;
+ LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
+ bool isAnchor);
+
+ /// Parse an operation variable.
+ FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
+ Context ctx) override;
+ /// Parse an operation format directive.
+ FailureOr<FormatElement *>
+ parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
private:
- /// The current context of the parser when parsing an element.
- enum ParserContext {
- /// The element is being parsed in a "top-level" context, i.e. at the top of
- /// the format or in an optional group.
- TopLevelContext,
- /// The element is being parsed as a custom directive child.
- CustomDirectiveContext,
- /// The element is being parsed as a type directive child.
- TypeDirectiveContext,
- /// The element is being parsed as a reference directive child.
- RefDirectiveContext
- };
-
/// This struct represents a type resolution instance. It includes a specific
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
@@ -2218,16 +2036,14 @@ class FormatParser {
Optional<StringRef> transformer;
};
- /// An iterator over the elements of a format group.
- using ElementsIterT = llvm::pointee_iterator<
- std::vector<std::unique_ptr<Element>>::const_iterator>;
+ using ElementsItT = ArrayRef<FormatElement *>::iterator;
/// Verify the state of operation attributes within the format.
- LogicalResult verifyAttributes(SMLoc loc);
+ LogicalResult verifyAttributes(SMLoc loc, ArrayRef<FormatElement *> elements);
/// Verify the attribute elements at the back of the given stack of iterators.
LogicalResult verifyAttributes(
SMLoc loc,
- SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack);
+ SmallVectorImpl<std::pair<ElementsItT, ElementsItT>> &iteratorStack);
/// Verify the state of operation operands within the format.
LogicalResult
@@ -2266,85 +2082,28 @@ class FormatParser {
/// within the format.
ConstArgument findSeenArg(StringRef name);
- /// Parse a specific element.
- LogicalResult parseElement(std::unique_ptr<Element> &element,
- ParserContext context);
- LogicalResult parseVariable(std::unique_ptr<Element> &element,
- ParserContext context);
- LogicalResult parseDirective(std::unique_ptr<Element> &element,
- ParserContext context);
- LogicalResult parseLiteral(std::unique_ptr<Element> &element,
- ParserContext context);
- LogicalResult parseOptional(std::unique_ptr<Element> &element,
- ParserContext context);
- LogicalResult parseOptionalChildElement(
- std::vector<std::unique_ptr<Element>> &childElements,
- Optional<unsigned> &anchorIdx);
- LogicalResult verifyOptionalChildElement(Element *element,
- SMLoc childLoc, bool isAnchor);
-
/// Parse the various
diff erent directives.
- LogicalResult parseAttrDictDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context,
- bool withKeyword);
- LogicalResult parseCustomDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context);
- LogicalResult parseCustomDirectiveParameter(
- std::vector<std::unique_ptr<Element>> ¶meters);
- LogicalResult parseFunctionalTypeDirective(std::unique_ptr<Element> &element,
- FormatToken tok,
- ParserContext context);
- LogicalResult parseOperandsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context);
- LogicalResult parseQualifiedDirective(std::unique_ptr<Element> &element,
- FormatToken tok, ParserContext context);
- LogicalResult parseReferenceDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context);
- LogicalResult parseRegionsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context);
- LogicalResult parseResultsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context);
- LogicalResult parseSuccessorsDirective(std::unique_ptr<Element> &element,
- SMLoc loc,
- ParserContext context);
- LogicalResult parseTypeDirective(std::unique_ptr<Element> &element,
- FormatToken tok, ParserContext context);
- LogicalResult parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
- bool isRefChild = false);
-
- //===--------------------------------------------------------------------===//
- // Lexer Utilities
- //===--------------------------------------------------------------------===//
-
- /// Advance the current lexer onto the next token.
- void consumeToken() {
- assert(curToken.getKind() != FormatToken::eof &&
- curToken.getKind() != FormatToken::error &&
- "shouldn't advance past EOF or errors");
- curToken = lexer.lexToken();
- }
- LogicalResult parseToken(FormatToken::Kind kind, const Twine &msg) {
- if (curToken.getKind() != kind)
- return emitError(curToken.getLoc(), msg);
- consumeToken();
- return ::mlir::success();
- }
- LogicalResult emitError(SMLoc loc, const Twine &msg) {
- lexer.emitError(loc, msg);
- return ::mlir::failure();
- }
- LogicalResult emitErrorAndNote(SMLoc loc, const Twine &msg,
- const Twine ¬e) {
- lexer.emitErrorAndNote(loc, msg, note);
- return ::mlir::failure();
- }
+ FailureOr<FormatElement *> parseAttrDictDirective(SMLoc loc, Context context,
+ bool withKeyword);
+ FailureOr<FormatElement *> parseFunctionalTypeDirective(SMLoc loc,
+ Context context);
+ FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
+ FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
+ Context context);
+ FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
+ Context context);
+ FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
+ FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
+ FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
+ Context context);
+ FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
+ FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
+ bool isRefChild = false);
//===--------------------------------------------------------------------===//
// Fields
//===--------------------------------------------------------------------===//
- FormatLexer lexer;
- FormatToken curToken;
OperationFormat &fmt;
Operator &op;
@@ -2361,17 +2120,8 @@ class FormatParser {
};
} // namespace
-LogicalResult FormatParser::parse() {
- SMLoc loc = curToken.getLoc();
-
- // Parse each of the format elements into the main format.
- while (curToken.getKind() != FormatToken::eof) {
- std::unique_ptr<Element> element;
- if (failed(parseElement(element, TopLevelContext)))
- return ::mlir::failure();
- fmt.elements.push_back(std::move(element));
- }
-
+LogicalResult OpFormatParser::verify(SMLoc loc,
+ ArrayRef<FormatElement *> elements) {
// Check that the attribute dictionary is in the format.
if (!hasAttrDict)
return emitError(loc, "'attr-dict' directive not found in "
@@ -2404,29 +2154,29 @@ LogicalResult FormatParser::parse() {
}
// Verify the state of the various operation components.
- if (failed(verifyAttributes(loc)) ||
+ if (failed(verifyAttributes(loc, elements)) ||
failed(verifyResults(loc, variableTyResolver)) ||
failed(verifyOperands(loc, variableTyResolver)) ||
failed(verifyRegions(loc)) || failed(verifySuccessors(loc)))
- return ::mlir::failure();
+ return failure();
// Collect the set of used attributes in the format.
fmt.usedAttributes = seenAttrs.takeVector();
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::verifyAttributes(SMLoc loc) {
+LogicalResult
+OpFormatParser::verifyAttributes(SMLoc loc,
+ ArrayRef<FormatElement *> elements) {
// Check that there are no `:` literals after an attribute without a constant
// type. The attribute grammar contains an optional trailing colon type, which
// can lead to unexpected and generally unintended behavior. Given that, it is
// better to just error out here instead.
- using ElementsIterT = llvm::pointee_iterator<
- std::vector<std::unique_ptr<Element>>::const_iterator>;
- SmallVector<std::pair<ElementsIterT, ElementsIterT>, 1> iteratorStack;
- iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end());
+ SmallVector<std::pair<ElementsItT, ElementsItT>, 1> iteratorStack;
+ iteratorStack.emplace_back(elements.begin(), elements.end());
while (!iteratorStack.empty())
if (failed(verifyAttributes(loc, iteratorStack)))
- return ::mlir::failure();
+ return ::failure();
// Check for VariadicOfVariadic variables. The segment attribute of those
// variables will be infered.
@@ -2437,16 +2187,16 @@ LogicalResult FormatParser::verifyAttributes(SMLoc loc) {
}
}
- return ::mlir::success();
+ return success();
}
/// Verify the attribute elements at the back of the given stack of iterators.
-LogicalResult FormatParser::verifyAttributes(
+LogicalResult OpFormatParser::verifyAttributes(
SMLoc loc,
- SmallVectorImpl<std::pair<ElementsIterT, ElementsIterT>> &iteratorStack) {
+ SmallVectorImpl<std::pair<ElementsItT, ElementsItT>> &iteratorStack) {
auto &stackIt = iteratorStack.back();
- ElementsIterT &it = stackIt.first, e = stackIt.second;
+ ElementsItT &it = stackIt.first, e = stackIt.second;
while (it != e) {
- Element *element = &*(it++);
+ FormatElement *element = *(it++);
// Traverse into optional groups.
if (auto *optional = dyn_cast<OptionalElement>(element)) {
@@ -2455,7 +2205,7 @@ LogicalResult FormatParser::verifyAttributes(
auto elseElements = optional->getElseElements();
iteratorStack.emplace_back(elseElements.begin(), elseElements.end());
- return ::mlir::success();
+ return success();
}
// We are checking for an attribute element followed by a `:`, so there is
@@ -2470,7 +2220,7 @@ LogicalResult FormatParser::verifyAttributes(
// Check the next iterator within the stack for literal elements.
for (auto &nextItPair : iteratorStack) {
- ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second;
+ ElementsItT nextIt = nextItPair.first, nextE = nextItPair.second;
for (; nextIt != nextE; ++nextIt) {
// Skip any trailing whitespace, attribute dictionaries, or optional
// groups.
@@ -2479,8 +2229,8 @@ LogicalResult FormatParser::verifyAttributes(
continue;
// We are only interested in `:` literals.
- auto *literal = dyn_cast<LiteralElement>(&*nextIt);
- if (!literal || literal->getLiteral() != ":")
+ auto *literal = dyn_cast<LiteralElement>(*nextIt);
+ if (!literal || literal->getSpelling() != ":")
break;
// TODO: Use the location of the literal element itself.
@@ -2493,12 +2243,11 @@ LogicalResult FormatParser::verifyAttributes(
}
}
iteratorStack.pop_back();
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::verifyOperands(
- SMLoc loc,
- llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
+LogicalResult OpFormatParser::verifyOperands(
+ SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
// Check that all of the operands are within the format, and their types can
// be inferred.
auto &buildableTypes = fmt.buildableTypes;
@@ -2541,13 +2290,13 @@ LogicalResult FormatParser::verifyOperands(
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.operandTypes[i].setBuilderIdx(it.first->second);
}
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::verifyRegions(SMLoc loc) {
+LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
// Check that all of the regions are within the format.
if (hasAllRegions)
- return ::mlir::success();
+ return success();
for (unsigned i = 0, e = op.getNumRegions(); i != e; ++i) {
const NamedRegion ®ion = op.getRegion(i);
@@ -2559,22 +2308,21 @@ LogicalResult FormatParser::verifyRegions(SMLoc loc) {
"' directive to the custom assembly format");
}
}
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::verifyResults(
- SMLoc loc,
- llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
+LogicalResult OpFormatParser::verifyResults(
+ SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
// If we format all of the types together, there is nothing to check.
if (fmt.allResultTypes)
- return ::mlir::success();
+ return success();
// If no result types are specified and we can infer them, infer all result
// types
if (op.getNumResults() > 0 && seenResultTypes.count() == 0 &&
canInferResultTypes) {
fmt.infersResultTypes = true;
- return ::mlir::success();
+ return success();
}
// Check that all of the result types can be inferred.
@@ -2608,13 +2356,13 @@ LogicalResult FormatParser::verifyResults(
auto it = buildableTypes.insert({*builder, buildableTypes.size()});
fmt.resultTypes[i].setBuilderIdx(it.first->second);
}
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::verifySuccessors(SMLoc loc) {
+LogicalResult OpFormatParser::verifySuccessors(SMLoc loc) {
// Check that all of the successors are within the format.
if (hasAllSuccessors)
- return ::mlir::success();
+ return success();
for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) {
const NamedSuccessor &successor = op.getSuccessor(i);
@@ -2626,10 +2374,10 @@ LogicalResult FormatParser::verifySuccessors(SMLoc loc) {
"' directive to the custom assembly format");
}
}
- return ::mlir::success();
+ return success();
}
-void FormatParser::handleAllTypesMatchConstraint(
+void OpFormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
@@ -2646,7 +2394,7 @@ void FormatParser::handleAllTypesMatchConstraint(
}
}
-void FormatParser::handleSameTypesConstraint(
+void OpFormatParser::handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults) {
const NamedTypeConstraint *resolver = nullptr;
@@ -2671,7 +2419,7 @@ void FormatParser::handleSameTypesConstraint(
}
}
-void FormatParser::handleTypesMatchConstraint(
+void OpFormatParser::handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
const llvm::Record &def) {
StringRef lhsName = def.getValueAsString("lhs");
@@ -2681,7 +2429,7 @@ void FormatParser::handleTypesMatchConstraint(
variableTyResolver[rhsName] = {arg, transformer};
}
-ConstArgument FormatParser::findSeenArg(StringRef name) {
+ConstArgument OpFormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
if (const NamedTypeConstraint *arg = findArg(op.getResults(), name))
@@ -2691,40 +2439,15 @@ ConstArgument FormatParser::findSeenArg(StringRef name) {
return nullptr;
}
-LogicalResult FormatParser::parseElement(std::unique_ptr<Element> &element,
- ParserContext context) {
- // Directives.
- if (curToken.isKeyword())
- return parseDirective(element, context);
- // Literals.
- if (curToken.getKind() == FormatToken::literal)
- return parseLiteral(element, context);
- // Optionals.
- if (curToken.getKind() == FormatToken::l_paren)
- return parseOptional(element, context);
- // Variables.
- if (curToken.getKind() == FormatToken::variable)
- return parseVariable(element, context);
- return emitError(curToken.getLoc(),
- "expected directive, literal, variable, or optional group");
-}
-
-LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
- ParserContext context) {
- FormatToken varTok = curToken;
- consumeToken();
-
- StringRef name = varTok.getSpelling().drop_front();
- SMLoc loc = varTok.getLoc();
-
- // Check that the parsed argument is something actually registered on the
- // op.
- /// Attributes
+FailureOr<FormatElement *>
+OpFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
+ // Check that the parsed argument is something actually registered on the op.
+ // Attributes
if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) {
- if (context == TypeDirectiveContext)
+ if (ctx == TypeDirectiveContext)
return emitError(
loc, "attributes cannot be used as children to a `type` directive");
- if (context == RefDirectiveContext) {
+ if (ctx == RefDirectiveContext) {
if (!seenAttrs.count(attr))
return emitError(loc, "attribute '" + name +
"' must be bound before it is referenced");
@@ -2732,280 +2455,92 @@ LogicalResult FormatParser::parseVariable(std::unique_ptr<Element> &element,
return emitError(loc, "attribute '" + name + "' is already bound");
}
- element = std::make_unique<AttributeVariable>(attr);
- return ::mlir::success();
+ return create<AttributeVariable>(attr);
}
- /// Operands
+ // Operands
if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) {
- if (context == TopLevelContext || context == CustomDirectiveContext) {
+ if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
if (fmt.allOperands || !seenOperands.insert(operand).second)
return emitError(loc, "operand '" + name + "' is already bound");
- } else if (context == RefDirectiveContext && !seenOperands.count(operand)) {
+ } else if (ctx == RefDirectiveContext && !seenOperands.count(operand)) {
return emitError(loc, "operand '" + name +
"' must be bound before it is referenced");
}
- element = std::make_unique<OperandVariable>(operand);
- return ::mlir::success();
+ return create<OperandVariable>(operand);
}
- /// Regions
+ // Regions
if (const NamedRegion *region = findArg(op.getRegions(), name)) {
- if (context == TopLevelContext || context == CustomDirectiveContext) {
+ if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
if (hasAllRegions || !seenRegions.insert(region).second)
return emitError(loc, "region '" + name + "' is already bound");
- } else if (context == RefDirectiveContext && !seenRegions.count(region)) {
+ } else if (ctx == RefDirectiveContext && !seenRegions.count(region)) {
return emitError(loc, "region '" + name +
"' must be bound before it is referenced");
} else {
return emitError(loc, "regions can only be used at the top level");
}
- element = std::make_unique<RegionVariable>(region);
- return ::mlir::success();
+ return create<RegionVariable>(region);
}
- /// Results.
+ // Results.
if (const auto *result = findArg(op.getResults(), name)) {
- if (context != TypeDirectiveContext)
+ if (ctx != TypeDirectiveContext)
return emitError(loc, "result variables can can only be used as a child "
"to a 'type' directive");
- element = std::make_unique<ResultVariable>(result);
- return ::mlir::success();
+ return create<ResultVariable>(result);
}
- /// Successors.
+ // Successors.
if (const auto *successor = findArg(op.getSuccessors(), name)) {
- if (context == TopLevelContext || context == CustomDirectiveContext) {
+ if (ctx == TopLevelContext || ctx == CustomDirectiveContext) {
if (hasAllSuccessors || !seenSuccessors.insert(successor).second)
return emitError(loc, "successor '" + name + "' is already bound");
- } else if (context == RefDirectiveContext &&
- !seenSuccessors.count(successor)) {
+ } else if (ctx == RefDirectiveContext && !seenSuccessors.count(successor)) {
return emitError(loc, "successor '" + name +
"' must be bound before it is referenced");
} else {
return emitError(loc, "successors can only be used at the top level");
}
- element = std::make_unique<SuccessorVariable>(successor);
- return ::mlir::success();
+ return create<SuccessorVariable>(successor);
}
return emitError(loc, "expected variable to refer to an argument, region, "
"result, or successor");
}
-LogicalResult FormatParser::parseDirective(std::unique_ptr<Element> &element,
- ParserContext context) {
- FormatToken dirTok = curToken;
- consumeToken();
-
- switch (dirTok.getKind()) {
+FailureOr<FormatElement *>
+OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
+ Context ctx) {
+ switch (kind) {
case FormatToken::kw_attr_dict:
- return parseAttrDictDirective(element, dirTok.getLoc(), context,
+ return parseAttrDictDirective(loc, ctx,
/*withKeyword=*/false);
case FormatToken::kw_attr_dict_w_keyword:
- return parseAttrDictDirective(element, dirTok.getLoc(), context,
+ return parseAttrDictDirective(loc, ctx,
/*withKeyword=*/true);
- case FormatToken::kw_custom:
- return parseCustomDirective(element, dirTok.getLoc(), context);
case FormatToken::kw_functional_type:
- return parseFunctionalTypeDirective(element, dirTok, context);
+ return parseFunctionalTypeDirective(loc, ctx);
case FormatToken::kw_operands:
- return parseOperandsDirective(element, dirTok.getLoc(), context);
+ return parseOperandsDirective(loc, ctx);
case FormatToken::kw_qualified:
- return parseQualifiedDirective(element, dirTok, context);
+ return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_regions:
- return parseRegionsDirective(element, dirTok.getLoc(), context);
+ return parseRegionsDirective(loc, ctx);
case FormatToken::kw_results:
- return parseResultsDirective(element, dirTok.getLoc(), context);
+ return parseResultsDirective(loc, ctx);
case FormatToken::kw_successors:
- return parseSuccessorsDirective(element, dirTok.getLoc(), context);
+ return parseSuccessorsDirective(loc, ctx);
case FormatToken::kw_ref:
- return parseReferenceDirective(element, dirTok.getLoc(), context);
+ return parseReferenceDirective(loc, ctx);
case FormatToken::kw_type:
- return parseTypeDirective(element, dirTok, context);
+ return parseTypeDirective(loc, ctx);
default:
- llvm_unreachable("unknown directive token");
+ return emitError(loc, "unsupported directive kind");
}
}
-LogicalResult FormatParser::parseLiteral(std::unique_ptr<Element> &element,
- ParserContext context) {
- FormatToken literalTok = curToken;
- if (context != TopLevelContext) {
- return emitError(
- literalTok.getLoc(),
- "literals may only be used in a top-level section of the format");
- }
- consumeToken();
-
- StringRef value = literalTok.getSpelling().drop_front().drop_back();
-
- // The parsed literal is a space element (`` or ` `).
- if (value.empty() || (value.size() == 1 && value.front() == ' ')) {
- element = std::make_unique<SpaceElement>(!value.empty());
- return ::mlir::success();
- }
- // The parsed literal is a newline element.
- if (value == "\\n") {
- element = std::make_unique<NewlineElement>();
- return ::mlir::success();
- }
-
- // Check that the parsed literal is valid.
- if (!isValidLiteral(value, [&](Twine diag) {
- (void)emitError(literalTok.getLoc(),
- "expected valid literal but got '" + value +
- "': " + diag);
- }))
- return failure();
- element = std::make_unique<LiteralElement>(value);
- return ::mlir::success();
-}
-
-LogicalResult FormatParser::parseOptional(std::unique_ptr<Element> &element,
- ParserContext context) {
- SMLoc curLoc = curToken.getLoc();
- if (context != TopLevelContext)
- return emitError(curLoc, "optional groups can only be used as top-level "
- "elements");
- consumeToken();
-
- // Parse the child elements for this optional group.
- std::vector<std::unique_ptr<Element>> thenElements, elseElements;
- Optional<unsigned> anchorIdx;
- do {
- if (failed(parseOptionalChildElement(thenElements, anchorIdx)))
- return ::mlir::failure();
- } while (curToken.getKind() != FormatToken::r_paren);
- consumeToken();
-
- // Parse the `else` elements of this optional group.
- if (curToken.getKind() == FormatToken::colon) {
- consumeToken();
- if (failed(parseToken(FormatToken::l_paren,
- "expected '(' to start else branch "
- "of optional group")))
- return failure();
- do {
- SMLoc childLoc = curToken.getLoc();
- elseElements.push_back({});
- if (failed(parseElement(elseElements.back(), TopLevelContext)) ||
- failed(verifyOptionalChildElement(elseElements.back().get(), childLoc,
- /*isAnchor=*/false)))
- return failure();
- } while (curToken.getKind() != FormatToken::r_paren);
- consumeToken();
- }
-
- if (failed(parseToken(FormatToken::question,
- "expected '?' after optional group")))
- return ::mlir::failure();
-
- // The optional group is required to have an anchor.
- if (!anchorIdx)
- return emitError(curLoc, "optional group specified no anchor element");
-
- // The first parsable element of the group must be able to be parsed in an
- // optional fashion.
- auto parseBegin = llvm::find_if_not(thenElements, [](auto &element) {
- return isa<WhitespaceElement>(element.get());
- });
- Element *firstElement = parseBegin->get();
- if (!isa<AttributeVariable>(firstElement) &&
- !isa<LiteralElement>(firstElement) &&
- !isa<OperandVariable>(firstElement) && !isa<RegionVariable>(firstElement))
- return emitError(curLoc,
- "first parsable element of an operand group must be "
- "an attribute, literal, operand, or region");
-
- auto parseStart = parseBegin - thenElements.begin();
- element = std::make_unique<OptionalElement>(
- std::move(thenElements), std::move(elseElements), *anchorIdx, parseStart);
- return ::mlir::success();
-}
-
-LogicalResult FormatParser::parseOptionalChildElement(
- std::vector<std::unique_ptr<Element>> &childElements,
- Optional<unsigned> &anchorIdx) {
- SMLoc childLoc = curToken.getLoc();
- childElements.push_back({});
- if (failed(parseElement(childElements.back(), TopLevelContext)))
- return ::mlir::failure();
-
- // Check to see if this element is the anchor of the optional group.
- bool isAnchor = curToken.getKind() == FormatToken::caret;
- if (isAnchor) {
- if (anchorIdx)
- return emitError(childLoc, "only one element can be marked as the anchor "
- "of an optional group");
- anchorIdx = childElements.size() - 1;
- consumeToken();
- }
-
- return verifyOptionalChildElement(childElements.back().get(), childLoc,
- isAnchor);
-}
-
-LogicalResult FormatParser::verifyOptionalChildElement(Element *element,
- SMLoc childLoc,
- bool isAnchor) {
- return TypeSwitch<Element *, LogicalResult>(element)
- // All attributes can be within the optional group, but only optional
- // attributes can be the anchor.
- .Case([&](AttributeVariable *attrEle) {
- if (isAnchor && !attrEle->getVar()->attr.isOptional())
- return emitError(childLoc, "only optional attributes can be used to "
- "anchor an optional group");
- return ::mlir::success();
- })
- // Only optional-like(i.e. variadic) operands can be within an optional
- // group.
- .Case([&](OperandVariable *ele) {
- if (!ele->getVar()->isVariableLength())
- return emitError(childLoc, "only variable length operands can be "
- "used within an optional group");
- return ::mlir::success();
- })
- // Only optional-like(i.e. variadic) results can be within an optional
- // group.
- .Case([&](ResultVariable *ele) {
- if (!ele->getVar()->isVariableLength())
- return emitError(childLoc, "only variable length results can be "
- "used within an optional group");
- return ::mlir::success();
- })
- .Case([&](RegionVariable *) {
- // TODO: When ODS has proper support for marking "optional" regions, add
- // a check here.
- return ::mlir::success();
- })
- .Case([&](TypeDirective *ele) {
- return verifyOptionalChildElement(ele->getOperand(), childLoc,
- /*isAnchor=*/false);
- })
- .Case([&](FunctionalTypeDirective *ele) {
- if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc,
- /*isAnchor=*/false)))
- return failure();
- return verifyOptionalChildElement(ele->getResults(), childLoc,
- /*isAnchor=*/false);
- })
- // Literals, whitespace, and custom directives may be used, but they can't
- // anchor the group.
- .Case<LiteralElement, WhitespaceElement, CustomDirective,
- FunctionalTypeDirective, OptionalElement>([&](Element *) {
- if (isAnchor)
- return emitError(childLoc, "only variables and types can be used "
- "to anchor an optional group");
- return ::mlir::success();
- })
- .Default([&](Element *) {
- return emitError(childLoc, "only literals, types, and variables can be "
- "used within an optional group");
- });
-}
-
-LogicalResult
-FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context,
- bool withKeyword) {
+FailureOr<FormatElement *>
+OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
+ bool withKeyword) {
if (context == TypeDirectiveContext)
return emitError(loc, "'attr-dict' directive can only be used as a "
"top-level directive");
@@ -3022,104 +2557,50 @@ FormatParser::parseAttrDictDirective(std::unique_ptr<Element> &element,
hasAttrDict = true;
}
- element = std::make_unique<AttrDictDirective>(withKeyword);
- return ::mlir::success();
+ return create<AttrDictDirective>(withKeyword);
}
-LogicalResult
-FormatParser::parseCustomDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
- SMLoc curLoc = curToken.getLoc();
- if (context != TopLevelContext)
- return emitError(loc, "'custom' is only valid as a top-level directive");
-
- // Parse the custom directive name.
- if (failed(parseToken(FormatToken::less,
- "expected '<' before custom directive name")))
- return ::mlir::failure();
-
- FormatToken nameTok = curToken;
- if (failed(parseToken(FormatToken::identifier,
- "expected custom directive name identifier")) ||
- failed(parseToken(FormatToken::greater,
- "expected '>' after custom directive name")) ||
- failed(parseToken(FormatToken::l_paren,
- "expected '(' before custom directive parameters")))
- return ::mlir::failure();
-
- // Parse the child elements for this optional group.=
- std::vector<std::unique_ptr<Element>> elements;
- do {
- if (failed(parseCustomDirectiveParameter(elements)))
- return ::mlir::failure();
- if (curToken.getKind() != FormatToken::comma)
- break;
- consumeToken();
- } while (true);
-
- if (failed(parseToken(FormatToken::r_paren,
- "expected ')' after custom directive parameters")))
- return ::mlir::failure();
-
- // After parsing all of the elements, ensure that all type directives refer
- // only to variables.
- for (auto &ele : elements) {
- if (auto *typeEle = dyn_cast<TypeDirective>(ele.get())) {
- if (!isa<OperandVariable, ResultVariable>(typeEle->getOperand())) {
- return emitError(curLoc, "type directives within a custom directive "
- "may only refer to variables");
+LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
+ SMLoc loc, ArrayRef<FormatElement *> arguments) {
+ for (FormatElement *argument : arguments) {
+ if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
+ OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
+ // TODO: FormatElement should have location info attached.
+ return emitError(loc, "only variables and types may be used as "
+ "parameters to a custom directive");
+ }
+ if (auto *type = dyn_cast<TypeDirective>(argument)) {
+ if (!isa<OperandVariable, ResultVariable>(type->getArg())) {
+ return emitError(loc, "type directives within a custom directive may "
+ "only refer to variables");
}
}
}
-
- element = std::make_unique<CustomDirective>(nameTok.getSpelling(),
- std::move(elements));
- return ::mlir::success();
-}
-
-LogicalResult FormatParser::parseCustomDirectiveParameter(
- std::vector<std::unique_ptr<Element>> ¶meters) {
- SMLoc childLoc = curToken.getLoc();
- parameters.push_back({});
- if (failed(parseElement(parameters.back(), CustomDirectiveContext)))
- return ::mlir::failure();
-
- // Verify that the element can be placed within a custom directive.
- if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
- OperandVariable, RegionVariable, SuccessorVariable>(
- parameters.back().get())) {
- return emitError(childLoc, "only variables and types may be used as "
- "parameters to a custom directive");
- }
- return ::mlir::success();
+ return success();
}
-LogicalResult FormatParser::parseFunctionalTypeDirective(
- std::unique_ptr<Element> &element, FormatToken tok, ParserContext context) {
- SMLoc loc = tok.getLoc();
+FailureOr<FormatElement *>
+OpFormatParser::parseFunctionalTypeDirective(SMLoc loc, Context context) {
if (context != TopLevelContext)
return emitError(
loc, "'functional-type' is only valid as a top-level directive");
// Parse the main operand.
- std::unique_ptr<Element> inputs, results;
+ FailureOr<FormatElement *> inputs, results;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
- failed(parseTypeDirectiveOperand(inputs)) ||
+ failed(inputs = parseTypeDirectiveOperand(loc)) ||
failed(parseToken(FormatToken::comma,
"expected ',' after inputs argument")) ||
- failed(parseTypeDirectiveOperand(results)) ||
+ failed(results = parseTypeDirectiveOperand(loc)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return ::mlir::failure();
- element = std::make_unique<FunctionalTypeDirective>(std::move(inputs),
- std::move(results));
- return ::mlir::success();
+ return failure();
+ return create<FunctionalTypeDirective>(*inputs, *results);
}
-LogicalResult
-FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
if (context == RefDirectiveContext) {
if (!fmt.allOperands)
return emitError(loc, "'ref' of 'operands' is not bound by a prior "
@@ -3130,31 +2611,27 @@ FormatParser::parseOperandsDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'operands' directive creates overlap in format");
fmt.allOperands = true;
}
- element = std::make_unique<OperandsDirective>();
- return ::mlir::success();
+ return create<OperandsDirective>();
}
-LogicalResult
-FormatParser::parseReferenceDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
if (context != CustomDirectiveContext)
return emitError(loc, "'ref' is only valid within a `custom` directive");
- std::unique_ptr<Element> operand;
+ FailureOr<FormatElement *> arg;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
- failed(parseElement(operand, RefDirectiveContext)) ||
+ failed(arg = parseElement(RefDirectiveContext)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return ::mlir::failure();
+ return failure();
- element = std::make_unique<RefDirective>(std::move(operand));
- return ::mlir::success();
+ return create<RefDirective>(*arg);
}
-LogicalResult
-FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
if (context == TypeDirectiveContext)
return emitError(loc, "'regions' is only valid as a top-level directive");
if (context == RefDirectiveContext) {
@@ -3168,23 +2645,19 @@ FormatParser::parseRegionsDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'regions' directive creates overlap in format");
hasAllRegions = true;
}
- element = std::make_unique<RegionsDirective>();
- return ::mlir::success();
+ return create<RegionsDirective>();
}
-LogicalResult
-FormatParser::parseResultsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseResultsDirective(SMLoc loc, Context context) {
if (context != TypeDirectiveContext)
return emitError(loc, "'results' directive can can only be used as a child "
"to a 'type' directive");
- element = std::make_unique<ResultsDirective>();
- return ::mlir::success();
+ return create<ResultsDirective>();
}
-LogicalResult
-FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
- SMLoc loc, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseSuccessorsDirective(SMLoc loc, Context context) {
if (context == TypeDirectiveContext)
return emitError(loc,
"'successors' is only valid as a top-level directive");
@@ -3199,62 +2672,59 @@ FormatParser::parseSuccessorsDirective(std::unique_ptr<Element> &element,
return emitError(loc, "'successors' directive creates overlap in format");
hasAllSuccessors = true;
}
- element = std::make_unique<SuccessorsDirective>();
- return ::mlir::success();
+ return create<SuccessorsDirective>();
}
-LogicalResult
-FormatParser::parseTypeDirective(std::unique_ptr<Element> &element,
- FormatToken tok, ParserContext context) {
- SMLoc loc = tok.getLoc();
+FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
+ Context context) {
if (context == TypeDirectiveContext)
return emitError(loc, "'type' cannot be used as a child of another `type`");
bool isRefChild = context == RefDirectiveContext;
- std::unique_ptr<Element> operand;
+ FailureOr<FormatElement *> operand;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
- failed(parseTypeDirectiveOperand(operand, isRefChild)) ||
+ failed(operand = parseTypeDirectiveOperand(loc, isRefChild)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return ::mlir::failure();
+ return failure();
- element = std::make_unique<TypeDirective>(std::move(operand));
- return ::mlir::success();
+ return create<TypeDirective>(*operand);
}
-LogicalResult
-FormatParser::parseQualifiedDirective(std::unique_ptr<Element> &element,
- FormatToken tok, ParserContext context) {
+FailureOr<FormatElement *>
+OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
+ FailureOr<FormatElement *> element;
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before argument list")) ||
- failed(parseElement(element, context)) ||
+ failed(element = parseElement(context)) ||
failed(
parseToken(FormatToken::r_paren, "expected ')' after argument list")))
return failure();
- if (auto *attr = dyn_cast<AttributeVariable>(element.get())) {
- attr->setShouldBeQualified();
- } else if (auto *type = dyn_cast<TypeDirective>(element.get())) {
- type->setShouldBeQualified();
- } else {
- return emitError(
- tok.getLoc(),
- "'qualified' directive expects an attribute or a `type` directive");
- }
- return success();
+ return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
+ .Case<AttributeVariable, TypeDirective>([](auto *element) {
+ element->setShouldBeQualified();
+ return element;
+ })
+ .Default([&](auto *element) {
+ return emitError(
+ loc,
+ "'qualified' directive expects an attribute or a `type` directive");
+ });
}
-LogicalResult
-FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
- bool isRefChild) {
- SMLoc loc = curToken.getLoc();
- if (failed(parseElement(element, TypeDirectiveContext)))
- return ::mlir::failure();
- if (isa<LiteralElement>(element.get()))
+FailureOr<FormatElement *>
+OpFormatParser::parseTypeDirectiveOperand(SMLoc loc, bool isRefChild) {
+ FailureOr<FormatElement *> result = parseElement(TypeDirectiveContext);
+ if (failed(result))
+ return failure();
+
+ FormatElement *element = *result;
+ if (isa<LiteralElement>(element))
return emitError(
loc, "'type' directive operand expects variable or directive operand");
- if (auto *var = dyn_cast<OperandVariable>(element.get())) {
+ if (auto *var = dyn_cast<OperandVariable>(element)) {
unsigned opIdx = var->getVar() - op.operand_begin();
if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx)))
return emitError(loc, "'type' of '" + var->getVar()->name +
@@ -3263,7 +2733,7 @@ FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
return emitError(loc, "'ref' of 'type($" + var->getVar()->name +
")' is not bound by a prior 'type' directive");
seenOperandTypes.set(opIdx);
- } else if (auto *var = dyn_cast<ResultVariable>(element.get())) {
+ } else if (auto *var = dyn_cast<ResultVariable>(element)) {
unsigned resIdx = var->getVar() - op.result_begin();
if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx)))
return emitError(loc, "'type' of '" + var->getVar()->name +
@@ -3289,7 +2759,78 @@ FormatParser::parseTypeDirectiveOperand(std::unique_ptr<Element> &element,
} else {
return emitError(loc, "invalid argument to 'type' directive");
}
- return ::mlir::success();
+ return element;
+}
+
+LogicalResult
+OpFormatParser::verifyOptionalGroupElements(SMLoc loc,
+ ArrayRef<FormatElement *> elements,
+ Optional<unsigned> anchorIndex) {
+ for (auto &it : llvm::enumerate(elements)) {
+ if (failed(verifyOptionalGroupElement(
+ loc, it.value(), anchorIndex && *anchorIndex == it.index())))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult OpFormatParser::verifyOptionalGroupElement(SMLoc loc,
+ FormatElement *element,
+ bool isAnchor) {
+ return TypeSwitch<FormatElement *, LogicalResult>(element)
+ // All attributes can be within the optional group, but only optional
+ // attributes can be the anchor.
+ .Case([&](AttributeVariable *attrEle) {
+ if (isAnchor && !attrEle->getVar()->attr.isOptional())
+ return emitError(loc, "only optional attributes can be used to "
+ "anchor an optional group");
+ return success();
+ })
+ // Only optional-like(i.e. variadic) operands can be within an optional
+ // group.
+ .Case([&](OperandVariable *ele) {
+ if (!ele->getVar()->isVariableLength())
+ return emitError(loc, "only variable length operands can be used "
+ "within an optional group");
+ return success();
+ })
+ // Only optional-like(i.e. variadic) results can be within an optional
+ // group.
+ .Case([&](ResultVariable *ele) {
+ if (!ele->getVar()->isVariableLength())
+ return emitError(loc, "only variable length results can be used "
+ "within an optional group");
+ return success();
+ })
+ .Case([&](RegionVariable *) {
+ // TODO: When ODS has proper support for marking "optional" regions, add
+ // a check here.
+ return success();
+ })
+ .Case([&](TypeDirective *ele) {
+ return verifyOptionalGroupElement(loc, ele->getArg(),
+ /*isAnchor=*/false);
+ })
+ .Case([&](FunctionalTypeDirective *ele) {
+ if (failed(verifyOptionalGroupElement(loc, ele->getInputs(),
+ /*isAnchor=*/false)))
+ return failure();
+ return verifyOptionalGroupElement(loc, ele->getResults(),
+ /*isAnchor=*/false);
+ })
+ // Literals, whitespace, and custom directives may be used, but they can't
+ // anchor the group.
+ .Case<LiteralElement, WhitespaceElement, CustomDirective,
+ FunctionalTypeDirective, OptionalElement>([&](FormatElement *) {
+ if (isAnchor)
+ return emitError(loc, "only variables and types can be used "
+ "to anchor an optional group");
+ return success();
+ })
+ .Default([&](FormatElement *) {
+ return emitError(loc, "only literals, types, and variables can be "
+ "used within an optional group");
+ });
}
//===----------------------------------------------------------------------===//
@@ -3308,7 +2849,9 @@ void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(op.getAssemblyFormat()), SMLoc());
OperationFormat format(op);
- if (failed(FormatParser(mgr, format, op).parse())) {
+ OpFormatParser parser(mgr, format, op);
+ FailureOr<std::vector<FormatElement *>> elements = parser.parse();
+ if (failed(elements)) {
// Exit the process if format errors are treated as fatal.
if (formatErrorIsFatal) {
// Invoke the interrupt handlers to run the file cleanup handlers.
@@ -3317,6 +2860,7 @@ void mlir::tblgen::generateOpFormat(const Operator &constOp, OpClass &opClass) {
}
return;
}
+ format.elements = std::move(*elements);
// Generate the printer and parser based on the parsed format.
format.genParser(op, opClass);
More information about the Mlir-commits
mailing list