[Mlir-commits] [mlir] 5312063 - [mlir:OpAsm] Factor out the common bits of (Op/Dialect)Asm(Parser/Printer)
River Riddle
llvmlistbot at llvm.org
Fri Sep 24 13:16:23 PDT 2021
Author: River Riddle
Date: 2021-09-24T20:12:19Z
New Revision: 531206310a27477f088f672f5e6fd688d77d9292
URL: https://github.com/llvm/llvm-project/commit/531206310a27477f088f672f5e6fd688d77d9292
DIFF: https://github.com/llvm/llvm-project/commit/531206310a27477f088f672f5e6fd688d77d9292.diff
LOG: [mlir:OpAsm] Factor out the common bits of (Op/Dialect)Asm(Parser/Printer)
This has a few benefits:
* It allows for defining parsers/printer code blocks that
can be shared between operations and attribute/types.
* It removes the weird duplication of generic parser/printer hooks,
which means that newly added hooks only require touching one class.
Differential Revision: https://reviews.llvm.org/D110375
Added:
mlir/lib/Parser/AsmParserImpl.h
Modified:
mlir/include/mlir/IR/DialectImplementation.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Parser.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 2d19f693d08c4..728e24605c29f 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -15,14 +15,9 @@
#define MLIR_IR_DIALECTIMPLEMENTATION_H
#include "mlir/IR/OpImplementation.h"
-#include "llvm/ADT/Twine.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/raw_ostream.h"
namespace mlir {
-class Builder;
-
//===----------------------------------------------------------------------===//
// DialectAsmPrinter
//===----------------------------------------------------------------------===//
@@ -30,360 +25,26 @@ class Builder;
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom printAttribute/printType() method on a
/// dialect.
-class DialectAsmPrinter {
+class DialectAsmPrinter : public AsmPrinter {
public:
- DialectAsmPrinter() {}
- virtual ~DialectAsmPrinter();
- virtual raw_ostream &getStream() const = 0;
-
- /// Print the given attribute to the stream.
- virtual void printAttribute(Attribute attr) = 0;
-
- /// Print the given attribute without its type. The corresponding parser must
- /// provide a valid type for the attribute.
- virtual void printAttributeWithoutType(Attribute attr) = 0;
-
- /// Print the given floating point value in a stabilized form that can be
- /// roundtripped through the IR. This is the companion to the 'parseFloat'
- /// hook on the DialectAsmParser.
- virtual void printFloat(const APFloat &value) = 0;
-
- /// Print the given type to the stream.
- virtual void printType(Type type) = 0;
-
-private:
- DialectAsmPrinter(const DialectAsmPrinter &) = delete;
- void operator=(const DialectAsmPrinter &) = delete;
+ using AsmPrinter::AsmPrinter;
+ ~DialectAsmPrinter() override;
};
-// Make the implementations convenient to use.
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
- p.printAttribute(attr);
- return p;
-}
-
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
- const APFloat &value) {
- p.printFloat(value);
- return p;
-}
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
- return p << APFloat(value);
-}
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
- return p << APFloat(value);
-}
-
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
- p.printType(type);
- return p;
-}
-
-// Support printing anything that isn't convertible to one of the above types,
-// even if it isn't exactly one of them. For example, we want to print
-// FunctionType with the Type version above, not have it match this.
-template <typename T, typename std::enable_if<
- !std::is_convertible<T &, Attribute &>::value &&
- !std::is_convertible<T &, Type &>::value &&
- !std::is_convertible<T &, APFloat &>::value &&
- !llvm::is_one_of<T, double, float>::value,
- T>::type * = nullptr>
-inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
- p.getStream() << other;
- return p;
-}
-
//===----------------------------------------------------------------------===//
// DialectAsmParser
//===----------------------------------------------------------------------===//
-/// The DialectAsmParser has methods for interacting with the asm parser:
-/// parsing things from it, emitting errors etc. It has an intentionally
-/// high-level API that is designed to reduce/constrain syntax innovation in
-/// individual attributes or types.
-class DialectAsmParser {
+/// The DialectAsmParser has methods for interacting with the asm parser when
+/// parsing attributes and types.
+class DialectAsmParser : public AsmParser {
public:
- virtual ~DialectAsmParser();
-
- /// Emit a diagnostic at the specified location and return failure.
- virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
- const Twine &message = {}) = 0;
-
- /// Return a builder which provides useful access to MLIRContext, global
- /// objects like types and attributes.
- virtual Builder &getBuilder() const = 0;
-
- /// Get the location of the next token and store it into the argument. This
- /// always succeeds.
- virtual llvm::SMLoc getCurrentLocation() = 0;
- ParseResult getCurrentLocation(llvm::SMLoc *loc) {
- *loc = getCurrentLocation();
- return success();
- }
-
- /// Return the location of the original name token.
- virtual llvm::SMLoc getNameLoc() const = 0;
-
- /// Re-encode the given source location as an MLIR location and return it.
- /// Note: This method should only be used when a `Location` is necessary, as
- /// the encoding process is not efficient. In other cases a more suitable
- /// alternative should be used, such as the `getChecked` methods defined
- /// below.
- virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
+ using AsmParser::AsmParser;
+ ~DialectAsmParser() override;
/// Returns the full specification of the symbol being parsed. This allows for
/// using a separate parser if necessary.
virtual StringRef getFullSymbolSpec() const = 0;
-
- // These methods emit an error and return failure or success. This allows
- // these to be chained together into a linear sequence of || expressions in
- // many cases.
-
- /// Parse a floating point value from the stream.
- virtual ParseResult parseFloat(double &result) = 0;
-
- /// Parse an integer value from the stream.
- template <typename IntT>
- ParseResult parseInteger(IntT &result) {
- auto loc = getCurrentLocation();
- OptionalParseResult parseResult = parseOptionalInteger(result);
- if (!parseResult.hasValue())
- return emitError(loc, "expected integer value");
- return *parseResult;
- }
-
- /// Parse an optional integer value from the stream.
- virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
-
- template <typename IntT>
- OptionalParseResult parseOptionalInteger(IntT &result) {
- auto loc = getCurrentLocation();
-
- // Parse the unsigned variant.
- APInt uintResult;
- OptionalParseResult parseResult = parseOptionalInteger(uintResult);
- if (!parseResult.hasValue() || failed(*parseResult))
- return parseResult;
-
- // Try to convert to the provided integer type. sextOrTrunc is correct even
- // for unsigned types because parseOptionalInteger ensures the sign bit is
- // zero for non-negated integers.
- result =
- (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
- if (APInt(uintResult.getBitWidth(), result) != uintResult)
- return emitError(loc, "integer value too large");
- return success();
- }
-
- /// Invoke the `getChecked` method of the given Attribute or Type class, using
- /// the provided location to emit errors in the case of failure. Note that
- /// unlike `OpBuilder::getType`, this method does not implicitly insert a
- /// context parameter.
- template <typename T, typename... ParamsT>
- T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
- return T::getChecked([&] { return emitError(loc); },
- std::forward<ParamsT>(params)...);
- }
- /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
- /// errors.
- template <typename T, typename... ParamsT>
- T getChecked(ParamsT &&... params) {
- return T::getChecked([&] { return emitError(getNameLoc()); },
- std::forward<ParamsT>(params)...);
- }
-
- //===--------------------------------------------------------------------===//
- // Token Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a '->' token.
- virtual ParseResult parseArrow() = 0;
-
- /// Parse a '->' token if present
- virtual ParseResult parseOptionalArrow() = 0;
-
- /// Parse a '{' token.
- virtual ParseResult parseLBrace() = 0;
-
- /// Parse a '{' token if present
- virtual ParseResult parseOptionalLBrace() = 0;
-
- /// Parse a `}` token.
- virtual ParseResult parseRBrace() = 0;
-
- /// Parse a `}` token if present
- virtual ParseResult parseOptionalRBrace() = 0;
-
- /// Parse a `:` token.
- virtual ParseResult parseColon() = 0;
-
- /// Parse a `:` token if present.
- virtual ParseResult parseOptionalColon() = 0;
-
- /// Parse a `,` token.
- virtual ParseResult parseComma() = 0;
-
- /// Parse a `,` token if present.
- virtual ParseResult parseOptionalComma() = 0;
-
- /// Parse a `=` token.
- virtual ParseResult parseEqual() = 0;
-
- /// Parse a `=` token if present.
- virtual ParseResult parseOptionalEqual() = 0;
-
- /// Parse a quoted string token.
- ParseResult parseString(std::string *string) {
- auto loc = getCurrentLocation();
- if (parseOptionalString(string))
- return emitError(loc, "expected string");
- return success();
- }
-
- /// Parse a quoted string token if present.
- virtual ParseResult parseOptionalString(std::string *string) = 0;
-
- /// Parse a given keyword.
- ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
- auto loc = getCurrentLocation();
- if (parseOptionalKeyword(keyword))
- return emitError(loc, "expected '") << keyword << "'" << msg;
- return success();
- }
-
- /// Parse a keyword into 'keyword'.
- ParseResult parseKeyword(StringRef *keyword) {
- auto loc = getCurrentLocation();
- if (parseOptionalKeyword(keyword))
- return emitError(loc, "expected valid keyword");
- return success();
- }
-
- /// Parse the given keyword if present.
- virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
-
- /// Parse a keyword, if present, into 'keyword'.
- virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
-
- /// Parse a '<' token.
- virtual ParseResult parseLess() = 0;
-
- /// Parse a `<` token if present.
- virtual ParseResult parseOptionalLess() = 0;
-
- /// Parse a '>' token.
- virtual ParseResult parseGreater() = 0;
-
- /// Parse a `>` token if present.
- virtual ParseResult parseOptionalGreater() = 0;
-
- /// Parse a `(` token.
- virtual ParseResult parseLParen() = 0;
-
- /// Parse a `(` token if present.
- virtual ParseResult parseOptionalLParen() = 0;
-
- /// Parse a `)` token.
- virtual ParseResult parseRParen() = 0;
-
- /// Parse a `)` token if present.
- virtual ParseResult parseOptionalRParen() = 0;
-
- /// Parse a `[` token.
- virtual ParseResult parseLSquare() = 0;
-
- /// Parse a `[` token if present.
- virtual ParseResult parseOptionalLSquare() = 0;
-
- /// Parse a `]` token.
- virtual ParseResult parseRSquare() = 0;
-
- /// Parse a `]` token if present.
- virtual ParseResult parseOptionalRSquare() = 0;
-
- /// Parse a `...` token if present;
- virtual ParseResult parseOptionalEllipsis() = 0;
-
- /// Parse a `?` token.
- virtual ParseResult parseOptionalQuestion() = 0;
-
- /// Parse a `*` token.
- virtual ParseResult parseOptionalStar() = 0;
-
- //===--------------------------------------------------------------------===//
- // Attribute Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an arbitrary attribute and return it in result.
- virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
-
- /// Parse an attribute of a specific kind and type.
- template <typename AttrType>
- ParseResult parseAttribute(AttrType &result, Type type = {}) {
- llvm::SMLoc loc = getCurrentLocation();
-
- // Parse any kind of attribute.
- Attribute attr;
- if (parseAttribute(attr, type))
- return failure();
-
- // Check for the right kind of attribute.
- result = attr.dyn_cast<AttrType>();
- if (!result)
- return emitError(loc, "invalid kind of attribute specified");
- return success();
- }
-
- /// Parse an affine map instance into 'map'.
- virtual ParseResult parseAffineMap(AffineMap &map) = 0;
-
- /// Parse an integer set instance into 'set'.
- virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
-
- //===--------------------------------------------------------------------===//
- // Type Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a type.
- virtual ParseResult parseType(Type &result) = 0;
-
- /// Parse a type of a specific kind, e.g. a FunctionType.
- template <typename TypeType>
- ParseResult parseType(TypeType &result) {
- llvm::SMLoc loc = getCurrentLocation();
-
- // Parse any kind of type.
- Type type;
- if (parseType(type))
- return failure();
-
- // Check for the right kind of attribute.
- result = type.dyn_cast<TypeType>();
- if (!result)
- return emitError(loc, "invalid kind of type specified");
- return success();
- }
-
- /// Parse a type if present.
- virtual OptionalParseResult parseOptionalType(Type &result) = 0;
-
- /// Parse a 'x' separated dimension list. This populates the dimension list,
- /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
- /// `?` otherwise.
- ///
- /// dimension-list ::= (dimension `x`)*
- /// dimension ::= `?` | integer
- ///
- /// When `allowDynamic` is not set, this is used to parse:
- ///
- /// static-dimension-list ::= (integer `x`)*
- virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic = true) = 0;
-
- /// Parse an 'x' token in a dimension list, handling the case where the x is
- /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
- /// next token.
- virtual ParseResult parseXInDimensionList() = 0;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 979b86c60194f..ca6fc1bab1f8c 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -24,17 +24,179 @@ namespace mlir {
class Builder;
+//===----------------------------------------------------------------------===//
+// AsmPrinter
+//===----------------------------------------------------------------------===//
+
+/// This base class exposes generic asm printer hooks, usable across the various
+/// derived printers.
+class AsmPrinter {
+public:
+ /// This class contains the internal default implementation of the base
+ /// printer methods.
+ class Impl;
+
+ /// Initialize the printer with the given internal implementation.
+ AsmPrinter(Impl &impl) : impl(&impl) {}
+ virtual ~AsmPrinter();
+
+ /// Return the raw output stream used by this printer.
+ virtual raw_ostream &getStream() const;
+
+ /// Print the given floating point value in a stabilized form that can be
+ /// roundtripped through the IR. This is the companion to the 'parseFloat'
+ /// hook on the AsmParser.
+ virtual void printFloat(const APFloat &value);
+
+ virtual void printType(Type type);
+ virtual void printAttribute(Attribute attr);
+
+ /// Print the given attribute without its type. The corresponding parser must
+ /// provide a valid type for the attribute.
+ virtual void printAttributeWithoutType(Attribute attr);
+
+ /// Print the given string as a symbol reference, i.e. a form representable by
+ /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
+ /// with '@'. The reference is surrounded with ""'s and escaped if it has any
+ /// special or non-printable characters in it.
+ virtual void printSymbolName(StringRef symbolRef);
+
+ /// Print an optional arrow followed by a type list.
+ template <typename TypeRange>
+ void printOptionalArrowTypeList(TypeRange &&types) {
+ if (types.begin() != types.end())
+ printArrowTypeList(types);
+ }
+ template <typename TypeRange>
+ void printArrowTypeList(TypeRange &&types) {
+ auto &os = getStream() << " -> ";
+
+ bool wrapped = !llvm::hasSingleElement(types) ||
+ (*types.begin()).template isa<FunctionType>();
+ if (wrapped)
+ os << '(';
+ llvm::interleaveComma(types, *this);
+ if (wrapped)
+ os << ')';
+ }
+
+ /// Print the two given type ranges in a functional form.
+ template <typename InputRangeT, typename ResultRangeT>
+ void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
+ auto &os = getStream();
+ os << '(';
+ llvm::interleaveComma(inputs, *this);
+ os << ')';
+ printArrowTypeList(results);
+ }
+
+protected:
+ /// Initialize the printer with no internal implementation. In this case, all
+ /// virtual methods of this class must be overriden.
+ AsmPrinter() : impl(nullptr) {}
+
+private:
+ AsmPrinter(const AsmPrinter &) = delete;
+ void operator=(const AsmPrinter &) = delete;
+
+ /// The internal implementation of the printer.
+ Impl *impl;
+};
+
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, Type type) {
+ p.printType(type);
+ return p;
+}
+
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, Attribute attr) {
+ p.printAttribute(attr);
+ return p;
+}
+
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, const APFloat &value) {
+ p.printFloat(value);
+ return p;
+}
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, float value) {
+ return p << APFloat(value);
+}
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, double value) {
+ return p << APFloat(value);
+}
+
+// Support printing anything that isn't convertible to one of the other
+// streamable types, even if it isn't exactly one of them. For example, we want
+// to print FunctionType with the Type version above, not have it match this.
+template <
+ typename AsmPrinterT, typename T,
+ typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
+ !std::is_convertible<T &, Type &>::value &&
+ !std::is_convertible<T &, Attribute &>::value &&
+ !std::is_convertible<T &, ValueRange>::value &&
+ !std::is_convertible<T &, APFloat &>::value &&
+ !llvm::is_one_of<T, bool, float, double>::value,
+ T>::type * = nullptr>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, const T &other) {
+ p.getStream() << other;
+ return p;
+}
+
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, bool value) {
+ return p << (value ? StringRef("true") : "false");
+}
+
+template <typename AsmPrinterT, typename ValueRangeT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
+ llvm::interleaveComma(types, p);
+ return p;
+}
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, const TypeRange &types) {
+ llvm::interleaveComma(types, p);
+ return p;
+}
+template <typename AsmPrinterT>
+inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
+ AsmPrinterT &>
+operator<<(AsmPrinterT &p, ArrayRef<Type> types) {
+ llvm::interleaveComma(types, p);
+ return p;
+}
+
//===----------------------------------------------------------------------===//
// OpAsmPrinter
//===----------------------------------------------------------------------===//
/// This is a pure-virtual base class that exposes the asmprinter hooks
/// necessary to implement a custom print() method.
-class OpAsmPrinter {
+class OpAsmPrinter : public AsmPrinter {
public:
- OpAsmPrinter() {}
- virtual ~OpAsmPrinter();
- virtual raw_ostream &getStream() const = 0;
+ using AsmPrinter::AsmPrinter;
+ ~OpAsmPrinter() override;
/// Print a newline and indent the printer to the start of the current
/// operation.
@@ -70,12 +232,6 @@ class OpAsmPrinter {
printOperand(*it);
}
}
- virtual void printType(Type type) = 0;
- virtual void printAttribute(Attribute attr) = 0;
-
- /// Print the given attribute without its type. The corresponding parser must
- /// provide a valid type for the attribute.
- virtual void printAttributeWithoutType(Attribute attr) = 0;
/// Print the given successor.
virtual void printSuccessor(Block *successor) = 0;
@@ -131,47 +287,9 @@ class OpAsmPrinter {
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) = 0;
- /// Print an optional arrow followed by a type list.
- template <typename TypeRange>
- void printOptionalArrowTypeList(TypeRange &&types) {
- if (types.begin() != types.end())
- printArrowTypeList(types);
- }
- template <typename TypeRange>
- void printArrowTypeList(TypeRange &&types) {
- auto &os = getStream() << " -> ";
-
- bool wrapped = !llvm::hasSingleElement(types) ||
- (*types.begin()).template isa<FunctionType>();
- if (wrapped)
- os << '(';
- llvm::interleaveComma(types, *this);
- if (wrapped)
- os << ')';
- }
-
/// Print the complete type of an operation in functional form.
void printFunctionalType(Operation *op);
-
- /// Print the two given type ranges in a functional form.
- template <typename InputRangeT, typename ResultRangeT>
- void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
- auto &os = getStream();
- os << '(';
- llvm::interleaveComma(inputs, *this);
- os << ')';
- printArrowTypeList(results);
- }
-
- /// Print the given string as a symbol reference, i.e. a form representable by
- /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
- /// with '@'. The reference is surrounded with ""'s and escaped if it has any
- /// special or non-printable characters in it.
- virtual void printSymbolName(StringRef symbolRef) = 0;
-
-private:
- OpAsmPrinter(const OpAsmPrinter &) = delete;
- void operator=(const OpAsmPrinter &) = delete;
+ using AsmPrinter::printFunctionalType;
};
// Make the implementations convenient to use.
@@ -189,77 +307,28 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
return p;
}
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
- p.printType(type);
- return p;
-}
-
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
- p.printAttribute(attr);
- return p;
-}
-
-// Support printing anything that isn't convertible to one of the above types,
-// even if it isn't exactly one of them. For example, we want to print
-// FunctionType with the Type version above, not have it match this.
-template <typename T, typename std::enable_if<
- !std::is_convertible<T &, Value &>::value &&
- !std::is_convertible<T &, Type &>::value &&
- !std::is_convertible<T &, Attribute &>::value &&
- !std::is_convertible<T &, ValueRange>::value &&
- !llvm::is_one_of<T, bool>::value,
- T>::type * = nullptr>
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &other) {
- p.getStream() << other;
- return p;
-}
-
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
- return p << (value ? StringRef("true") : "false");
-}
-
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
p.printSuccessor(value);
return p;
}
-template <typename ValueRangeT>
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
- const ValueTypeRange<ValueRangeT> &types) {
- llvm::interleaveComma(types, p);
- return p;
-}
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const TypeRange &types) {
- llvm::interleaveComma(types, p);
- return p;
-}
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, ArrayRef<Type> types) {
- llvm::interleaveComma(types, p);
- return p;
-}
-
//===----------------------------------------------------------------------===//
-// OpAsmParser
+// AsmParser
//===----------------------------------------------------------------------===//
-/// The OpAsmParser has methods for interacting with the asm parser: parsing
-/// things from it, emitting errors etc. It has an intentionally high-level API
-/// that is designed to reduce/constrain syntax innovation in individual
-/// operations.
-///
-/// For example, consider an op like this:
-///
-/// %x = load %p[%1, %2] : memref<...>
-///
-/// The "%x = load" tokens are already parsed and therefore invisible to the
-/// custom op parser. This can be supported by calling `parseOperandList` to
-/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
-/// parse the indices, then calling `parseColonTypeList` to parse the result
-/// type.
-///
-class OpAsmParser {
+/// This base class exposes generic asm parser hooks, usable across the various
+/// derived parsers.
+class AsmParser {
public:
- virtual ~OpAsmParser();
+ AsmParser() = default;
+ virtual ~AsmParser();
+
+ /// Return the location of the original name token.
+ virtual llvm::SMLoc getNameLoc() const = 0;
+
+ //===--------------------------------------------------------------------===//
+ // Utilities
+ //===--------------------------------------------------------------------===//
/// Emit a diagnostic at the specified location and return failure.
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
@@ -277,44 +346,11 @@ class OpAsmParser {
return success();
}
- /// Return the name of the specified result in the specified syntax, as well
- /// as the sub-element in the name. It returns an empty string and ~0U for
- /// invalid result numbers. For example, in this operation:
- ///
- /// %x, %y:2, %z = foo.op
- ///
- /// getResultName(0) == {"x", 0 }
- /// getResultName(1) == {"y", 0 }
- /// getResultName(2) == {"y", 1 }
- /// getResultName(3) == {"z", 0 }
- /// getResultName(4) == {"", ~0U }
- virtual std::pair<StringRef, unsigned>
- getResultName(unsigned resultNo) const = 0;
-
- /// Return the number of declared SSA results. This returns 4 for the foo.op
- /// example in the comment for `getResultName`.
- virtual size_t getNumResults() const = 0;
-
- /// Return the location of the original name token.
- virtual llvm::SMLoc getNameLoc() const = 0;
-
/// Re-encode the given source location as an MLIR location and return it.
/// Note: This method should only be used when a `Location` is necessary, as
/// the encoding process is not efficient.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
- // These methods emit an error and return failure or success. This allows
- // these to be chained together into a linear sequence of || expressions in
- // many cases.
-
- /// Parse an operation in its generic form.
- /// The parsed operation is parsed in the current context and inserted in the
- /// provided block and insertion point. The results produced by this operation
- /// aren't mapped to any named value in the parser. Returns nullptr on
- /// failure.
- virtual Operation *parseGenericOperation(Block *insertBlock,
- Block::iterator insertPt) = 0;
-
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
@@ -385,6 +421,17 @@ class OpAsmParser {
/// Parse a '*' token if present.
virtual ParseResult parseOptionalStar() = 0;
+ /// Parse a quoted string token.
+ ParseResult parseString(std::string *string) {
+ auto loc = getCurrentLocation();
+ if (parseOptionalString(string))
+ return emitError(loc, "expected string");
+ return success();
+ }
+
+ /// Parse a quoted string token if present.
+ virtual ParseResult parseOptionalString(std::string *string) = 0;
+
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
auto loc = getCurrentLocation();
@@ -440,6 +487,9 @@ class OpAsmParser {
/// Parse a `...` token if present;
virtual ParseResult parseOptionalEllipsis() = 0;
+ /// Parse a floating point value from the stream.
+ virtual ParseResult parseFloat(double &result) = 0;
+
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
@@ -514,6 +564,27 @@ class OpAsmParser {
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
+ //===--------------------------------------------------------------------===//
+ // Attribute/Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Invoke the `getChecked` method of the given Attribute or Type class, using
+ /// the provided location to emit errors in the case of failure. Note that
+ /// unlike `OpBuilder::getType`, this method does not implicitly insert a
+ /// context parameter.
+ template <typename T, typename... ParamsT>
+ T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
+ return T::getChecked([&] { return emitError(loc); },
+ std::forward<ParamsT>(params)...);
+ }
+ /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
+ /// errors.
+ template <typename T, typename... ParamsT>
+ T getChecked(ParamsT &&... params) {
+ return T::getChecked([&] { return emitError(getNameLoc()); },
+ std::forward<ParamsT>(params)...);
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@@ -634,6 +705,180 @@ class OpAsmParser {
virtual ParseResult
parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a type.
+ virtual ParseResult parseType(Type &result) = 0;
+
+ /// Parse an optional type.
+ virtual OptionalParseResult parseOptionalType(Type &result) = 0;
+
+ /// Parse a type of a specific type.
+ template <typename TypeT>
+ ParseResult parseType(TypeT &result) {
+ llvm::SMLoc loc = getCurrentLocation();
+
+ // Parse any kind of type.
+ Type type;
+ if (parseType(type))
+ return failure();
+
+ // Check for the right kind of attribute.
+ result = type.dyn_cast<TypeT>();
+ if (!result)
+ return emitError(loc, "invalid kind of type specified");
+
+ return success();
+ }
+
+ /// Parse a type list.
+ ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
+ do {
+ Type type;
+ if (parseType(type))
+ return failure();
+ result.push_back(type);
+ } while (succeeded(parseOptionalComma()));
+ return success();
+ }
+
+ /// Parse an arrow followed by a type list.
+ virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
+
+ /// Parse an optional arrow followed by a type list.
+ virtual ParseResult
+ parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
+
+ /// Parse a colon followed by a type.
+ virtual ParseResult parseColonType(Type &result) = 0;
+
+ /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
+ template <typename TypeType>
+ ParseResult parseColonType(TypeType &result) {
+ llvm::SMLoc loc = getCurrentLocation();
+
+ // Parse any kind of type.
+ Type type;
+ if (parseColonType(type))
+ return failure();
+
+ // Check for the right kind of attribute.
+ result = type.dyn_cast<TypeType>();
+ if (!result)
+ return emitError(loc, "invalid kind of type specified");
+
+ return success();
+ }
+
+ /// Parse a colon followed by a type list, which must have at least one type.
+ virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
+
+ /// Parse an optional colon followed by a type list, which if present must
+ /// have at least one type.
+ virtual ParseResult
+ parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
+
+ /// Parse a keyword followed by a type.
+ ParseResult parseKeywordType(const char *keyword, Type &result) {
+ return failure(parseKeyword(keyword) || parseType(result));
+ }
+
+ /// Add the specified type to the end of the specified type list and return
+ /// success. This is a helper designed to allow parse methods to be simple
+ /// and chain through || operators.
+ ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
+ result.push_back(type);
+ return success();
+ }
+
+ /// Add the specified types to the end of the specified type list and return
+ /// success. This is a helper designed to allow parse methods to be simple
+ /// and chain through || operators.
+ ParseResult addTypesToList(ArrayRef<Type> types,
+ SmallVectorImpl<Type> &result) {
+ result.append(types.begin(), types.end());
+ return success();
+ }
+
+ /// Parse a 'x' separated dimension list. This populates the dimension list,
+ /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
+ /// `?` otherwise.
+ ///
+ /// dimension-list ::= (dimension `x`)*
+ /// dimension ::= `?` | integer
+ ///
+ /// When `allowDynamic` is not set, this is used to parse:
+ ///
+ /// static-dimension-list ::= (integer `x`)*
+ virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic = true) = 0;
+
+ /// Parse an 'x' token in a dimension list, handling the case where the x is
+ /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
+ /// next token.
+ virtual ParseResult parseXInDimensionList() = 0;
+
+private:
+ AsmParser(const AsmParser &) = delete;
+ void operator=(const AsmParser &) = delete;
+};
+
+//===----------------------------------------------------------------------===//
+// OpAsmParser
+//===----------------------------------------------------------------------===//
+
+/// The OpAsmParser has methods for interacting with the asm parser: parsing
+/// things from it, emitting errors etc. It has an intentionally high-level API
+/// that is designed to reduce/constrain syntax innovation in individual
+/// operations.
+///
+/// For example, consider an op like this:
+///
+/// %x = load %p[%1, %2] : memref<...>
+///
+/// The "%x = load" tokens are already parsed and therefore invisible to the
+/// custom op parser. This can be supported by calling `parseOperandList` to
+/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
+/// parse the indices, then calling `parseColonTypeList` to parse the result
+/// type.
+///
+class OpAsmParser : public AsmParser {
+public:
+ using AsmParser::AsmParser;
+ ~OpAsmParser() override;
+
+ /// Return the name of the specified result in the specified syntax, as well
+ /// as the sub-element in the name. It returns an empty string and ~0U for
+ /// invalid result numbers. For example, in this operation:
+ ///
+ /// %x, %y:2, %z = foo.op
+ ///
+ /// getResultName(0) == {"x", 0 }
+ /// getResultName(1) == {"y", 0 }
+ /// getResultName(2) == {"y", 1 }
+ /// getResultName(3) == {"z", 0 }
+ /// getResultName(4) == {"", ~0U }
+ virtual std::pair<StringRef, unsigned>
+ getResultName(unsigned resultNo) const = 0;
+
+ /// Return the number of declared SSA results. This returns 4 for the foo.op
+ /// example in the comment for `getResultName`.
+ virtual size_t getNumResults() const = 0;
+
+ // These methods emit an error and return failure or success. This allows
+ // these to be chained together into a linear sequence of || expressions in
+ // many cases.
+
+ /// Parse an operation in its generic form.
+ /// The parsed operation is parsed in the current context and inserted in the
+ /// provided block and insertion point. The results produced by this operation
+ /// aren't mapped to any named value in the parser. Returns nullptr on
+ /// failure.
+ virtual Operation *parseGenericOperation(Block *insertBlock,
+ Block::iterator insertPt) = 0;
+
//===--------------------------------------------------------------------===//
// Operand Parsing
//===--------------------------------------------------------------------===//
@@ -813,77 +1058,6 @@ class OpAsmParser {
// Type Parsing
//===--------------------------------------------------------------------===//
- /// Parse a type.
- virtual ParseResult parseType(Type &result) = 0;
-
- /// Parse an optional type.
- virtual OptionalParseResult parseOptionalType(Type &result) = 0;
-
- /// Parse a type of a specific type.
- template <typename TypeT>
- ParseResult parseType(TypeT &result) {
- llvm::SMLoc loc = getCurrentLocation();
-
- // Parse any kind of type.
- Type type;
- if (parseType(type))
- return failure();
-
- // Check for the right kind of attribute.
- result = type.dyn_cast<TypeT>();
- if (!result)
- return emitError(loc, "invalid kind of type specified");
-
- return success();
- }
-
- /// Parse a type list.
- ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
- do {
- Type type;
- if (parseType(type))
- return failure();
- result.push_back(type);
- } while (succeeded(parseOptionalComma()));
- return success();
- }
-
- /// Parse an arrow followed by a type list.
- virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
-
- /// Parse an optional arrow followed by a type list.
- virtual ParseResult
- parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
-
- /// Parse a colon followed by a type.
- virtual ParseResult parseColonType(Type &result) = 0;
-
- /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
- template <typename TypeType>
- ParseResult parseColonType(TypeType &result) {
- llvm::SMLoc loc = getCurrentLocation();
-
- // Parse any kind of type.
- Type type;
- if (parseColonType(type))
- return failure();
-
- // Check for the right kind of attribute.
- result = type.dyn_cast<TypeType>();
- if (!result)
- return emitError(loc, "invalid kind of type specified");
-
- return success();
- }
-
- /// Parse a colon followed by a type list, which must have at least one type.
- virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
-
- /// Parse an optional colon followed by a type list, which if present must
- /// have at least one type.
- virtual ParseResult
- parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
-
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...)
ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
@@ -914,27 +1088,6 @@ class OpAsmParser {
parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
SmallVectorImpl<OperandType> &rhs,
SmallVectorImpl<Type> &types) = 0;
- /// Parse a keyword followed by a type.
- ParseResult parseKeywordType(const char *keyword, Type &result) {
- return failure(parseKeyword(keyword) || parseType(result));
- }
-
- /// Add the specified type to the end of the specified type list and return
- /// success. This is a helper designed to allow parse methods to be simple
- /// and chain through || operators.
- ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
- result.push_back(type);
- return success();
- }
-
- /// Add the specified types to the end of the specified type list and return
- /// success. This is a helper designed to allow parse methods to be simple
- /// and chain through || operators.
- ParseResult addTypesToList(ArrayRef<Type> types,
- SmallVectorImpl<Type> &result) {
- result.append(types.begin(), types.end());
- return success();
- }
private:
/// Parse either an operand list or a region argument list depending on
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 7a90804cd047c..b6438fb7071bf 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -52,6 +52,18 @@ void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
void OperationName::dump() const { print(llvm::errs()); }
+//===--------------------------------------------------------------------===//
+// AsmParser
+//===--------------------------------------------------------------------===//
+
+AsmParser::~AsmParser() {}
+DialectAsmParser::~DialectAsmParser() {}
+OpAsmParser::~OpAsmParser() {}
+
+//===--------------------------------------------------------------------===//
+// DialectAsmPrinter
+//===--------------------------------------------------------------------===//
+
DialectAsmPrinter::~DialectAsmPrinter() {}
//===--------------------------------------------------------------------===//
@@ -250,12 +262,12 @@ namespace {
struct NewLineCounter {
unsigned curLine = 1;
};
-} // end anonymous namespace
static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
++newLine.curLine;
return os << '\n';
}
+} // end anonymous namespace
//===----------------------------------------------------------------------===//
// AliasInitializer
@@ -492,6 +504,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter {
/// The following are hooks of `OpAsmPrinter` that are not necessary for
/// determining potential aliases.
+ void printFloat(const APFloat &value) override {}
void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
void printNewline() override {}
@@ -1202,18 +1215,17 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
AsmState::~AsmState() {}
//===----------------------------------------------------------------------===//
-// ModulePrinter
+// AsmPrinter::Impl
//===----------------------------------------------------------------------===//
-namespace {
-class ModulePrinter {
+namespace mlir {
+class AsmPrinter::Impl {
public:
- ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
- AsmStateImpl *state = nullptr)
+ Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
+ AsmStateImpl *state = nullptr)
: os(os), printerFlags(flags), state(state) {}
- explicit ModulePrinter(ModulePrinter &printer)
- : os(printer.os), printerFlags(printer.printerFlags),
- state(printer.state) {}
+ explicit Impl(Impl &other)
+ : Impl(other.os, other.printerFlags, other.state) {}
/// Returns the output stream of the printer.
raw_ostream &getStream() { return os; }
@@ -1298,9 +1310,9 @@ class ModulePrinter {
/// A tracker for the number of new lines emitted during printing.
NewLineCounter newLine;
};
-} // end anonymous namespace
+} // namespace mlir
-void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
+void AsmPrinter::Impl::printTrailingLocation(Location loc, bool allowAlias) {
// Check to see if we are printing debug information.
if (!printerFlags.shouldPrintDebugInfo())
return;
@@ -1309,7 +1321,7 @@ void ModulePrinter::printTrailingLocation(Location loc, bool allowAlias) {
printLocation(loc, /*allowAlias=*/allowAlias);
}
-void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
+void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty) {
TypeSwitch<LocationAttr>(loc)
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
printLocationInternal(loc.getFallbackLocation(), pretty);
@@ -1430,7 +1442,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
os << str;
}
-void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
+void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
if (printerFlags.shouldPrintDebugInfoPrettyForm())
return printLocationInternal(loc, /*pretty=*/true);
@@ -1578,8 +1590,8 @@ static void printElidedElementsAttr(raw_ostream &os) {
os << R"(opaque<"_", "0xDEADBEEF">)";
}
-void ModulePrinter::printAttribute(Attribute attr,
- AttrTypeElision typeElision) {
+void AsmPrinter::Impl::printAttribute(Attribute attr,
+ AttrTypeElision typeElision) {
if (!attr) {
os << "<<NULL ATTRIBUTE>>";
return;
@@ -1780,8 +1792,8 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
os << ']';
}
-void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
- bool allowHex) {
+void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
+ bool allowHex) {
if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
return printDenseStringElementsAttr(stringAttr);
@@ -1789,8 +1801,8 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
allowHex);
}
-void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
- bool allowHex) {
+void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
+ DenseIntOrFPElementsAttr attr, bool allowHex) {
auto type = attr.getType();
auto elementType = type.getElementType();
@@ -1860,7 +1872,8 @@ void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
}
}
-void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
+void AsmPrinter::Impl::printDenseStringElementsAttr(
+ DenseStringElementsAttr attr) {
ArrayRef<StringRef> data = attr.getRawStringData();
auto printFn = [&](unsigned index) {
os << "\"";
@@ -1870,7 +1883,7 @@ void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
}
-void ModulePrinter::printType(Type type) {
+void AsmPrinter::Impl::printType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
@@ -1986,9 +1999,9 @@ void ModulePrinter::printType(Type type) {
.Default([&](Type type) { return printDialectType(type); });
}
-void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
- ArrayRef<StringRef> elidedAttrs,
- bool withKeyword) {
+void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
+ ArrayRef<StringRef> elidedAttrs,
+ bool withKeyword) {
// If there are no attributes, then there is nothing to be done.
if (attrs.empty())
return;
@@ -2020,7 +2033,7 @@ void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
printFilteredAttributesFn(filteredAttrs);
}
-void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
+void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
if (isBareIdentifier(attr.first)) {
os << attr.first;
} else {
@@ -2037,81 +2050,82 @@ void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
printAttribute(attr.second);
}
-//===----------------------------------------------------------------------===//
-// CustomDialectAsmPrinter
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class provides the main specialization of the DialectAsmPrinter that is
-/// used to provide support for print attributes and types. This hooks allows
-/// for dialects to hook into the main ModulePrinter.
-struct CustomDialectAsmPrinter : public DialectAsmPrinter {
-public:
- CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
- ~CustomDialectAsmPrinter() override {}
-
- raw_ostream &getStream() const override { return printer.getStream(); }
-
- /// Print the given attribute to the stream.
- void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
-
- /// Print the given attribute without its type. The corresponding parser must
- /// provide a valid type for the attribute.
- void printAttributeWithoutType(Attribute attr) override {
- printer.printAttribute(attr, ModulePrinter::AttrTypeElision::Must);
- }
-
- /// Print the given floating point value in a stablized form.
- void printFloat(const APFloat &value) override {
- printFloatValue(value, getStream());
- }
-
- /// Print the given type to the stream.
- void printType(Type type) override { printer.printType(type); }
-
- /// The main module printer.
- ModulePrinter &printer;
-};
-} // end anonymous namespace
-
-void ModulePrinter::printDialectAttribute(Attribute attr) {
+void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
auto &dialect = attr.getDialect();
// Ask the dialect to serialize the attribute to a string.
std::string attrName;
{
llvm::raw_string_ostream attrNameStr(attrName);
- ModulePrinter subPrinter(attrNameStr, printerFlags, state);
- CustomDialectAsmPrinter printer(subPrinter);
+ Impl subPrinter(attrNameStr, printerFlags, state);
+ DialectAsmPrinter printer(subPrinter);
dialect.printAttribute(attr, printer);
}
printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
}
-void ModulePrinter::printDialectType(Type type) {
+void AsmPrinter::Impl::printDialectType(Type type) {
auto &dialect = type.getDialect();
// Ask the dialect to serialize the type to a string.
std::string typeName;
{
llvm::raw_string_ostream typeNameStr(typeName);
- ModulePrinter subPrinter(typeNameStr, printerFlags, state);
- CustomDialectAsmPrinter printer(subPrinter);
+ Impl subPrinter(typeNameStr, printerFlags, state);
+ DialectAsmPrinter printer(subPrinter);
dialect.printType(type, printer);
}
printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
}
+//===--------------------------------------------------------------------===//
+// AsmPrinter
+//===--------------------------------------------------------------------===//
+
+AsmPrinter::~AsmPrinter() {}
+
+raw_ostream &AsmPrinter::getStream() const {
+ assert(impl && "expected AsmPrinter::getStream to be overriden");
+ return impl->getStream();
+}
+
+/// Print the given floating point value in a stablized form.
+void AsmPrinter::printFloat(const APFloat &value) {
+ assert(impl && "expected AsmPrinter::printFloat to be overriden");
+ printFloatValue(value, impl->getStream());
+}
+
+void AsmPrinter::printType(Type type) {
+ assert(impl && "expected AsmPrinter::printType to be overriden");
+ impl->printType(type);
+}
+
+void AsmPrinter::printAttribute(Attribute attr) {
+ assert(impl && "expected AsmPrinter::printAttribute to be overriden");
+ impl->printAttribute(attr);
+}
+
+void AsmPrinter::printAttributeWithoutType(Attribute attr) {
+ assert(impl &&
+ "expected AsmPrinter::printAttributeWithoutType to be overriden");
+ impl->printAttribute(attr, Impl::AttrTypeElision::Must);
+}
+
+void AsmPrinter::printSymbolName(StringRef symbolRef) {
+ assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
+ ::printSymbolReference(symbolRef, impl->getStream());
+}
+
//===----------------------------------------------------------------------===//
// Affine expressions and maps
//===----------------------------------------------------------------------===//
-void ModulePrinter::printAffineExpr(
+void AsmPrinter::Impl::printAffineExpr(
AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
}
-void ModulePrinter::printAffineExprInternal(
+void AsmPrinter::Impl::printAffineExprInternal(
AffineExpr expr, BindingStrength enclosingTightness,
function_ref<void(unsigned, bool)> printValueName) {
const char *binopSpelling = nullptr;
@@ -2244,12 +2258,12 @@ void ModulePrinter::printAffineExprInternal(
os << ')';
}
-void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
+void AsmPrinter::Impl::printAffineConstraint(AffineExpr expr, bool isEq) {
printAffineExprInternal(expr, BindingStrength::Weak);
isEq ? os << " == 0" : os << " >= 0";
}
-void ModulePrinter::printAffineMap(AffineMap map) {
+void AsmPrinter::Impl::printAffineMap(AffineMap map) {
// Dimension identifiers.
os << '(';
for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
@@ -2275,7 +2289,7 @@ void ModulePrinter::printAffineMap(AffineMap map) {
os << ')';
}
-void ModulePrinter::printIntegerSet(IntegerSet set) {
+void AsmPrinter::Impl::printIntegerSet(IntegerSet set) {
// Dimension identifiers.
os << '(';
for (unsigned i = 1; i < set.getNumDims(); ++i)
@@ -2313,11 +2327,14 @@ void ModulePrinter::printIntegerSet(IntegerSet set) {
namespace {
/// This class contains the logic for printing operations, regions, and blocks.
-class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
+class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
public:
+ using Impl = AsmPrinter::Impl;
+ using Impl::printType;
+
explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
AsmStateImpl &state)
- : ModulePrinter(os, flags, &state) {}
+ : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
/// Print the given top-level operation.
void printTopLevelOperation(Operation *op);
@@ -2346,9 +2363,6 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
// OpAsmPrinter methods
//===--------------------------------------------------------------------===//
- /// Return the current stream of the printer.
- raw_ostream &getStream() const override { return os; }
-
/// Print a newline and indent the printer to the start of the current
/// operation.
void printNewline() override {
@@ -2356,20 +2370,6 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
os.indent(currentIndent);
}
- /// Print the given type.
- void printType(Type type) override { ModulePrinter::printType(type); }
-
- /// Print the given attribute.
- void printAttribute(Attribute attr) override {
- ModulePrinter::printAttribute(attr);
- }
-
- /// Print the given attribute without its type. The corresponding parser must
- /// provide a valid type for the attribute.
- void printAttributeWithoutType(Attribute attr) override {
- ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
- }
-
/// Print a block argument in the usual format of:
/// %ssaName : type {attr1=42} loc("here")
/// where location printing is controlled by the standard internal option.
@@ -2388,13 +2388,13 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
/// Print an optional attribute dictionary with a given set of elided values.
void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
- ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
+ Impl::printOptionalAttrDict(attrs, elidedAttrs);
}
void printOptionalAttrDictWithKeyword(
ArrayRef<NamedAttribute> attrs,
ArrayRef<StringRef> elidedAttrs = {}) override {
- ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
- /*withKeyword=*/true);
+ Impl::printOptionalAttrDict(attrs, elidedAttrs,
+ /*withKeyword=*/true);
}
/// Print the given successor.
@@ -2427,11 +2427,6 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
ValueRange symOperands) override;
- /// Print the given string as a symbol reference.
- void printSymbolName(StringRef symbolRef) override {
- ::printSymbolReference(symbolRef, os);
- }
-
private:
// Contains the stack of default dialects to use when printing regions.
// A new dialect is pushed to the stack before parsing regions nested under an
@@ -2732,7 +2727,7 @@ void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
//===----------------------------------------------------------------------===//
void Attribute::print(raw_ostream &os) const {
- ModulePrinter(os).printAttribute(*this);
+ AsmPrinter::Impl(os).printAttribute(*this);
}
void Attribute::dump() const {
@@ -2740,7 +2735,9 @@ void Attribute::dump() const {
llvm::errs() << "\n";
}
-void Type::print(raw_ostream &os) const { ModulePrinter(os).printType(*this); }
+void Type::print(raw_ostream &os) const {
+ AsmPrinter::Impl(os).printType(*this);
+}
void Type::dump() const { print(llvm::errs()); }
@@ -2759,7 +2756,7 @@ void AffineExpr::print(raw_ostream &os) const {
os << "<<NULL AFFINE EXPR>>";
return;
}
- ModulePrinter(os).printAffineExpr(*this);
+ AsmPrinter::Impl(os).printAffineExpr(*this);
}
void AffineExpr::dump() const {
@@ -2772,11 +2769,11 @@ void AffineMap::print(raw_ostream &os) const {
os << "<<NULL AFFINE MAP>>";
return;
}
- ModulePrinter(os).printAffineMap(*this);
+ AsmPrinter::Impl(os).printAffineMap(*this);
}
void IntegerSet::print(raw_ostream &os) const {
- ModulePrinter(os).printIntegerSet(*this);
+ AsmPrinter::Impl(os).printIntegerSet(*this);
}
void Value::print(raw_ostream &os) {
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index 2f2997e11fb83..0170b99c0d1a6 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -24,8 +24,6 @@
using namespace mlir;
using namespace detail;
-DialectAsmParser::~DialectAsmParser() {}
-
//===----------------------------------------------------------------------===//
// DialectRegistry
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 6cba5742a3e0f..0fa51ea979430 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -19,8 +19,6 @@
using namespace mlir;
-OpAsmParser::~OpAsmParser() {}
-
//===----------------------------------------------------------------------===//
// OperationName
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
new file mode 100644
index 0000000000000..6db297b92200d
--- /dev/null
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -0,0 +1,501 @@
+//===- AsmParserImpl.h - MLIR AsmParserImpl Class ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_ASMPARSERIMPL_H
+#define MLIR_LIB_PARSER_ASMPARSERIMPL_H
+
+#include "Parser.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Parser/AsmParserState.h"
+
+namespace mlir {
+namespace detail {
+//===----------------------------------------------------------------------===//
+// AsmParserImpl
+//===----------------------------------------------------------------------===//
+
+/// This class provides the implementation of the generic parser methods within
+/// AsmParser.
+template <typename BaseT>
+class AsmParserImpl : public BaseT {
+public:
+ AsmParserImpl(llvm::SMLoc nameLoc, Parser &parser)
+ : nameLoc(nameLoc), parser(parser) {}
+ ~AsmParserImpl() override {}
+
+ /// Return the location of the original name token.
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ //===--------------------------------------------------------------------===//
+ // Utilities
+ //===--------------------------------------------------------------------===//
+
+ /// Return if any errors were emitted during parsing.
+ bool didEmitError() const { return emittedError; }
+
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ emittedError = true;
+ return parser.emitError(loc, message);
+ }
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ Builder &getBuilder() const override { return parser.builder; }
+
+ /// Get the location of the next token and store it into the argument. This
+ /// always succeeds.
+ llvm::SMLoc getCurrentLocation() override {
+ return parser.getToken().getLoc();
+ }
+
+ /// Re-encode the given source location as an MLIR location and return it.
+ Location getEncodedSourceLoc(llvm::SMLoc loc) override {
+ return parser.getEncodedSourceLocation(loc);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ using Delimiter = AsmParser::Delimiter;
+
+ /// Parse a `->` token.
+ ParseResult parseArrow() override {
+ return parser.parseToken(Token::arrow, "expected '->'");
+ }
+
+ /// Parses a `->` if present.
+ ParseResult parseOptionalArrow() override {
+ return success(parser.consumeIf(Token::arrow));
+ }
+
+ /// Parse a '{' token.
+ ParseResult parseLBrace() override {
+ return parser.parseToken(Token::l_brace, "expected '{'");
+ }
+
+ /// Parse a '{' token if present
+ ParseResult parseOptionalLBrace() override {
+ return success(parser.consumeIf(Token::l_brace));
+ }
+
+ /// Parse a `}` token.
+ ParseResult parseRBrace() override {
+ return parser.parseToken(Token::r_brace, "expected '}'");
+ }
+
+ /// Parse a `}` token if present
+ ParseResult parseOptionalRBrace() override {
+ return success(parser.consumeIf(Token::r_brace));
+ }
+
+ /// Parse a `:` token.
+ ParseResult parseColon() override {
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ /// Parse a `:` token if present.
+ ParseResult parseOptionalColon() override {
+ return success(parser.consumeIf(Token::colon));
+ }
+
+ /// Parse a `,` token.
+ ParseResult parseComma() override {
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ /// Parse a `,` token if present.
+ ParseResult parseOptionalComma() override {
+ return success(parser.consumeIf(Token::comma));
+ }
+
+ /// Parses a `...` if present.
+ ParseResult parseOptionalEllipsis() override {
+ return success(parser.consumeIf(Token::ellipsis));
+ }
+
+ /// Parse a `=` token.
+ ParseResult parseEqual() override {
+ return parser.parseToken(Token::equal, "expected '='");
+ }
+
+ /// Parse a `=` token if present.
+ ParseResult parseOptionalEqual() override {
+ return success(parser.consumeIf(Token::equal));
+ }
+
+ /// Parse a '<' token.
+ ParseResult parseLess() override {
+ return parser.parseToken(Token::less, "expected '<'");
+ }
+
+ /// Parse a `<` token if present.
+ ParseResult parseOptionalLess() override {
+ return success(parser.consumeIf(Token::less));
+ }
+
+ /// Parse a '>' token.
+ ParseResult parseGreater() override {
+ return parser.parseToken(Token::greater, "expected '>'");
+ }
+
+ /// Parse a `>` token if present.
+ ParseResult parseOptionalGreater() override {
+ return success(parser.consumeIf(Token::greater));
+ }
+
+ /// Parse a `(` token.
+ ParseResult parseLParen() override {
+ return parser.parseToken(Token::l_paren, "expected '('");
+ }
+
+ /// Parses a '(' if present.
+ ParseResult parseOptionalLParen() override {
+ return success(parser.consumeIf(Token::l_paren));
+ }
+
+ /// Parse a `)` token.
+ ParseResult parseRParen() override {
+ return parser.parseToken(Token::r_paren, "expected ')'");
+ }
+
+ /// Parses a ')' if present.
+ ParseResult parseOptionalRParen() override {
+ return success(parser.consumeIf(Token::r_paren));
+ }
+
+ /// Parse a `[` token.
+ ParseResult parseLSquare() override {
+ return parser.parseToken(Token::l_square, "expected '['");
+ }
+
+ /// Parses a '[' if present.
+ ParseResult parseOptionalLSquare() override {
+ return success(parser.consumeIf(Token::l_square));
+ }
+
+ /// Parse a `]` token.
+ ParseResult parseRSquare() override {
+ return parser.parseToken(Token::r_square, "expected ']'");
+ }
+
+ /// Parses a ']' if present.
+ ParseResult parseOptionalRSquare() override {
+ return success(parser.consumeIf(Token::r_square));
+ }
+
+ /// Parses a '?' token.
+ ParseResult parseQuestion() override {
+ return parser.parseToken(Token::question, "expected '?'");
+ }
+
+ /// Parses a '?' if present.
+ ParseResult parseOptionalQuestion() override {
+ return success(parser.consumeIf(Token::question));
+ }
+
+ /// Parses a '*' token.
+ ParseResult parseStar() override {
+ return parser.parseToken(Token::star, "expected '*'");
+ }
+
+ /// Parses a '*' if present.
+ ParseResult parseOptionalStar() override {
+ return success(parser.consumeIf(Token::star));
+ }
+
+ /// Parses a '+' token.
+ ParseResult parsePlus() override {
+ return parser.parseToken(Token::plus, "expected '+'");
+ }
+
+ /// Parses a '+' token if present.
+ ParseResult parseOptionalPlus() override {
+ return success(parser.consumeIf(Token::plus));
+ }
+
+ /// Parses a quoted string token if present.
+ ParseResult parseOptionalString(std::string *string) override {
+ if (!parser.getToken().is(Token::string))
+ return failure();
+
+ if (string)
+ *string = parser.getToken().getStringValue();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Returns true if the current token corresponds to a keyword.
+ bool isCurrentTokenAKeyword() const {
+ return parser.getToken().isAny(Token::bare_identifier, Token::inttype) ||
+ parser.getToken().isKeyword();
+ }
+
+ /// Parse the given keyword if present.
+ ParseResult parseOptionalKeyword(StringRef keyword) override {
+ // Check that the current token has the same spelling.
+ if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
+ return failure();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse a keyword, if present, into 'keyword'.
+ ParseResult parseOptionalKeyword(StringRef *keyword) override {
+ // Check that the current token is a keyword.
+ if (!isCurrentTokenAKeyword())
+ return failure();
+
+ *keyword = parser.getTokenSpelling();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse a keyword if it is one of the 'allowedKeywords'.
+ ParseResult
+ parseOptionalKeyword(StringRef *keyword,
+ ArrayRef<StringRef> allowedKeywords) override {
+ // Check that the current token is a keyword.
+ if (!isCurrentTokenAKeyword())
+ return failure();
+
+ StringRef currentKeyword = parser.getTokenSpelling();
+ if (llvm::is_contained(allowedKeywords, currentKeyword)) {
+ *keyword = currentKeyword;
+ parser.consumeToken();
+ return success();
+ }
+
+ return failure();
+ }
+
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ bool isNegative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+ llvm::SMLoc loc = curTok.getLoc();
+
+ // Check for a floating point value.
+ if (curTok.is(Token::floatliteral)) {
+ auto val = curTok.getFloatingPointValue();
+ if (!val.hasValue())
+ return emitError(loc, "floating point value too large");
+ parser.consumeToken(Token::floatliteral);
+ result = isNegative ? -*val : *val;
+ return success();
+ }
+
+ // Check for a hexadecimal float value.
+ if (curTok.is(Token::integer)) {
+ Optional<APFloat> apResult;
+ if (failed(parser.parseFloatFromIntegerLiteral(
+ apResult, curTok, isNegative, APFloat::IEEEdouble(),
+ /*typeSizeInBits=*/64)))
+ return failure();
+
+ parser.consumeToken(Token::integer);
+ result = apResult->convertToDouble();
+ return success();
+ }
+
+ return emitError(loc, "expected floating point literal");
+ }
+
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(APInt &result) override {
+ return parser.parseOptionalInteger(result);
+ }
+
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ParseResult parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElt,
+ StringRef contextMessage) override {
+ return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute and return it in result.
+ ParseResult parseAttribute(Attribute &result, Type type) override {
+ result = parser.parseAttribute(type);
+ return success(static_cast<bool>(result));
+ }
+
+ /// Parse an optional attribute.
+ template <typename AttrT>
+ OptionalParseResult
+ parseOptionalAttributeAndAddToList(AttrT &result, Type type,
+ StringRef attrName, NamedAttrList &attrs) {
+ OptionalParseResult parseResult =
+ parser.parseOptionalAttribute(result, type);
+ if (parseResult.hasValue() && succeeded(*parseResult))
+ attrs.push_back(parser.builder.getNamedAttr(attrName, result));
+ return parseResult;
+ }
+ OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+ }
+ OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+ }
+ OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+ }
+
+ /// Parse a named dictionary into 'result' if it is present.
+ ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
+ if (parser.getToken().isNot(Token::l_brace))
+ return success();
+ return parser.parseAttributeDict(result);
+ }
+
+ /// Parse a named dictionary into 'result' if the `attributes` keyword is
+ /// present.
+ ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
+ if (failed(parseOptionalKeyword("attributes")))
+ return success();
+ return parser.parseAttributeDict(result);
+ }
+
+ /// Parse an affine map instance into 'map'.
+ ParseResult parseAffineMap(AffineMap &map) override {
+ return parser.parseAffineMapReference(map);
+ }
+
+ /// Parse an integer set instance into 'set'.
+ ParseResult printIntegerSet(IntegerSet &set) override {
+ return parser.parseIntegerSetReference(set);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Identifier Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an optional @-identifier and store it (without the '@' symbol) in a
+ /// string attribute named 'attrName'.
+ ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
+ NamedAttrList &attrs) override {
+ Token atToken = parser.getToken();
+ if (atToken.isNot(Token::at_identifier))
+ return failure();
+
+ result = getBuilder().getStringAttr(atToken.getSymbolReference());
+ attrs.push_back(getBuilder().getNamedAttr(attrName, result));
+ parser.consumeToken();
+
+ // If we are populating the assembly parser state, record this as a symbol
+ // reference.
+ if (parser.getState().asmState) {
+ parser.getState().asmState->addUses(SymbolRefAttr::get(result),
+ atToken.getLocRange());
+ }
+ return success();
+ }
+
+ /// Parse a loc(...) specifier if present, filling in result if so.
+ ParseResult
+ parseOptionalLocationSpecifier(Optional<Location> &result) override {
+ // If there is a 'loc' we parse a trailing location.
+ if (!parser.consumeIf(Token::kw_loc))
+ return success();
+ LocationAttr directLoc;
+ if (parser.parseToken(Token::l_paren, "expected '(' in location") ||
+ parser.parseLocationInstance(directLoc) ||
+ parser.parseToken(Token::r_paren, "expected ')' in location"))
+ return failure();
+
+ result = directLoc;
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a type.
+ ParseResult parseType(Type &result) override {
+ return failure(!(result = parser.parseType()));
+ }
+
+ /// Parse an optional type.
+ OptionalParseResult parseOptionalType(Type &result) override {
+ return parser.parseOptionalType(result);
+ }
+
+ /// Parse an arrow followed by a type list.
+ ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
+ if (parseArrow() || parser.parseFunctionResultTypes(result))
+ return failure();
+ return success();
+ }
+
+ /// Parse an optional arrow followed by a type list.
+ ParseResult
+ parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
+ if (!parser.consumeIf(Token::arrow))
+ return success();
+ return parser.parseFunctionResultTypes(result);
+ }
+
+ /// Parse a colon followed by a type.
+ ParseResult parseColonType(Type &result) override {
+ return failure(parser.parseToken(Token::colon, "expected ':'") ||
+ !(result = parser.parseType()));
+ }
+
+ /// Parse a colon followed by a type list, which must have at least one type.
+ ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
+ if (parser.parseToken(Token::colon, "expected ':'"))
+ return failure();
+ return parser.parseTypeListNoParens(result);
+ }
+
+ /// Parse an optional colon followed by a type list, which if present must
+ /// have at least one type.
+ ParseResult
+ parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
+ if (!parser.consumeIf(Token::colon))
+ return success();
+ return parser.parseTypeListNoParens(result);
+ }
+
+ ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) override {
+ return parser.parseDimensionListRanked(dimensions, allowDynamic);
+ }
+
+ ParseResult parseXInDimensionList() override {
+ return parser.parseXInDimensionList();
+ }
+
+protected:
+ /// The source location of the dialect symbol.
+ llvm::SMLoc nameLoc;
+
+ /// The main parser.
+ Parser &parser;
+
+ /// A flag that indicates if any errors were emitted during parsing.
+ bool emittedError = false;
+};
+} // namespace detail
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_ASMPARSERIMPL_H
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 5fea113094001..b6b32072eb2e1 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -11,7 +11,7 @@
//
//===----------------------------------------------------------------------===//
-#include "Parser.h"
+#include "AsmParserImpl.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
@@ -27,304 +27,20 @@ namespace {
/// This class provides the main implementation of the DialectAsmParser that
/// allows for dialects to parse attributes and types. This allows for dialect
/// hooking into the main MLIR parsing logic.
-class CustomDialectAsmParser : public DialectAsmParser {
+class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
public:
CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
- : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
- parser(parser) {}
+ : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
+ fullSpec(fullSpec) {}
~CustomDialectAsmParser() override {}
- /// Emit a diagnostic at the specified location and return failure.
- InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
- return parser.emitError(loc, message);
- }
-
- /// Return a builder which provides useful access to MLIRContext, global
- /// objects like types and attributes.
- Builder &getBuilder() const override { return parser.builder; }
-
- /// Get the location of the next token and store it into the argument. This
- /// always succeeds.
- llvm::SMLoc getCurrentLocation() override {
- return parser.getToken().getLoc();
- }
-
- /// Return the location of the original name token.
- llvm::SMLoc getNameLoc() const override { return nameLoc; }
-
- /// Re-encode the given source location as an MLIR location and return it.
- Location getEncodedSourceLoc(llvm::SMLoc loc) override {
- return parser.getEncodedSourceLocation(loc);
- }
-
/// Returns the full specification of the symbol being parsed. This allows
/// for using a separate parser if necessary.
StringRef getFullSymbolSpec() const override { return fullSpec; }
- /// Parse a floating point value from the stream.
- ParseResult parseFloat(double &result) override {
- bool isNegative = parser.consumeIf(Token::minus);
- Token curTok = parser.getToken();
- llvm::SMLoc loc = curTok.getLoc();
-
- // Check for a floating point value.
- if (curTok.is(Token::floatliteral)) {
- auto val = curTok.getFloatingPointValue();
- if (!val.hasValue())
- return emitError(loc, "floating point value too large");
- parser.consumeToken(Token::floatliteral);
- result = isNegative ? -*val : *val;
- return success();
- }
-
- // Check for a hexadecimal float value.
- if (curTok.is(Token::integer)) {
- Optional<APFloat> apResult;
- if (failed(parser.parseFloatFromIntegerLiteral(
- apResult, curTok, isNegative, APFloat::IEEEdouble(),
- /*typeSizeInBits=*/64)))
- return failure();
-
- parser.consumeToken(Token::integer);
- result = apResult->convertToDouble();
- return success();
- }
-
- return emitError(loc, "expected floating point literal");
- }
-
- /// Parse an optional integer value from the stream.
- OptionalParseResult parseOptionalInteger(APInt &result) override {
- return parser.parseOptionalInteger(result);
- }
-
- //===--------------------------------------------------------------------===//
- // Token Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a `->` token.
- ParseResult parseArrow() override {
- return parser.parseToken(Token::arrow, "expected '->'");
- }
-
- /// Parses a `->` if present.
- ParseResult parseOptionalArrow() override {
- return success(parser.consumeIf(Token::arrow));
- }
-
- /// Parse a '{' token.
- ParseResult parseLBrace() override {
- return parser.parseToken(Token::l_brace, "expected '{'");
- }
-
- /// Parse a '{' token if present
- ParseResult parseOptionalLBrace() override {
- return success(parser.consumeIf(Token::l_brace));
- }
-
- /// Parse a `}` token.
- ParseResult parseRBrace() override {
- return parser.parseToken(Token::r_brace, "expected '}'");
- }
-
- /// Parse a `}` token if present
- ParseResult parseOptionalRBrace() override {
- return success(parser.consumeIf(Token::r_brace));
- }
-
- /// Parse a `:` token.
- ParseResult parseColon() override {
- return parser.parseToken(Token::colon, "expected ':'");
- }
-
- /// Parse a `:` token if present.
- ParseResult parseOptionalColon() override {
- return success(parser.consumeIf(Token::colon));
- }
-
- /// Parse a `,` token.
- ParseResult parseComma() override {
- return parser.parseToken(Token::comma, "expected ','");
- }
-
- /// Parse a `,` token if present.
- ParseResult parseOptionalComma() override {
- return success(parser.consumeIf(Token::comma));
- }
-
- /// Parses a `...` if present.
- ParseResult parseOptionalEllipsis() override {
- return success(parser.consumeIf(Token::ellipsis));
- }
-
- /// Parse a `=` token.
- ParseResult parseEqual() override {
- return parser.parseToken(Token::equal, "expected '='");
- }
-
- /// Parse a `=` token if present.
- ParseResult parseOptionalEqual() override {
- return success(parser.consumeIf(Token::equal));
- }
-
- /// Parse a '<' token.
- ParseResult parseLess() override {
- return parser.parseToken(Token::less, "expected '<'");
- }
-
- /// Parse a `<` token if present.
- ParseResult parseOptionalLess() override {
- return success(parser.consumeIf(Token::less));
- }
-
- /// Parse a '>' token.
- ParseResult parseGreater() override {
- return parser.parseToken(Token::greater, "expected '>'");
- }
-
- /// Parse a `>` token if present.
- ParseResult parseOptionalGreater() override {
- return success(parser.consumeIf(Token::greater));
- }
-
- /// Parse a `(` token.
- ParseResult parseLParen() override {
- return parser.parseToken(Token::l_paren, "expected '('");
- }
-
- /// Parses a '(' if present.
- ParseResult parseOptionalLParen() override {
- return success(parser.consumeIf(Token::l_paren));
- }
-
- /// Parse a `)` token.
- ParseResult parseRParen() override {
- return parser.parseToken(Token::r_paren, "expected ')'");
- }
-
- /// Parses a ')' if present.
- ParseResult parseOptionalRParen() override {
- return success(parser.consumeIf(Token::r_paren));
- }
-
- /// Parse a `[` token.
- ParseResult parseLSquare() override {
- return parser.parseToken(Token::l_square, "expected '['");
- }
-
- /// Parses a '[' if present.
- ParseResult parseOptionalLSquare() override {
- return success(parser.consumeIf(Token::l_square));
- }
-
- /// Parse a `]` token.
- ParseResult parseRSquare() override {
- return parser.parseToken(Token::r_square, "expected ']'");
- }
-
- /// Parses a ']' if present.
- ParseResult parseOptionalRSquare() override {
- return success(parser.consumeIf(Token::r_square));
- }
-
- /// Parses a '?' if present.
- ParseResult parseOptionalQuestion() override {
- return success(parser.consumeIf(Token::question));
- }
-
- /// Parses a '*' if present.
- ParseResult parseOptionalStar() override {
- return success(parser.consumeIf(Token::star));
- }
-
- /// Parses a quoted string token if present.
- ParseResult parseOptionalString(std::string *string) override {
- if (!parser.getToken().is(Token::string))
- return failure();
-
- if (string)
- *string = parser.getToken().getStringValue();
- parser.consumeToken();
- return success();
- }
-
- /// Returns true if the current token corresponds to a keyword.
- bool isCurrentTokenAKeyword() const {
- return parser.getToken().isAny(Token::bare_identifier, Token::inttype) ||
- parser.getToken().isKeyword();
- }
-
- /// Parse the given keyword if present.
- ParseResult parseOptionalKeyword(StringRef keyword) override {
- // Check that the current token has the same spelling.
- if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
- return failure();
- parser.consumeToken();
- return success();
- }
-
- /// Parse a keyword, if present, into 'keyword'.
- ParseResult parseOptionalKeyword(StringRef *keyword) override {
- // Check that the current token is a keyword.
- if (!isCurrentTokenAKeyword())
- return failure();
-
- *keyword = parser.getTokenSpelling();
- parser.consumeToken();
- return success();
- }
-
- //===--------------------------------------------------------------------===//
- // Attribute Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an arbitrary attribute and return it in result.
- ParseResult parseAttribute(Attribute &result, Type type) override {
- result = parser.parseAttribute(type);
- return success(static_cast<bool>(result));
- }
-
- /// Parse an affine map instance into 'map'.
- ParseResult parseAffineMap(AffineMap &map) override {
- return parser.parseAffineMapReference(map);
- }
-
- /// Parse an integer set instance into 'set'.
- ParseResult printIntegerSet(IntegerSet &set) override {
- return parser.parseIntegerSetReference(set);
- }
-
- //===--------------------------------------------------------------------===//
- // Type Parsing
- //===--------------------------------------------------------------------===//
-
- ParseResult parseType(Type &result) override {
- result = parser.parseType();
- return success(static_cast<bool>(result));
- }
-
- ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic) override {
- return parser.parseDimensionListRanked(dimensions, allowDynamic);
- }
-
- ParseResult parseXInDimensionList() override {
- return parser.parseXInDimensionList();
- }
-
- OptionalParseResult parseOptionalType(Type &result) override {
- return parser.parseOptionalType(result);
- }
-
private:
/// The full symbol specification.
StringRef fullSpec;
-
- /// The source location of the dialect symbol.
- SMLoc nameLoc;
-
- /// The main parser.
- Parser &parser;
};
} // namespace
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 6c8375427e791..881892d5bc252 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
+#include "AsmParserImpl.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
@@ -1093,15 +1094,15 @@ Operation *OperationParser::parseGenericOperation(Block *insertBlock,
}
namespace {
-class CustomOpAsmParser : public OpAsmParser {
+class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
public:
CustomOpAsmParser(
SMLoc nameLoc, ArrayRef<OperationParser::ResultRecord> resultIDs,
function_ref<ParseResult(OpAsmParser &, OperationState &)> parseAssembly,
bool isIsolatedFromAbove, StringRef opName, OperationParser &parser)
- : nameLoc(nameLoc), resultIDs(resultIDs), parseAssembly(parseAssembly),
- isIsolatedFromAbove(isIsolatedFromAbove), opName(opName),
- parser(parser) {
+ : AsmParserImpl<OpAsmParser>(nameLoc, parser), resultIDs(resultIDs),
+ parseAssembly(parseAssembly), isIsolatedFromAbove(isIsolatedFromAbove),
+ opName(opName), parser(parser) {
(void)isIsolatedFromAbove; // Only used in assert, silence unused warning.
}
@@ -1131,21 +1132,6 @@ class CustomOpAsmParser : public OpAsmParser {
// Utilities
//===--------------------------------------------------------------------===//
- /// Return if any errors were emitted during parsing.
- bool didEmitError() const { return emittedError; }
-
- /// Emit a diagnostic at the specified location and return failure.
- InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
- emittedError = true;
- return parser.emitError(loc, "custom op '" + opName + "' " + message);
- }
-
- llvm::SMLoc getCurrentLocation() override {
- return parser.getToken().getLoc();
- }
-
- Builder &getBuilder() const override { return parser.builder; }
-
/// Return the name of the specified result in the specified syntax, as well
/// as the subelement in the name. For example, in this operation:
///
@@ -1181,331 +1167,10 @@ class CustomOpAsmParser : public OpAsmParser {
return count;
}
- llvm::SMLoc getNameLoc() const override { return nameLoc; }
-
- /// Re-encode the given source location as an MLIR location and return it.
- Location getEncodedSourceLoc(llvm::SMLoc loc) override {
- return parser.getEncodedSourceLocation(loc);
- }
-
- //===--------------------------------------------------------------------===//
- // Token Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a `->` token.
- ParseResult parseArrow() override {
- return parser.parseToken(Token::arrow, "expected '->'");
- }
-
- /// Parses a `->` if present.
- ParseResult parseOptionalArrow() override {
- return success(parser.consumeIf(Token::arrow));
- }
-
- /// Parse a '{' token.
- ParseResult parseLBrace() override {
- return parser.parseToken(Token::l_brace, "expected '{'");
- }
-
- /// Parse a '{' token if present
- ParseResult parseOptionalLBrace() override {
- return success(parser.consumeIf(Token::l_brace));
- }
-
- /// Parse a `}` token.
- ParseResult parseRBrace() override {
- return parser.parseToken(Token::r_brace, "expected '}'");
- }
-
- /// Parse a `}` token if present
- ParseResult parseOptionalRBrace() override {
- return success(parser.consumeIf(Token::r_brace));
- }
-
- /// Parse a `:` token.
- ParseResult parseColon() override {
- return parser.parseToken(Token::colon, "expected ':'");
- }
-
- /// Parse a `:` token if present.
- ParseResult parseOptionalColon() override {
- return success(parser.consumeIf(Token::colon));
- }
-
- /// Parse a `,` token.
- ParseResult parseComma() override {
- return parser.parseToken(Token::comma, "expected ','");
- }
-
- /// Parse a `,` token if present.
- ParseResult parseOptionalComma() override {
- return success(parser.consumeIf(Token::comma));
- }
-
- /// Parses a `...` if present.
- ParseResult parseOptionalEllipsis() override {
- return success(parser.consumeIf(Token::ellipsis));
- }
-
- /// Parse a `=` token.
- ParseResult parseEqual() override {
- return parser.parseToken(Token::equal, "expected '='");
- }
-
- /// Parse a `=` token if present.
- ParseResult parseOptionalEqual() override {
- return success(parser.consumeIf(Token::equal));
- }
-
- /// Parse a '<' token.
- ParseResult parseLess() override {
- return parser.parseToken(Token::less, "expected '<'");
- }
-
- /// Parse a '<' token if present.
- ParseResult parseOptionalLess() override {
- return success(parser.consumeIf(Token::less));
- }
-
- /// Parse a '>' token.
- ParseResult parseGreater() override {
- return parser.parseToken(Token::greater, "expected '>'");
- }
-
- /// Parse a '>' token if present.
- ParseResult parseOptionalGreater() override {
- return success(parser.consumeIf(Token::greater));
- }
-
- /// Parse a `(` token.
- ParseResult parseLParen() override {
- return parser.parseToken(Token::l_paren, "expected '('");
- }
-
- /// Parses a '(' if present.
- ParseResult parseOptionalLParen() override {
- return success(parser.consumeIf(Token::l_paren));
- }
-
- /// Parse a `)` token.
- ParseResult parseRParen() override {
- return parser.parseToken(Token::r_paren, "expected ')'");
- }
-
- /// Parses a ')' if present.
- ParseResult parseOptionalRParen() override {
- return success(parser.consumeIf(Token::r_paren));
- }
-
- /// Parse a `[` token.
- ParseResult parseLSquare() override {
- return parser.parseToken(Token::l_square, "expected '['");
- }
-
- /// Parses a '[' if present.
- ParseResult parseOptionalLSquare() override {
- return success(parser.consumeIf(Token::l_square));
- }
-
- /// Parse a `]` token.
- ParseResult parseRSquare() override {
- return parser.parseToken(Token::r_square, "expected ']'");
- }
-
- /// Parses a ']' if present.
- ParseResult parseOptionalRSquare() override {
- return success(parser.consumeIf(Token::r_square));
- }
-
- /// Parses a '?' token.
- ParseResult parseQuestion() override {
- return parser.parseToken(Token::question, "expected '?'");
- }
-
- /// Parses a '?' token if present.
- ParseResult parseOptionalQuestion() override {
- return success(parser.consumeIf(Token::question));
- }
-
- /// Parses a '+' token.
- ParseResult parsePlus() override {
- return parser.parseToken(Token::plus, "expected '+'");
- }
-
- /// Parses a '+' token if present.
- ParseResult parseOptionalPlus() override {
- return success(parser.consumeIf(Token::plus));
- }
-
- /// Parses a '*' token.
- ParseResult parseStar() override {
- return parser.parseToken(Token::star, "expected '*'");
- }
-
- /// Parses a '*' token if present.
- ParseResult parseOptionalStar() override {
- return success(parser.consumeIf(Token::star));
- }
-
- /// Parse an optional integer value from the stream.
- OptionalParseResult parseOptionalInteger(APInt &result) override {
- return parser.parseOptionalInteger(result);
- }
-
- /// Parse a list of comma-separated items with an optional delimiter. If a
- /// delimiter is provided, then an empty list is allowed. If not, then at
- /// least one element will be parsed.
- ParseResult parseCommaSeparatedList(Delimiter delimiter,
- function_ref<ParseResult()> parseElt,
- StringRef contextMessage) override {
- return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
- }
-
- //===--------------------------------------------------------------------===//
- // Attribute Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an arbitrary attribute of a given type and return it in result.
- ParseResult parseAttribute(Attribute &result, Type type) override {
- result = parser.parseAttribute(type);
- return success(static_cast<bool>(result));
- }
-
- /// Parse an optional attribute.
- template <typename AttrT>
- OptionalParseResult
- parseOptionalAttributeAndAddToList(AttrT &result, Type type,
- StringRef attrName, NamedAttrList &attrs) {
- OptionalParseResult parseResult =
- parser.parseOptionalAttribute(result, type);
- if (parseResult.hasValue() && succeeded(*parseResult))
- attrs.push_back(parser.builder.getNamedAttr(attrName, result));
- return parseResult;
- }
- OptionalParseResult parseOptionalAttribute(Attribute &result, Type type,
- StringRef attrName,
- NamedAttrList &attrs) override {
- return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
- }
- OptionalParseResult parseOptionalAttribute(ArrayAttr &result, Type type,
- StringRef attrName,
- NamedAttrList &attrs) override {
- return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
- }
- OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
- StringRef attrName,
- NamedAttrList &attrs) override {
- return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
- }
-
- /// Parse a named dictionary into 'result' if it is present.
- ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
- if (parser.getToken().isNot(Token::l_brace))
- return success();
- return parser.parseAttributeDict(result);
- }
-
- /// Parse a named dictionary into 'result' if the `attributes` keyword is
- /// present.
- ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override {
- if (failed(parseOptionalKeyword("attributes")))
- return success();
- return parser.parseAttributeDict(result);
- }
-
- /// Parse an affine map instance into 'map'.
- ParseResult parseAffineMap(AffineMap &map) override {
- return parser.parseAffineMapReference(map);
- }
-
- /// Parse an integer set instance into 'set'.
- ParseResult printIntegerSet(IntegerSet &set) override {
- return parser.parseIntegerSetReference(set);
- }
-
- //===--------------------------------------------------------------------===//
- // Identifier Parsing
- //===--------------------------------------------------------------------===//
-
- /// Returns true if the current token corresponds to a keyword.
- bool isCurrentTokenAKeyword() const {
- return parser.getToken().is(Token::bare_identifier) ||
- parser.getToken().isKeyword();
- }
-
- /// Parse the given keyword if present.
- ParseResult parseOptionalKeyword(StringRef keyword) override {
- // Check that the current token has the same spelling.
- if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
- return failure();
- parser.consumeToken();
- return success();
- }
-
- /// Parse a keyword, if present, into 'keyword'.
- ParseResult parseOptionalKeyword(StringRef *keyword) override {
- // Check that the current token is a keyword.
- if (!isCurrentTokenAKeyword())
- return failure();
-
- *keyword = parser.getTokenSpelling();
- parser.consumeToken();
- return success();
- }
-
- /// Parse a keyword if it is one of the 'allowedKeywords'.
- ParseResult
- parseOptionalKeyword(StringRef *keyword,
- ArrayRef<StringRef> allowedKeywords) override {
- // Check that the current token is a keyword.
- if (!isCurrentTokenAKeyword())
- return failure();
-
- StringRef currentKeyword = parser.getTokenSpelling();
- if (llvm::is_contained(allowedKeywords, currentKeyword)) {
- *keyword = currentKeyword;
- parser.consumeToken();
- return success();
- }
-
- return failure();
- }
-
- /// Parse an optional @-identifier and store it (without the '@' symbol) in a
- /// string attribute named 'attrName'.
- ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
- NamedAttrList &attrs) override {
- Token atToken = parser.getToken();
- if (atToken.isNot(Token::at_identifier))
- return failure();
-
- result = getBuilder().getStringAttr(atToken.getSymbolReference());
- attrs.push_back(getBuilder().getNamedAttr(attrName, result));
- parser.consumeToken();
-
- // If we are populating the assembly parser state, record this as a symbol
- // reference.
- if (parser.getState().asmState) {
- parser.getState().asmState->addUses(SymbolRefAttr::get(result),
- atToken.getLocRange());
- }
- return success();
- }
-
- /// Parse a loc(...) specifier if present, filling in result if so.
- ParseResult
- parseOptionalLocationSpecifier(Optional<Location> &result) override {
- // If there is a 'loc' we parse a trailing location.
- if (!parser.consumeIf(Token::kw_loc))
- return success();
- LocationAttr directLoc;
- if (parser.parseToken(Token::l_paren, "expected '(' in location") ||
- parser.parseLocationInstance(directLoc) ||
- parser.parseToken(Token::r_paren, "expected ')' in location"))
- return failure();
-
- result = directLoc;
- return success();
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ return AsmParserImpl<OpAsmParser>::emitError(loc, "custom op '" + opName +
+ "' " + message);
}
//===--------------------------------------------------------------------===//
@@ -1779,53 +1444,6 @@ class CustomOpAsmParser : public OpAsmParser {
// Type Parsing
//===--------------------------------------------------------------------===//
- /// Parse a type.
- ParseResult parseType(Type &result) override {
- return failure(!(result = parser.parseType()));
- }
-
- /// Parse an optional type.
- OptionalParseResult parseOptionalType(Type &result) override {
- return parser.parseOptionalType(result);
- }
-
- /// Parse an arrow followed by a type list.
- ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) override {
- if (parseArrow() || parser.parseFunctionResultTypes(result))
- return failure();
- return success();
- }
-
- /// Parse an optional arrow followed by a type list.
- ParseResult
- parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) override {
- if (!parser.consumeIf(Token::arrow))
- return success();
- return parser.parseFunctionResultTypes(result);
- }
-
- /// Parse a colon followed by a type.
- ParseResult parseColonType(Type &result) override {
- return failure(parser.parseToken(Token::colon, "expected ':'") ||
- !(result = parser.parseType()));
- }
-
- /// Parse a colon followed by a type list, which must have at least one type.
- ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
- if (parser.parseToken(Token::colon, "expected ':'"))
- return failure();
- return parser.parseTypeListNoParens(result);
- }
-
- /// Parse an optional colon followed by a type list, which if present must
- /// have at least one type.
- ParseResult
- parseOptionalColonTypeList(SmallVectorImpl<Type> &result) override {
- if (!parser.consumeIf(Token::colon))
- return success();
- return parser.parseTypeListNoParens(result);
- }
-
/// Parse a list of assignments of the form
/// (%x1 = %y1, %x2 = %y2, ...).
OptionalParseResult
@@ -1870,9 +1488,6 @@ class CustomOpAsmParser : public OpAsmParser {
}
private:
- /// The source location of the operation name.
- SMLoc nameLoc;
-
/// Information about the result name specifiers.
ArrayRef<OperationParser::ResultRecord> resultIDs;
@@ -1881,11 +1496,8 @@ class CustomOpAsmParser : public OpAsmParser {
bool isIsolatedFromAbove;
StringRef opName;
- /// The main operation parser.
+ /// The backing operation parser.
OperationParser &parser;
-
- /// A flag that indicates if any errors were emitted during parsing.
- bool emittedError = false;
};
} // end anonymous namespace.
More information about the Mlir-commits
mailing list