[Mlir-commits] [mlir] 5111468 - [mlir][NFC] Split Parser into several different files.
River Riddle
llvmlistbot at llvm.org
Wed Jun 10 17:18:05 PDT 2020
Author: River Riddle
Date: 2020-06-10T17:17:13-07:00
New Revision: 51114686d51225e26b13c6870ecb8b6795413ba4
URL: https://github.com/llvm/llvm-project/commit/51114686d51225e26b13c6870ecb8b6795413ba4
DIFF: https://github.com/llvm/llvm-project/commit/51114686d51225e26b13c6870ecb8b6795413ba4.diff
LOG: [mlir][NFC] Split Parser into several different files.
Summary: At this point Parser has grown to be over 5000 lines and can be very difficult to navigate/update/etc. This commit splits Parser.cpp into several sub files focused on parsing specific types of entities; e.g., Attributes, Types, etc.
Differential Revision: https://reviews.llvm.org/D81299
Added:
mlir/lib/Parser/AffineParser.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/LocationParser.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/ParserState.h
mlir/lib/Parser/TypeParser.cpp
Modified:
mlir/lib/Parser/CMakeLists.txt
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Token.cpp
mlir/lib/Parser/Token.h
Removed:
################################################################################
diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp
new file mode 100644
index 000000000000..99869a262fa0
--- /dev/null
+++ b/mlir/lib/Parser/AffineParser.cpp
@@ -0,0 +1,726 @@
+//===- AffineParser.cpp - MLIR Affine Parser ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a parser for Affine structures.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+using llvm::SMLoc;
+
+namespace {
+
+/// Lower precedence ops (all at the same precedence level). LNoOp is false in
+/// the boolean sense.
+enum AffineLowPrecOp {
+ /// Null value.
+ LNoOp,
+ Add,
+ Sub
+};
+
+/// Higher precedence ops - all at the same precedence level. HNoOp is false
+/// in the boolean sense.
+enum AffineHighPrecOp {
+ /// Null value.
+ HNoOp,
+ Mul,
+ FloorDiv,
+ CeilDiv,
+ Mod
+};
+
+/// This is a specialized parser for affine structures (affine maps, affine
+/// expressions, and integer sets), maintaining the state transient to their
+/// bodies.
+class AffineParser : public Parser {
+public:
+ AffineParser(ParserState &state, bool allowParsingSSAIds = false,
+ function_ref<ParseResult(bool)> parseElement = nullptr)
+ : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
+ parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
+
+ AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
+ ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
+ IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
+ ParseResult parseAffineMapOfSSAIds(AffineMap &map,
+ OpAsmParser::Delimiter delimiter);
+ void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
+ unsigned &numDims);
+
+private:
+ // Binary affine op parsing.
+ AffineLowPrecOp consumeIfLowPrecOp();
+ AffineHighPrecOp consumeIfHighPrecOp();
+
+ // Identifier lists for polyhedral structures.
+ ParseResult parseDimIdList(unsigned &numDims);
+ ParseResult parseSymbolIdList(unsigned &numSymbols);
+ ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols);
+ ParseResult parseIdentifierDefinition(AffineExpr idExpr);
+
+ AffineExpr parseAffineExpr();
+ AffineExpr parseParentheticalExpr();
+ AffineExpr parseNegateExpression(AffineExpr lhs);
+ AffineExpr parseIntegerExpr();
+ AffineExpr parseBareIdExpr();
+ AffineExpr parseSSAIdExpr(bool isSymbol);
+ AffineExpr parseSymbolSSAIdExpr();
+
+ AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
+ AffineExpr rhs, llvm::SMLoc opLoc);
+ AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
+ AffineExpr rhs);
+ AffineExpr parseAffineOperandExpr(AffineExpr lhs);
+ AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
+ AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
+ llvm::SMLoc llhsOpLoc);
+ AffineExpr parseAffineConstraint(bool *isEq);
+
+private:
+ bool allowParsingSSAIds;
+ function_ref<ParseResult(bool)> parseElement;
+ unsigned numDimOperands;
+ unsigned numSymbolOperands;
+ SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
+};
+} // end anonymous namespace
+
+/// Create an affine binary high precedence op expression (mul's, div's, mod).
+/// opLoc is the location of the op token to be used to report errors
+/// for non-conforming expressions.
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
+ AffineExpr lhs, AffineExpr rhs,
+ SMLoc opLoc) {
+ // TODO: make the error location info accurate.
+ switch (op) {
+ case Mul:
+ if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: at least one of the multiply "
+ "operands has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs * rhs;
+ case FloorDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of floordiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.floorDiv(rhs);
+ case CeilDiv:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of ceildiv "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs.ceilDiv(rhs);
+ case Mod:
+ if (!rhs.isSymbolicOrConstant()) {
+ emitError(opLoc, "non-affine expression: right operand of mod "
+ "has to be either a constant or symbolic");
+ return nullptr;
+ }
+ return lhs % rhs;
+ case HNoOp:
+ llvm_unreachable("can't create affine expression for null high prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineHighPrecOp");
+}
+
+/// Create an affine binary low precedence op expression (add, sub).
+AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
+ AffineExpr lhs, AffineExpr rhs) {
+ switch (op) {
+ case AffineLowPrecOp::Add:
+ return lhs + rhs;
+ case AffineLowPrecOp::Sub:
+ return lhs - rhs;
+ case AffineLowPrecOp::LNoOp:
+ llvm_unreachable("can't create affine expression for null low prec op");
+ return nullptr;
+ }
+ llvm_unreachable("Unknown AffineLowPrecOp");
+}
+
+/// Consume this token if it is a lower precedence affine op (there are only
+/// two precedence levels).
+AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::plus:
+ consumeToken(Token::plus);
+ return AffineLowPrecOp::Add;
+ case Token::minus:
+ consumeToken(Token::minus);
+ return AffineLowPrecOp::Sub;
+ default:
+ return AffineLowPrecOp::LNoOp;
+ }
+}
+
+/// Consume this token if it is a higher precedence affine op (there are only
+/// two precedence levels)
+AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
+ switch (getToken().getKind()) {
+ case Token::star:
+ consumeToken(Token::star);
+ return Mul;
+ case Token::kw_floordiv:
+ consumeToken(Token::kw_floordiv);
+ return FloorDiv;
+ case Token::kw_ceildiv:
+ consumeToken(Token::kw_ceildiv);
+ return CeilDiv;
+ case Token::kw_mod:
+ consumeToken(Token::kw_mod);
+ return Mod;
+ default:
+ return HNoOp;
+ }
+}
+
+/// Parse a high precedence op expression list: mul, div, and mod are high
+/// precedence binary ops, i.e., parse a
+/// expr_1 op_1 expr_2 op_2 ... expr_n
+/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
+/// All affine binary ops are left associative.
+/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
+/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
+/// null. llhsOpLoc is the location of the llhsOp token that will be used to
+/// report an error for non-conforming expressions.
+AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
+ AffineHighPrecOp llhsOp,
+ SMLoc llhsOpLoc) {
+ AffineExpr lhs = parseAffineOperandExpr(llhs);
+ if (!lhs)
+ return nullptr;
+
+ // Found an LHS. Parse the remaining expression.
+ auto opLoc = getToken().getLoc();
+ if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
+ if (llhs) {
+ AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
+ if (!expr)
+ return nullptr;
+ return parseAffineHighPrecOpExpr(expr, op, opLoc);
+ }
+ // No LLHS, get RHS
+ return parseAffineHighPrecOpExpr(lhs, op, opLoc);
+ }
+
+ // This is the last operand in this expression.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
+
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression inside parentheses.
+///
+/// affine-expr ::= `(` affine-expr `)`
+AffineExpr AffineParser::parseParentheticalExpr() {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return nullptr;
+ if (getToken().is(Token::r_paren))
+ return (emitError("no expression inside parentheses"), nullptr);
+
+ auto expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+ if (parseToken(Token::r_paren, "expected ')'"))
+ return nullptr;
+
+ return expr;
+}
+
+/// Parse the negation expression.
+///
+/// affine-expr ::= `-` affine-expr
+AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
+ if (parseToken(Token::minus, "expected '-'"))
+ return nullptr;
+
+ AffineExpr operand = parseAffineOperandExpr(lhs);
+ // Since negation has the highest precedence of all ops (including high
+ // precedence ops) but lower than parentheses, we are only going to use
+ // parseAffineOperandExpr instead of parseAffineExpr here.
+ if (!operand)
+ // Extra error message although parseAffineOperandExpr would have
+ // complained. Leads to a better diagnostic.
+ return (emitError("missing operand of negation"), nullptr);
+ return (-1) * operand;
+}
+
+/// Parse a bare id that may appear in an affine expression.
+///
+/// affine-expr ::= bare-id
+AffineExpr AffineParser::parseBareIdExpr() {
+ if (getToken().isNot(Token::bare_identifier))
+ return (emitError("expected bare identifier"), nullptr);
+
+ StringRef sRef = getTokenSpelling();
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == sRef) {
+ consumeToken(Token::bare_identifier);
+ return entry.second;
+ }
+ }
+
+ return (emitError("use of undeclared identifier"), nullptr);
+}
+
+/// Parse an SSA id which may appear in an affine expression.
+AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
+ if (!allowParsingSSAIds)
+ return (emitError("unexpected ssa identifier"), nullptr);
+ if (getToken().isNot(Token::percent_identifier))
+ return (emitError("expected ssa identifier"), nullptr);
+ auto name = getTokenSpelling();
+ // Check if we already parsed this SSA id.
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name) {
+ consumeToken(Token::percent_identifier);
+ return entry.second;
+ }
+ }
+ // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
+ if (parseElement(isSymbol))
+ return (emitError("failed to parse ssa identifier"), nullptr);
+ auto idExpr = isSymbol
+ ? getAffineSymbolExpr(numSymbolOperands++, getContext())
+ : getAffineDimExpr(numDimOperands++, getContext());
+ dimsAndSymbols.push_back({name, idExpr});
+ return idExpr;
+}
+
+AffineExpr AffineParser::parseSymbolSSAIdExpr() {
+ if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
+ parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
+ return nullptr;
+ AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
+ if (!symbolExpr)
+ return nullptr;
+ if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
+ return nullptr;
+ return symbolExpr;
+}
+
+/// Parse a positive integral constant appearing in an affine expression.
+///
+/// affine-expr ::= integer-literal
+AffineExpr AffineParser::parseIntegerExpr() {
+ auto val = getToken().getUInt64IntegerValue();
+ if (!val.hasValue() || (int64_t)val.getValue() < 0)
+ return (emitError("constant too large for index"), nullptr);
+
+ consumeToken(Token::integer);
+ return builder.getAffineConstantExpr((int64_t)val.getValue());
+}
+
+/// Parses an expression that can be a valid operand of an affine expression.
+/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
+/// operator, the rhs of which is being parsed. This is used to determine
+/// whether an error should be emitted for a missing right operand.
+// Eg: for an expression without parentheses (like i + j + k + l), each
+// of the four identifiers is an operand. For i + j*k + l, j*k is not an
+// operand expression, it's an op expression and will be parsed via
+// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
+// -l are valid operands that will be parsed by this function.
+AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
+ switch (getToken().getKind()) {
+ case Token::bare_identifier:
+ return parseBareIdExpr();
+ case Token::kw_symbol:
+ return parseSymbolSSAIdExpr();
+ case Token::percent_identifier:
+ return parseSSAIdExpr(/*isSymbol=*/false);
+ case Token::integer:
+ return parseIntegerExpr();
+ case Token::l_paren:
+ return parseParentheticalExpr();
+ case Token::minus:
+ return parseNegateExpression(lhs);
+ case Token::kw_ceildiv:
+ case Token::kw_floordiv:
+ case Token::kw_mod:
+ case Token::plus:
+ case Token::star:
+ if (lhs)
+ emitError("missing right operand of binary operator");
+ else
+ emitError("missing left operand of binary operator");
+ return nullptr;
+ default:
+ if (lhs)
+ emitError("missing right operand of binary operator");
+ else
+ emitError("expected affine expression");
+ return nullptr;
+ }
+}
+
+/// Parse affine expressions that are bare-id's, integer constants,
+/// parenthetical affine expressions, and affine op expressions that are a
+/// composition of those.
+///
+/// All binary op's associate from left to right.
+///
+/// {add, sub} have lower precedence than {mul, div, and mod}.
+///
+/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
+/// ceildiv, and mod are at the same higher precedence level. Negation has
+/// higher precedence than any binary op.
+///
+/// llhs: the affine expression appearing on the left of the one being parsed.
+/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
+/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
+/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
+/// associativity.
+///
+/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
+/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
+/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
+AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
+ AffineLowPrecOp llhsOp) {
+ AffineExpr lhs;
+ if (!(lhs = parseAffineOperandExpr(llhs)))
+ return nullptr;
+
+ // Found an LHS. Deal with the ops.
+ if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
+ if (llhs) {
+ AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ return parseAffineLowPrecOpExpr(sum, lOp);
+ }
+ // No LLHS, get RHS and form the expression.
+ return parseAffineLowPrecOpExpr(lhs, lOp);
+ }
+ auto opLoc = getToken().getLoc();
+ if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
+ // We have a higher precedence op here. Get the rhs operand for the llhs
+ // through parseAffineHighPrecOpExpr.
+ AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
+ if (!highRes)
+ return nullptr;
+
+ // If llhs is null, the product forms the first operand of the yet to be
+ // found expression. If non-null, the op to associate with llhs is llhsOp.
+ AffineExpr expr =
+ llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
+
+ // Recurse for subsequent low prec op's after the affine high prec op
+ // expression.
+ if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
+ return parseAffineLowPrecOpExpr(expr, nextOp);
+ return expr;
+ }
+ // Last operand in the expression list.
+ if (llhs)
+ return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
+ // No llhs, 'lhs' itself is the expression.
+ return lhs;
+}
+
+/// Parse an affine expression.
+/// affine-expr ::= `(` affine-expr `)`
+/// | `-` affine-expr
+/// | affine-expr `+` affine-expr
+/// | affine-expr `-` affine-expr
+/// | affine-expr `*` affine-expr
+/// | affine-expr `floordiv` affine-expr
+/// | affine-expr `ceildiv` affine-expr
+/// | affine-expr `mod` affine-expr
+/// | bare-id
+/// | integer-literal
+///
+/// Additional conditions are checked depending on the production. For eg.,
+/// one of the operands for `*` has to be either constant/symbolic; the second
+/// operand for floordiv, ceildiv, and mod has to be a positive integer.
+AffineExpr AffineParser::parseAffineExpr() {
+ return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
+}
+
+/// Parse a dim or symbol from the lists appearing before the actual
+/// expressions of the affine map. Update our state to store the
+/// dimensional/symbolic identifier.
+ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
+ if (getToken().isNot(Token::bare_identifier))
+ return emitError("expected bare identifier");
+
+ auto name = getTokenSpelling();
+ for (auto entry : dimsAndSymbols) {
+ if (entry.first == name)
+ return emitError("redefinition of identifier '" + name + "'");
+ }
+ consumeToken(Token::bare_identifier);
+
+ dimsAndSymbols.push_back({name, idExpr});
+ return success();
+}
+
+/// Parse the list of dimensional identifiers to an affine map.
+ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of dimensional identifiers list")) {
+ return failure();
+ }
+
+ auto parseElt = [&]() -> ParseResult {
+ auto dimension = getAffineDimExpr(numDims++, getContext());
+ return parseIdentifierDefinition(dimension);
+ };
+ return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
+ consumeToken(Token::l_square);
+ auto parseElt = [&]() -> ParseResult {
+ auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
+ return parseIdentifierDefinition(symbol);
+ };
+ return parseCommaSeparatedListUntil(Token::r_square, parseElt);
+}
+
+/// Parse the list of symbolic identifiers to an affine map.
+ParseResult
+AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
+ unsigned &numSymbols) {
+ if (parseDimIdList(numDims)) {
+ return failure();
+ }
+ if (!getToken().is(Token::l_square)) {
+ numSymbols = 0;
+ return success();
+ }
+ return parseSymbolIdList(numSymbols);
+}
+
+/// Parses an ambiguous affine map or integer set definition inline.
+ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
+ IntegerSet &set) {
+ unsigned numDims = 0, numSymbols = 0;
+
+ // List of dimensional and optional symbol identifiers.
+ if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) {
+ return failure();
+ }
+
+ // This is needed for parsing attributes as we wouldn't know whether we would
+ // be parsing an integer set attribute or an affine map attribute.
+ bool isArrow = getToken().is(Token::arrow);
+ bool isColon = getToken().is(Token::colon);
+ if (!isArrow && !isColon) {
+ return emitError("expected '->' or ':'");
+ } else if (isArrow) {
+ parseToken(Token::arrow, "expected '->' or '['");
+ map = parseAffineMapRange(numDims, numSymbols);
+ return map ? success() : failure();
+ } else if (parseToken(Token::colon, "expected ':' or '['")) {
+ return failure();
+ }
+
+ if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
+ return success();
+
+ return failure();
+}
+
+/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ParseResult
+AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
+ OpAsmParser::Delimiter delimiter) {
+ Token::Kind rightToken;
+ switch (delimiter) {
+ case OpAsmParser::Delimiter::Square:
+ if (parseToken(Token::l_square, "expected '['"))
+ return failure();
+ rightToken = Token::r_square;
+ break;
+ case OpAsmParser::Delimiter::Paren:
+ if (parseToken(Token::l_paren, "expected '('"))
+ return failure();
+ rightToken = Token::r_paren;
+ break;
+ default:
+ return emitError("unexpected delimiter");
+ }
+
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseAffineExpr();
+ exprs.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions); the list can be empty. Grammar:
+ // multi-dim-affine-expr ::= `(` `)`
+ // | `(` affine-expr (`,` affine-expr)* `)`
+ if (parseCommaSeparatedListUntil(rightToken, parseElt,
+ /*allowEmptyList=*/true))
+ return failure();
+ // Parsed a valid affine map.
+ map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
+ exprs, getContext());
+ return success();
+}
+
+/// Parse the range and sizes affine map definition inline.
+///
+/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
+///
+/// multi-dim-affine-expr ::= `(` `)`
+/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
+AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
+ unsigned numSymbols) {
+ parseToken(Token::l_paren, "expected '(' at start of affine map range");
+
+ SmallVector<AffineExpr, 4> exprs;
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseAffineExpr();
+ ParseResult res = elt ? success() : failure();
+ exprs.push_back(elt);
+ return res;
+ };
+
+ // Parse a multi-dimensional affine expression (a comma-separated list of
+ // 1-d affine expressions). Grammar:
+ // multi-dim-affine-expr ::= `(` `)`
+ // | `(` affine-expr (`,` affine-expr)* `)`
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return AffineMap();
+
+ // Parsed a valid affine map.
+ return AffineMap::get(numDims, numSymbols, exprs, getContext());
+}
+
+/// Parse an affine constraint.
+/// affine-constraint ::= affine-expr `>=` `0`
+/// | affine-expr `==` `0`
+///
+/// isEq is set to true if the parsed constraint is an equality, false if it
+/// is an inequality (greater than or equal).
+///
+AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
+ AffineExpr expr = parseAffineExpr();
+ if (!expr)
+ return nullptr;
+
+ if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = false;
+ return expr;
+ }
+ return (emitError("expected '0' after '>='"), nullptr);
+ }
+
+ if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
+ getToken().is(Token::integer)) {
+ auto dim = getToken().getUnsignedIntegerValue();
+ if (dim.hasValue() && dim.getValue() == 0) {
+ consumeToken(Token::integer);
+ *isEq = true;
+ return expr;
+ }
+ return (emitError("expected '0' after '=='"), nullptr);
+ }
+
+ return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
+ nullptr);
+}
+
+/// Parse the constraints that are part of an integer set definition.
+/// integer-set-inline
+/// ::= dim-and-symbol-id-lists `:`
+/// '(' affine-constraint-conjunction? ')'
+/// affine-constraint-conjunction ::= affine-constraint (`,`
+/// affine-constraint)*
+///
+IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
+ unsigned numSymbols) {
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of integer set constraint list"))
+ return IntegerSet();
+
+ SmallVector<AffineExpr, 4> constraints;
+ SmallVector<bool, 4> isEqs;
+ auto parseElt = [&]() -> ParseResult {
+ bool isEq;
+ auto elt = parseAffineConstraint(&isEq);
+ ParseResult res = elt ? success() : failure();
+ if (elt) {
+ constraints.push_back(elt);
+ isEqs.push_back(isEq);
+ }
+ return res;
+ };
+
+ // Parse a list of affine constraints (comma-separated).
+ if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ return IntegerSet();
+
+ // If no constraints were parsed, then treat this as a degenerate 'true' case.
+ if (constraints.empty()) {
+ /* 0 == 0 */
+ auto zero = getAffineConstantExpr(0, getContext());
+ return IntegerSet::get(numDims, numSymbols, zero, true);
+ }
+
+ // Parsed a valid integer set.
+ return IntegerSet::get(numDims, numSymbols, constraints, isEqs);
+}
+
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
+
+/// Parse an ambiguous reference to either and affine map or an integer set.
+ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
+ IntegerSet &set) {
+ return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
+}
+ParseResult Parser::parseAffineMapReference(AffineMap &map) {
+ llvm::SMLoc curLoc = getToken().getLoc();
+ IntegerSet set;
+ if (parseAffineMapOrIntegerSetReference(map, set))
+ return failure();
+ if (set)
+ return emitError(curLoc, "expected AffineMap, but got IntegerSet");
+ return success();
+}
+ParseResult Parser::parseIntegerSetReference(IntegerSet &set) {
+ llvm::SMLoc curLoc = getToken().getLoc();
+ AffineMap map;
+ if (parseAffineMapOrIntegerSetReference(map, set))
+ return failure();
+ if (map)
+ return emitError(curLoc, "expected IntegerSet, but got AffineMap");
+ return success();
+}
+
+/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
+/// parse SSA value uses encountered while parsing affine expressions.
+ParseResult
+Parser::parseAffineMapOfSSAIds(AffineMap &map,
+ function_ref<ParseResult(bool)> parseElement,
+ OpAsmParser::Delimiter delimiter) {
+ return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
+ .parseAffineMapOfSSAIds(map, delimiter);
+}
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
new file mode 100644
index 000000000000..ebbb1293f19d
--- /dev/null
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -0,0 +1,910 @@
+//===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the parser for the MLIR Types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringExtras.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+/// Parse an arbitrary attribute.
+///
+/// attribute-value ::= `unit`
+/// | bool-literal
+/// | integer-literal (`:` (index-type | integer-type))?
+/// | float-literal (`:` float-type)?
+/// | string-literal (`:` type)?
+/// | type
+/// | `[` (attribute-value (`,` attribute-value)*)? `]`
+/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
+/// | symbol-ref-id (`::` symbol-ref-id)*
+/// | `dense` `<` attribute-value `>` `:`
+/// (tensor-type | vector-type)
+/// | `sparse` `<` attribute-value `,` attribute-value `>`
+/// `:` (tensor-type | vector-type)
+/// | `opaque` `<` dialect-namespace `,` hex-string-literal
+/// `>` `:` (tensor-type | vector-type)
+/// | extended-attribute
+///
+Attribute Parser::parseAttribute(Type type) {
+ switch (getToken().getKind()) {
+ // Parse an AffineMap or IntegerSet attribute.
+ case Token::kw_affine_map: {
+ consumeToken(Token::kw_affine_map);
+
+ AffineMap map;
+ if (parseToken(Token::less, "expected '<' in affine map") ||
+ parseAffineMapReference(map) ||
+ parseToken(Token::greater, "expected '>' in affine map"))
+ return Attribute();
+ return AffineMapAttr::get(map);
+ }
+ case Token::kw_affine_set: {
+ consumeToken(Token::kw_affine_set);
+
+ IntegerSet set;
+ if (parseToken(Token::less, "expected '<' in integer set") ||
+ parseIntegerSetReference(set) ||
+ parseToken(Token::greater, "expected '>' in integer set"))
+ return Attribute();
+ return IntegerSetAttr::get(set);
+ }
+
+ // Parse an array attribute.
+ case Token::l_square: {
+ consumeToken(Token::l_square);
+
+ SmallVector<Attribute, 4> elements;
+ auto parseElt = [&]() -> ParseResult {
+ elements.push_back(parseAttribute());
+ return elements.back() ? success() : failure();
+ };
+
+ if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
+ return nullptr;
+ return builder.getArrayAttr(elements);
+ }
+
+ // Parse a boolean attribute.
+ case Token::kw_false:
+ consumeToken(Token::kw_false);
+ return builder.getBoolAttr(false);
+ case Token::kw_true:
+ consumeToken(Token::kw_true);
+ return builder.getBoolAttr(true);
+
+ // Parse a dense elements attribute.
+ case Token::kw_dense:
+ return parseDenseElementsAttr(type);
+
+ // Parse a dictionary attribute.
+ case Token::l_brace: {
+ NamedAttrList elements;
+ if (parseAttributeDict(elements))
+ return nullptr;
+ return elements.getDictionary(getContext());
+ }
+
+ // Parse an extended attribute, i.e. alias or dialect attribute.
+ case Token::hash_identifier:
+ return parseExtendedAttr(type);
+
+ // Parse floating point and integer attributes.
+ case Token::floatliteral:
+ return parseFloatAttr(type, /*isNegative=*/false);
+ case Token::integer:
+ return parseDecOrHexAttr(type, /*isNegative=*/false);
+ case Token::minus: {
+ consumeToken(Token::minus);
+ if (getToken().is(Token::integer))
+ return parseDecOrHexAttr(type, /*isNegative=*/true);
+ if (getToken().is(Token::floatliteral))
+ return parseFloatAttr(type, /*isNegative=*/true);
+
+ return (emitError("expected constant integer or floating point value"),
+ nullptr);
+ }
+
+ // Parse a location attribute.
+ case Token::kw_loc: {
+ LocationAttr attr;
+ return failed(parseLocation(attr)) ? Attribute() : attr;
+ }
+
+ // Parse an opaque elements attribute.
+ case Token::kw_opaque:
+ return parseOpaqueElementsAttr(type);
+
+ // Parse a sparse elements attribute.
+ case Token::kw_sparse:
+ return parseSparseElementsAttr(type);
+
+ // Parse a string attribute.
+ case Token::string: {
+ auto val = getToken().getStringValue();
+ consumeToken(Token::string);
+ // Parse the optional trailing colon type if one wasn't explicitly provided.
+ if (!type && consumeIf(Token::colon) && !(type = parseType()))
+ return Attribute();
+
+ return type ? StringAttr::get(val, type)
+ : StringAttr::get(val, getContext());
+ }
+
+ // Parse a symbol reference attribute.
+ case Token::at_identifier: {
+ std::string nameStr = getToken().getSymbolReference();
+ consumeToken(Token::at_identifier);
+
+ // Parse any nested references.
+ std::vector<FlatSymbolRefAttr> nestedRefs;
+ while (getToken().is(Token::colon)) {
+ // Check for the '::' prefix.
+ const char *curPointer = getToken().getLoc().getPointer();
+ consumeToken(Token::colon);
+ if (!consumeIf(Token::colon)) {
+ state.lex.resetPointer(curPointer);
+ consumeToken();
+ break;
+ }
+ // Parse the reference itself.
+ auto curLoc = getToken().getLoc();
+ if (getToken().isNot(Token::at_identifier)) {
+ emitError(curLoc, "expected nested symbol reference identifier");
+ return Attribute();
+ }
+
+ std::string nameStr = getToken().getSymbolReference();
+ consumeToken(Token::at_identifier);
+ nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+ }
+
+ return builder.getSymbolRefAttr(nameStr, nestedRefs);
+ }
+
+ // Parse a 'unit' attribute.
+ case Token::kw_unit:
+ consumeToken(Token::kw_unit);
+ return builder.getUnitAttr();
+
+ default:
+ // Parse a type attribute.
+ if (Type type = parseType())
+ return TypeAttr::get(type);
+ return nullptr;
+ }
+}
+
+/// Attribute dictionary.
+///
+/// attribute-dict ::= `{` `}`
+/// | `{` attribute-entry (`,` attribute-entry)* `}`
+/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
+///
+ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
+ if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
+ return failure();
+
+ llvm::SmallDenseSet<Identifier> seenKeys;
+ auto parseElt = [&]() -> ParseResult {
+ // The name of an attribute can either be a bare identifier, or a string.
+ Optional<Identifier> nameId;
+ if (getToken().is(Token::string))
+ nameId = builder.getIdentifier(getToken().getStringValue());
+ else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
+ getToken().isKeyword())
+ nameId = builder.getIdentifier(getTokenSpelling());
+ else
+ return emitError("expected attribute name");
+ if (!seenKeys.insert(*nameId).second)
+ return emitError("duplicate key in dictionary attribute");
+ consumeToken();
+
+ // Try to parse the '=' for the attribute value.
+ if (!consumeIf(Token::equal)) {
+ // If there is no '=', we treat this as a unit attribute.
+ attributes.push_back({*nameId, builder.getUnitAttr()});
+ return success();
+ }
+
+ auto attr = parseAttribute();
+ if (!attr)
+ return failure();
+ attributes.push_back({*nameId, attr});
+ return success();
+ };
+
+ if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+ return failure();
+
+ return success();
+}
+
+/// Parse a float attribute.
+Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
+ auto val = getToken().getFloatingPointValue();
+ if (!val.hasValue())
+ return (emitError("floating point value too large for attribute"), nullptr);
+ consumeToken(Token::floatliteral);
+ if (!type) {
+ // Default to F64 when no type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getF64Type();
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+ if (!type.isa<FloatType>())
+ return (emitError("floating point value not valid for specified type"),
+ nullptr);
+ return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
+}
+
+/// Construct a float attribute bitwise equivalent to the integer literal.
+static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
+ uint64_t value) {
+ if (type.isF64())
+ return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
+
+ APInt apInt(type.getWidth(), value);
+ if (apInt != value) {
+ p->emitError("hexadecimal float constant out of range for type");
+ return llvm::None;
+ }
+ return APFloat(type.getFloatSemantics(), apInt);
+}
+
+/// Construct an APint from a parsed value, a known attribute type and
+/// sign.
+static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
+ StringRef spelling) {
+ // Parse the integer value into an APInt that is big enough to hold the value.
+ APInt result;
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+ if (spelling.getAsInteger(isHex ? 0 : 10, result))
+ return llvm::None;
+
+ // Extend or truncate the bitwidth to the right size.
+ unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
+ : type.getIntOrFloatBitWidth();
+ if (width > result.getBitWidth()) {
+ result = result.zext(width);
+ } else if (width < result.getBitWidth()) {
+ // The parser can return an unnecessarily wide result with leading zeros.
+ // This isn't a problem, but truncating off bits is bad.
+ if (result.countLeadingZeros() < result.getBitWidth() - width)
+ return llvm::None;
+
+ result = result.trunc(width);
+ }
+
+ if (isNegative) {
+ // The value is negative, we have an overflow if the sign bit is not set
+ // in the negated apInt.
+ result.negate();
+ if (!result.isSignBitSet())
+ return llvm::None;
+ } else if ((type.isSignedInteger() || type.isIndex()) &&
+ result.isSignBitSet()) {
+ // The value is a positive signed integer or index,
+ // we have an overflow if the sign bit is set.
+ return llvm::None;
+ }
+
+ return result;
+}
+
+/// Parse a decimal or a hexadecimal literal, which can be either an integer
+/// or a float attribute.
+Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
+ // Remember if the literal is hexadecimal.
+ StringRef spelling = getToken().getSpelling();
+ auto loc = state.curToken.getLoc();
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ consumeToken(Token::integer);
+ if (!type) {
+ // Default to i64 if not type is specified.
+ if (!consumeIf(Token::colon))
+ type = builder.getIntegerType(64);
+ else if (!(type = parseType()))
+ return nullptr;
+ }
+
+ if (auto floatType = type.dyn_cast<FloatType>()) {
+ if (isNegative)
+ return emitError(
+ loc,
+ "hexadecimal float literal should not have a leading minus"),
+ nullptr;
+ if (!isHex) {
+ emitError(loc, "unexpected decimal integer literal for a float attribute")
+ .attachNote()
+ << "add a trailing dot to make the literal a float";
+ return nullptr;
+ }
+
+ auto val = Token::getUInt64IntegerValue(spelling);
+ if (!val.hasValue())
+ return emitError("integer constant out of range for attribute"), nullptr;
+
+ // Construct a float attribute bitwise equivalent to the integer literal.
+ Optional<APFloat> apVal =
+ buildHexadecimalFloatLiteral(this, floatType, *val);
+ return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
+ }
+
+ if (!type.isa<IntegerType>() && !type.isa<IndexType>())
+ return emitError(loc, "integer literal not valid for specified type"),
+ nullptr;
+
+ if (isNegative && type.isUnsignedInteger()) {
+ emitError(loc,
+ "negative integer literal not valid for unsigned integer type");
+ return nullptr;
+ }
+
+ Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
+ if (!apInt)
+ return emitError(loc, "integer constant out of range for attribute"),
+ nullptr;
+ return builder.getIntegerAttr(type, *apInt);
+}
+
+//===----------------------------------------------------------------------===//
+// TensorLiteralParser
+//===----------------------------------------------------------------------===//
+
+/// Parse elements values stored within a hex etring. On success, the values are
+/// stored into 'result'.
+static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
+ std::string &result) {
+ std::string val = tok.getStringValue();
+ if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+ return parser.emitError(tok.getLoc(),
+ "elements hex string should start with '0x'");
+
+ StringRef hexValues = StringRef(val).drop_front(2);
+ if (!llvm::all_of(hexValues, llvm::isHexDigit))
+ return parser.emitError(tok.getLoc(),
+ "elements hex string only contains hex digits");
+
+ result = llvm::fromHex(hexValues);
+ return success();
+}
+
+namespace {
+/// This class implements a parser for TensorLiterals. A tensor literal is
+/// either a single element (e.g, 5) or a multi-dimensional list of elements
+/// (e.g., [[5, 5]]).
+class TensorLiteralParser {
+public:
+ TensorLiteralParser(Parser &p) : p(p) {}
+
+ /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+ /// may also parse a tensor literal that is store as a hex string.
+ ParseResult parse(bool allowHex);
+
+ /// Build a dense attribute instance with the parsed elements and the given
+ /// shaped type.
+ DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
+
+ ArrayRef<int64_t> getShape() const { return shape; }
+
+private:
+ /// Get the parsed elements for an integer attribute.
+ ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+ std::vector<APInt> &intValues);
+
+ /// Get the parsed elements for a float attribute.
+ ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+ std::vector<APFloat> &floatValues);
+
+ /// Build a Dense String attribute for the given type.
+ DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
+
+ /// Build a Dense attribute with hex data for the given type.
+ DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
+
+ /// Parse a single element, returning failure if it isn't a valid element
+ /// literal. For example:
+ /// parseElement(1) -> Success, 1
+ /// parseElement([1]) -> Failure
+ ParseResult parseElement();
+
+ /// Parse a list of either lists or elements, returning the dimensions of the
+ /// parsed sub-tensors in dims. For example:
+ /// parseList([1, 2, 3]) -> Success, [3]
+ /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+ /// parseList([[1, 2], 3]) -> Failure
+ /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ ParseResult parseList(SmallVectorImpl<int64_t> &dims);
+
+ /// Parse a literal that was printed as a hex string.
+ ParseResult parseHexElements();
+
+ Parser &p;
+
+ /// The shape inferred from the parsed elements.
+ SmallVector<int64_t, 4> shape;
+
+ /// Storage used when parsing elements, this is a pair of <is_negated, token>.
+ std::vector<std::pair<bool, Token>> storage;
+
+ /// Storage used when parsing elements that were stored as hex values.
+ Optional<Token> hexStorage;
+};
+} // end anonymous namespace
+
+/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+/// may also parse a tensor literal that is store as a hex string.
+ParseResult TensorLiteralParser::parse(bool allowHex) {
+ // If hex is allowed, check for a string literal.
+ if (allowHex && p.getToken().is(Token::string)) {
+ hexStorage = p.getToken();
+ p.consumeToken(Token::string);
+ return success();
+ }
+ // Otherwise, parse a list or an individual element.
+ if (p.getToken().is(Token::l_square))
+ return parseList(shape);
+ return parseElement();
+}
+
+/// Build a dense attribute instance with the parsed elements and the given
+/// shaped type.
+DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
+ ShapedType type) {
+ Type eltType = type.getElementType();
+
+ // Check to see if we parse the literal from a hex string.
+ if (hexStorage.hasValue() &&
+ (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
+ return getHexAttr(loc, type);
+
+ // Check that the parsed storage size has the same number of elements to the
+ // type, or is a known splat.
+ if (!shape.empty() && getShape() != type.getShape()) {
+ p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
+ << "]) does not match type ([" << type.getShape() << "])";
+ return nullptr;
+ }
+
+ // Handle complex types in the specific element type cases below.
+ bool isComplex = false;
+ if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
+ eltType = complexTy.getElementType();
+ isComplex = true;
+ }
+
+ // Handle integer and index types.
+ if (eltType.isIntOrIndex()) {
+ std::vector<APInt> intValues;
+ if (failed(getIntAttrElements(loc, eltType, intValues)))
+ return nullptr;
+ if (isComplex) {
+ // If this is a complex, treat the parsed values as complex values.
+ auto complexData = llvm::makeArrayRef(
+ reinterpret_cast<std::complex<APInt> *>(intValues.data()),
+ intValues.size() / 2);
+ return DenseElementsAttr::get(type, complexData);
+ }
+ return DenseElementsAttr::get(type, intValues);
+ }
+ // Handle floating point types.
+ if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
+ std::vector<APFloat> floatValues;
+ if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
+ return nullptr;
+ if (isComplex) {
+ // If this is a complex, treat the parsed values as complex values.
+ auto complexData = llvm::makeArrayRef(
+ reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
+ floatValues.size() / 2);
+ return DenseElementsAttr::get(type, complexData);
+ }
+ return DenseElementsAttr::get(type, floatValues);
+ }
+
+ // Other types are assumed to be string representations.
+ return getStringAttr(loc, type, type.getElementType());
+}
+
+/// Build a Dense Integer attribute for the given type.
+ParseResult
+TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
+ std::vector<APInt> &intValues) {
+ intValues.reserve(storage.size());
+ bool isUintType = eltTy.isUnsignedInteger();
+ for (const auto &signAndToken : storage) {
+ bool isNegative = signAndToken.first;
+ const Token &token = signAndToken.second;
+ auto tokenLoc = token.getLoc();
+
+ if (isNegative && isUintType) {
+ return p.emitError(tokenLoc)
+ << "expected unsigned integer elements, but parsed negative value";
+ }
+
+ // Check to see if floating point values were parsed.
+ if (token.is(Token::floatliteral)) {
+ return p.emitError(tokenLoc)
+ << "expected integer elements, but parsed floating-point";
+ }
+
+ assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
+ "unexpected token type");
+ if (token.isAny(Token::kw_true, Token::kw_false)) {
+ if (!eltTy.isInteger(1)) {
+ return p.emitError(tokenLoc)
+ << "expected i1 type for 'true' or 'false' values";
+ }
+ APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
+ intValues.push_back(apInt);
+ continue;
+ }
+
+ // Create APInt values for each element with the correct bitwidth.
+ Optional<APInt> apInt =
+ buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
+ if (!apInt)
+ return p.emitError(tokenLoc, "integer constant out of range for type");
+ intValues.push_back(*apInt);
+ }
+ return success();
+}
+
+/// Build a Dense Float attribute for the given type.
+ParseResult
+TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
+ std::vector<APFloat> &floatValues) {
+ floatValues.reserve(storage.size());
+ for (const auto &signAndToken : storage) {
+ bool isNegative = signAndToken.first;
+ const Token &token = signAndToken.second;
+
+ // Handle hexadecimal float literals.
+ if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
+ if (isNegative) {
+ return p.emitError(token.getLoc())
+ << "hexadecimal float literal should not have a leading minus";
+ }
+ auto val = token.getUInt64IntegerValue();
+ if (!val.hasValue()) {
+ return p.emitError(
+ "hexadecimal float constant out of range for attribute");
+ }
+ Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
+ if (!apVal)
+ return failure();
+ floatValues.push_back(*apVal);
+ continue;
+ }
+
+ // Check to see if any decimal integers or booleans were parsed.
+ if (!token.is(Token::floatliteral))
+ return p.emitError()
+ << "expected floating-point elements, but parsed integer";
+
+ // Build the float values from tokens.
+ auto val = token.getFloatingPointValue();
+ if (!val.hasValue())
+ return p.emitError("floating point value too large for attribute");
+
+ APFloat apVal(isNegative ? -*val : *val);
+ if (!eltTy.isF64()) {
+ bool unused;
+ apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
+ &unused);
+ }
+ floatValues.push_back(apVal);
+ }
+ return success();
+}
+
+/// Build a Dense String attribute for the given type.
+DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
+ ShapedType type,
+ Type eltTy) {
+ if (hexStorage.hasValue()) {
+ auto stringValue = hexStorage.getValue().getStringValue();
+ return DenseStringElementsAttr::get(type, {stringValue});
+ }
+
+ std::vector<std::string> stringValues;
+ std::vector<StringRef> stringRefValues;
+ stringValues.reserve(storage.size());
+ stringRefValues.reserve(storage.size());
+
+ for (auto val : storage) {
+ stringValues.push_back(val.second.getStringValue());
+ stringRefValues.push_back(stringValues.back());
+ }
+
+ return DenseStringElementsAttr::get(type, stringRefValues);
+}
+
+/// Build a Dense attribute with hex data for the given type.
+DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
+ ShapedType type) {
+ Type elementType = type.getElementType();
+ if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
+ p.emitError(loc)
+ << "expected floating-point, integer, or complex element type, got "
+ << elementType;
+ return nullptr;
+ }
+
+ std::string data;
+ if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
+ return nullptr;
+
+ ArrayRef<char> rawData(data.data(), data.size());
+ bool detectedSplat = false;
+ if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
+ p.emitError(loc) << "elements hex data size is invalid for provided type: "
+ << type;
+ return nullptr;
+ }
+
+ return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
+}
+
+ParseResult TensorLiteralParser::parseElement() {
+ switch (p.getToken().getKind()) {
+ // Parse a boolean element.
+ case Token::kw_true:
+ case Token::kw_false:
+ case Token::floatliteral:
+ case Token::integer:
+ storage.emplace_back(/*isNegative=*/false, p.getToken());
+ p.consumeToken();
+ break;
+
+ // Parse a signed integer or a negative floating-point element.
+ case Token::minus:
+ p.consumeToken(Token::minus);
+ if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+ return p.emitError("expected integer or floating point literal");
+ storage.emplace_back(/*isNegative=*/true, p.getToken());
+ p.consumeToken();
+ break;
+
+ case Token::string:
+ storage.emplace_back(/*isNegative=*/false, p.getToken());
+ p.consumeToken();
+ break;
+
+ // Parse a complex element of the form '(' element ',' element ')'.
+ case Token::l_paren:
+ p.consumeToken(Token::l_paren);
+ if (parseElement() ||
+ p.parseToken(Token::comma, "expected ',' between complex elements") ||
+ parseElement() ||
+ p.parseToken(Token::r_paren, "expected ')' after complex elements"))
+ return failure();
+ break;
+
+ default:
+ return p.emitError("expected element literal of primitive type");
+ }
+
+ return success();
+}
+
+/// Parse a list of either lists or elements, returning the dimensions of the
+/// parsed sub-tensors in dims. For example:
+/// parseList([1, 2, 3]) -> Success, [3]
+/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
+/// parseList([[1, 2], 3]) -> Failure
+/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
+ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
+ p.consumeToken(Token::l_square);
+
+ auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
+ const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
+ if (prevDims == newDims)
+ return success();
+ return p.emitError("tensor literal is invalid; ranks are not consistent "
+ "between elements");
+ };
+
+ bool first = true;
+ SmallVector<int64_t, 4> newDims;
+ unsigned size = 0;
+ auto parseCommaSeparatedList = [&]() -> ParseResult {
+ SmallVector<int64_t, 4> thisDims;
+ if (p.getToken().getKind() == Token::l_square) {
+ if (parseList(thisDims))
+ return failure();
+ } else if (parseElement()) {
+ return failure();
+ }
+ ++size;
+ if (!first)
+ return checkDims(newDims, thisDims);
+ newDims = thisDims;
+ first = false;
+ return success();
+ };
+ if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+ return failure();
+
+ // Return the sublists' dimensions with 'size' prepended.
+ dims.clear();
+ dims.push_back(size);
+ dims.append(newDims.begin(), newDims.end());
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ElementsAttr Parser
+//===----------------------------------------------------------------------===//
+
+/// Parse a dense elements attribute.
+Attribute Parser::parseDenseElementsAttr(Type attrType) {
+ consumeToken(Token::kw_dense);
+ if (parseToken(Token::less, "expected '<' after 'dense'"))
+ return nullptr;
+
+ // Parse the literal data.
+ TensorLiteralParser literalParser(*this);
+ if (literalParser.parse(/*allowHex=*/true))
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+
+ auto typeLoc = getToken().getLoc();
+ auto type = parseElementsLiteralType(attrType);
+ if (!type)
+ return nullptr;
+ return literalParser.getAttr(typeLoc, type);
+}
+
+/// Parse an opaque elements attribute.
+Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
+ consumeToken(Token::kw_opaque);
+ if (parseToken(Token::less, "expected '<' after 'opaque'"))
+ return nullptr;
+
+ if (getToken().isNot(Token::string))
+ return (emitError("expected dialect namespace"), nullptr);
+
+ auto name = getToken().getStringValue();
+ auto *dialect = builder.getContext()->getRegisteredDialect(name);
+ // TODO(shpeisman): Allow for having an unknown dialect on an opaque
+ // attribute. Otherwise, it can't be roundtripped without having the dialect
+ // registered.
+ if (!dialect)
+ return (emitError("no registered dialect with namespace '" + name + "'"),
+ nullptr);
+ consumeToken(Token::string);
+
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ Token hexTok = getToken();
+ if (parseToken(Token::string, "elements hex string should start with '0x'") ||
+ parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+ auto type = parseElementsLiteralType(attrType);
+ if (!type)
+ return nullptr;
+
+ std::string data;
+ if (parseElementAttrHexValues(*this, hexTok, data))
+ return nullptr;
+ return OpaqueElementsAttr::get(dialect, type, data);
+}
+
+/// Shaped type for elements attribute.
+///
+/// elements-literal-type ::= vector-type | ranked-tensor-type
+///
+/// This method also checks the type has static shape.
+ShapedType Parser::parseElementsLiteralType(Type type) {
+ // If the user didn't provide a type, parse the colon type for the literal.
+ if (!type) {
+ if (parseToken(Token::colon, "expected ':'"))
+ return nullptr;
+ if (!(type = parseType()))
+ return nullptr;
+ }
+
+ if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
+ emitError("elements literal must be a ranked tensor or vector type");
+ return nullptr;
+ }
+
+ auto sType = type.cast<ShapedType>();
+ if (!sType.hasStaticShape())
+ return (emitError("elements literal type must have static shape"), nullptr);
+
+ return sType;
+}
+
+/// Parse a sparse elements attribute.
+Attribute Parser::parseSparseElementsAttr(Type attrType) {
+ consumeToken(Token::kw_sparse);
+ if (parseToken(Token::less, "Expected '<' after 'sparse'"))
+ return nullptr;
+
+ /// Parse the indices. We don't allow hex values here as we may need to use
+ /// the inferred shape.
+ auto indicesLoc = getToken().getLoc();
+ TensorLiteralParser indiceParser(*this);
+ if (indiceParser.parse(/*allowHex=*/false))
+ return nullptr;
+
+ if (parseToken(Token::comma, "expected ','"))
+ return nullptr;
+
+ /// Parse the values.
+ auto valuesLoc = getToken().getLoc();
+ TensorLiteralParser valuesParser(*this);
+ if (valuesParser.parse(/*allowHex=*/true))
+ return nullptr;
+
+ if (parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+
+ auto type = parseElementsLiteralType(attrType);
+ if (!type)
+ return nullptr;
+
+ // If the indices are a splat, i.e. the literal parser parsed an element and
+ // not a list, we set the shape explicitly. The indices are represented by a
+ // 2-dimensional shape where the second dimension is the rank of the type.
+ // Given that the parsed indices is a splat, we know that we only have one
+ // indice and thus one for the first dimension.
+ auto indiceEltType = builder.getIntegerType(64);
+ ShapedType indicesType;
+ if (indiceParser.getShape().empty()) {
+ indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
+ } else {
+ // Otherwise, set the shape to the one parsed by the literal parser.
+ indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
+ }
+ auto indices = indiceParser.getAttr(indicesLoc, indicesType);
+
+ // If the values are a splat, set the shape explicitly based on the number of
+ // indices. The number of indices is encoded in the first dimension of the
+ // indice shape type.
+ auto valuesEltType = type.getElementType();
+ ShapedType valuesType =
+ valuesParser.getShape().empty()
+ ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
+ : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
+ auto values = valuesParser.getAttr(valuesLoc, valuesType);
+
+ /// Sanity check.
+ if (valuesType.getRank() != 1)
+ return (emitError("expected 1-d tensor for values"), nullptr);
+
+ auto sameShape = (indicesType.getRank() == 1) ||
+ (type.getRank() == indicesType.getDimSize(1));
+ auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
+ if (!sameShape || !sameElementNum) {
+ emitError() << "expected shape ([" << type.getShape()
+ << "]); inferred shape of indices literal (["
+ << indicesType.getShape()
+ << "]); inferred shape of values literal (["
+ << valuesType.getShape() << "])";
+ return nullptr;
+ }
+
+ // Build the sparse elements attribute by the indices and values.
+ return SparseElementsAttr::get(type, indices, values);
+}
diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
index b9ab3f33ba20..4d68c5c839c9 100644
--- a/mlir/lib/Parser/CMakeLists.txt
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -1,7 +1,12 @@
add_mlir_library(MLIRParser
+ AffineParser.cpp
+ AttributeParser.cpp
+ DialectSymbolParser.cpp
Lexer.cpp
+ LocationParser.cpp
Parser.cpp
Token.cpp
+ TypeParser.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Parser
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
new file mode 100644
index 000000000000..9d14d6f4fa4f
--- /dev/null
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -0,0 +1,617 @@
+//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the parser for the dialect symbols, such as extended
+// attributes and types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+using llvm::MemoryBuffer;
+using llvm::SMLoc;
+using llvm::SourceMgr;
+
+namespace {
+/// This class provides the main implementation of the DialectAsmParser that
+/// allows for dialects to parse attributes and types. This allows for dialect
+/// hooking into the main MLIR parsing logic.
+class CustomDialectAsmParser : public DialectAsmParser {
+public:
+ CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
+ : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
+ parser(parser) {}
+ ~CustomDialectAsmParser() override {}
+
+ /// Emit a diagnostic at the specified location and return failure.
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
+ return parser.emitError(loc, message);
+ }
+
+ /// Return a builder which provides useful access to MLIRContext, global
+ /// objects like types and attributes.
+ Builder &getBuilder() const override { return parser.builder; }
+
+ /// Get the location of the next token and store it into the argument. This
+ /// always succeeds.
+ llvm::SMLoc getCurrentLocation() override {
+ return parser.getToken().getLoc();
+ }
+
+ /// Return the location of the original name token.
+ llvm::SMLoc getNameLoc() const override { return nameLoc; }
+
+ /// Re-encode the given source location as an MLIR location and return it.
+ Location getEncodedSourceLoc(llvm::SMLoc loc) override {
+ return parser.getEncodedSourceLocation(loc);
+ }
+
+ /// Returns the full specification of the symbol being parsed. This allows
+ /// for using a separate parser if necessary.
+ StringRef getFullSymbolSpec() const override { return fullSpec; }
+
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ bool negative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+
+ // Check for a floating point value.
+ if (curTok.is(Token::floatliteral)) {
+ auto val = curTok.getFloatingPointValue();
+ if (!val.hasValue())
+ return emitError(curTok.getLoc(), "floating point value too large");
+ parser.consumeToken(Token::floatliteral);
+ result = negative ? -*val : *val;
+ return success();
+ }
+
+ // TODO(riverriddle) support hex floating point values.
+ return emitError(getCurrentLocation(), "expected floating point literal");
+ }
+
+ /// Parse an optional integer value from the stream.
+ OptionalParseResult parseOptionalInteger(uint64_t &result) override {
+ Token curToken = parser.getToken();
+ if (curToken.isNot(Token::integer, Token::minus))
+ return llvm::None;
+
+ bool negative = parser.consumeIf(Token::minus);
+ Token curTok = parser.getToken();
+ if (parser.parseToken(Token::integer, "expected integer value"))
+ return failure();
+
+ auto val = curTok.getUInt64IntegerValue();
+ if (!val)
+ return emitError(curTok.getLoc(), "integer value too large");
+ result = negative ? -*val : *val;
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a `->` token.
+ ParseResult parseArrow() override {
+ return parser.parseToken(Token::arrow, "expected '->'");
+ }
+
+ /// Parses a `->` if present.
+ ParseResult parseOptionalArrow() override {
+ return success(parser.consumeIf(Token::arrow));
+ }
+
+ /// Parse a '{' token.
+ ParseResult parseLBrace() override {
+ return parser.parseToken(Token::l_brace, "expected '{'");
+ }
+
+ /// Parse a '{' token if present
+ ParseResult parseOptionalLBrace() override {
+ return success(parser.consumeIf(Token::l_brace));
+ }
+
+ /// Parse a `}` token.
+ ParseResult parseRBrace() override {
+ return parser.parseToken(Token::r_brace, "expected '}'");
+ }
+
+ /// Parse a `}` token if present
+ ParseResult parseOptionalRBrace() override {
+ return success(parser.consumeIf(Token::r_brace));
+ }
+
+ /// Parse a `:` token.
+ ParseResult parseColon() override {
+ return parser.parseToken(Token::colon, "expected ':'");
+ }
+
+ /// Parse a `:` token if present.
+ ParseResult parseOptionalColon() override {
+ return success(parser.consumeIf(Token::colon));
+ }
+
+ /// Parse a `,` token.
+ ParseResult parseComma() override {
+ return parser.parseToken(Token::comma, "expected ','");
+ }
+
+ /// Parse a `,` token if present.
+ ParseResult parseOptionalComma() override {
+ return success(parser.consumeIf(Token::comma));
+ }
+
+ /// Parses a `...` if present.
+ ParseResult parseOptionalEllipsis() override {
+ return success(parser.consumeIf(Token::ellipsis));
+ }
+
+ /// Parse a `=` token.
+ ParseResult parseEqual() override {
+ return parser.parseToken(Token::equal, "expected '='");
+ }
+
+ /// Parse a `=` token if present.
+ ParseResult parseOptionalEqual() override {
+ return success(parser.consumeIf(Token::equal));
+ }
+
+ /// Parse a '<' token.
+ ParseResult parseLess() override {
+ return parser.parseToken(Token::less, "expected '<'");
+ }
+
+ /// Parse a `<` token if present.
+ ParseResult parseOptionalLess() override {
+ return success(parser.consumeIf(Token::less));
+ }
+
+ /// Parse a '>' token.
+ ParseResult parseGreater() override {
+ return parser.parseToken(Token::greater, "expected '>'");
+ }
+
+ /// Parse a `>` token if present.
+ ParseResult parseOptionalGreater() override {
+ return success(parser.consumeIf(Token::greater));
+ }
+
+ /// Parse a `(` token.
+ ParseResult parseLParen() override {
+ return parser.parseToken(Token::l_paren, "expected '('");
+ }
+
+ /// Parses a '(' if present.
+ ParseResult parseOptionalLParen() override {
+ return success(parser.consumeIf(Token::l_paren));
+ }
+
+ /// Parse a `)` token.
+ ParseResult parseRParen() override {
+ return parser.parseToken(Token::r_paren, "expected ')'");
+ }
+
+ /// Parses a ')' if present.
+ ParseResult parseOptionalRParen() override {
+ return success(parser.consumeIf(Token::r_paren));
+ }
+
+ /// Parse a `[` token.
+ ParseResult parseLSquare() override {
+ return parser.parseToken(Token::l_square, "expected '['");
+ }
+
+ /// Parses a '[' if present.
+ ParseResult parseOptionalLSquare() override {
+ return success(parser.consumeIf(Token::l_square));
+ }
+
+ /// Parse a `]` token.
+ ParseResult parseRSquare() override {
+ return parser.parseToken(Token::r_square, "expected ']'");
+ }
+
+ /// Parses a ']' if present.
+ ParseResult parseOptionalRSquare() override {
+ return success(parser.consumeIf(Token::r_square));
+ }
+
+ /// Parses a '?' if present.
+ ParseResult parseOptionalQuestion() override {
+ return success(parser.consumeIf(Token::question));
+ }
+
+ /// Parses a '*' if present.
+ ParseResult parseOptionalStar() override {
+ return success(parser.consumeIf(Token::star));
+ }
+
+ /// Returns if the current token corresponds to a keyword.
+ bool isCurrentTokenAKeyword() const {
+ return parser.getToken().is(Token::bare_identifier) ||
+ parser.getToken().isKeyword();
+ }
+
+ /// Parse the given keyword if present.
+ ParseResult parseOptionalKeyword(StringRef keyword) override {
+ // Check that the current token has the same spelling.
+ if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
+ return failure();
+ parser.consumeToken();
+ return success();
+ }
+
+ /// Parse a keyword, if present, into 'keyword'.
+ ParseResult parseOptionalKeyword(StringRef *keyword) override {
+ // Check that the current token is a keyword.
+ if (!isCurrentTokenAKeyword())
+ return failure();
+
+ *keyword = parser.getTokenSpelling();
+ parser.consumeToken();
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute and return it in result.
+ ParseResult parseAttribute(Attribute &result, Type type) override {
+ result = parser.parseAttribute(type);
+ return success(static_cast<bool>(result));
+ }
+
+ /// Parse an affine map instance into 'map'.
+ ParseResult parseAffineMap(AffineMap &map) override {
+ return parser.parseAffineMapReference(map);
+ }
+
+ /// Parse an integer set instance into 'set'.
+ ParseResult printIntegerSet(IntegerSet &set) override {
+ return parser.parseIntegerSetReference(set);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult parseType(Type &result) override {
+ result = parser.parseType();
+ return success(static_cast<bool>(result));
+ }
+
+ ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) override {
+ return parser.parseDimensionListRanked(dimensions, allowDynamic);
+ }
+
+private:
+ /// The full symbol specification.
+ StringRef fullSpec;
+
+ /// The source location of the dialect symbol.
+ SMLoc nameLoc;
+
+ /// The main parser.
+ Parser &parser;
+};
+} // namespace
+
+/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
+/// and may be recursive. Return with the 'prettyName' StringRef encompassing
+/// the entire pretty name.
+///
+/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
+/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
+/// | '(' pretty-dialect-sym-contents+ ')'
+/// | '[' pretty-dialect-sym-contents+ ']'
+/// | '{' pretty-dialect-sym-contents+ '}'
+/// | '[^[<({>\])}\0]+'
+///
+ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
+ // Pretty symbol names are a relatively unstructured format that contains a
+ // series of properly nested punctuation, with anything else in the middle.
+ // Scan ahead to find it and consume it if successful, otherwise emit an
+ // error.
+ auto *curPtr = getTokenSpelling().data();
+
+ SmallVector<char, 8> nestedPunctuation;
+
+ // Scan over the nested punctuation, bailing out on error and consuming until
+ // we find the end. We know that we're currently looking at the '<', so we
+ // can go until we find the matching '>' character.
+ assert(*curPtr == '<');
+ do {
+ char c = *curPtr++;
+ switch (c) {
+ case '\0':
+ // This also handles the EOF case.
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ case '<':
+ case '[':
+ case '(':
+ case '{':
+ nestedPunctuation.push_back(c);
+ continue;
+
+ case '-':
+ // The sequence `->` is treated as special token.
+ if (*curPtr == '>')
+ ++curPtr;
+ continue;
+
+ case '>':
+ if (nestedPunctuation.pop_back_val() != '<')
+ return emitError("unbalanced '>' character in pretty dialect name");
+ break;
+ case ']':
+ if (nestedPunctuation.pop_back_val() != '[')
+ return emitError("unbalanced ']' character in pretty dialect name");
+ break;
+ case ')':
+ if (nestedPunctuation.pop_back_val() != '(')
+ return emitError("unbalanced ')' character in pretty dialect name");
+ break;
+ case '}':
+ if (nestedPunctuation.pop_back_val() != '{')
+ return emitError("unbalanced '}' character in pretty dialect name");
+ break;
+
+ default:
+ continue;
+ }
+ } while (!nestedPunctuation.empty());
+
+ // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
+ // consuming all this stuff, and return.
+ state.lex.resetPointer(curPtr);
+
+ unsigned length = curPtr - prettyName.begin();
+ prettyName = StringRef(prettyName.begin(), length);
+ consumeToken();
+ return success();
+}
+
+/// Parse an extended dialect symbol.
+template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
+static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
+ SymbolAliasMap &aliases,
+ CreateFn &&createSymbol) {
+ // Parse the dialect namespace.
+ StringRef identifier = p.getTokenSpelling().drop_front();
+ auto loc = p.getToken().getLoc();
+ p.consumeToken(identifierTok);
+
+ // If there is no '<' token following this, and if the typename contains no
+ // dot, then we are parsing a symbol alias.
+ if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
+ // Check for an alias for this type.
+ auto aliasIt = aliases.find(identifier);
+ if (aliasIt == aliases.end())
+ return (p.emitError("undefined symbol alias id '" + identifier + "'"),
+ nullptr);
+ return aliasIt->second;
+ }
+
+ // Otherwise, we are parsing a dialect-specific symbol. If the name contains
+ // a dot, then this is the "pretty" form. If not, it is the verbose form that
+ // looks like <"...">.
+ std::string symbolData;
+ auto dialectName = identifier;
+
+ // Handle the verbose form, where "identifier" is a simple dialect name.
+ if (!identifier.contains('.')) {
+ // Consume the '<'.
+ if (p.parseToken(Token::less, "expected '<' in dialect type"))
+ return nullptr;
+
+ // Parse the symbol specific data.
+ if (p.getToken().isNot(Token::string))
+ return (p.emitError("expected string literal data in dialect symbol"),
+ nullptr);
+ symbolData = p.getToken().getStringValue();
+ loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
+ p.consumeToken(Token::string);
+
+ // Consume the '>'.
+ if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
+ return nullptr;
+ } else {
+ // Ok, the dialect name is the part of the identifier before the dot, the
+ // part after the dot is the dialect's symbol, or the start thereof.
+ auto dotHalves = identifier.split('.');
+ dialectName = dotHalves.first;
+ auto prettyName = dotHalves.second;
+ loc = llvm::SMLoc::getFromPointer(prettyName.data());
+
+ // If the dialect's symbol is followed immediately by a <, then lex the body
+ // of it into prettyName.
+ if (p.getToken().is(Token::less) &&
+ prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
+ if (p.parsePrettyDialectSymbolName(prettyName))
+ return nullptr;
+ }
+
+ symbolData = prettyName.str();
+ }
+
+ // Record the name location of the type remapped to the top level buffer.
+ llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
+ p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
+
+ // Call into the provided symbol construction function.
+ Symbol sym = createSymbol(dialectName, symbolData, loc);
+
+ // Pop the last parser location.
+ p.getState().symbols.nestedParserLocs.pop_back();
+ return sym;
+}
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(StringRef inputStr, MLIRContext *context,
+ SymbolState &symbolState, ParserFn &&parserFn,
+ size_t *numRead = nullptr) {
+ SourceMgr sourceMgr;
+ auto memBuffer = MemoryBuffer::getMemBuffer(
+ inputStr, /*BufferName=*/"<mlir_parser_buffer>",
+ /*RequiresNullTerminator=*/false);
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ ParserState state(sourceMgr, context, symbolState);
+ Parser parser(state);
+
+ Token startTok = parser.getToken();
+ T symbol = parserFn(parser);
+ if (!symbol)
+ return T();
+
+ // If 'numRead' is valid, then provide the number of bytes that were read.
+ Token endTok = parser.getToken();
+ if (numRead) {
+ *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
+ startTok.getLoc().getPointer());
+
+ // Otherwise, ensure that all of the tokens were parsed.
+ } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
+ parser.emitError(endTok.getLoc(), "encountered unexpected token");
+ return T();
+ }
+ return symbol;
+}
+
+/// Parse an extended attribute.
+///
+/// extended-attribute ::= (dialect-attribute | attribute-alias)
+/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
+/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
+/// attribute-alias ::= `#` alias-name
+///
+Attribute Parser::parseExtendedAttr(Type type) {
+ Attribute attr = parseExtendedSymbol<Attribute>(
+ *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ llvm::SMLoc loc) -> Attribute {
+ // Parse an optional trailing colon type.
+ Type attrType = type;
+ if (consumeIf(Token::colon) && !(attrType = parseType()))
+ return Attribute();
+
+ // If we found a registered dialect, then ask it to parse the attribute.
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Attribute>(
+ symbolData, state.context, state.symbols, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseAttribute(customParser, attrType);
+ });
+ }
+
+ // Otherwise, form a new opaque attribute.
+ return OpaqueAttr::getChecked(
+ Identifier::get(dialectName, state.context), symbolData,
+ attrType ? attrType : NoneType::get(state.context),
+ getEncodedSourceLocation(loc));
+ });
+
+ // Ensure that the attribute has the same type as requested.
+ if (attr && type && attr.getType() != type) {
+ emitError("attribute type
diff erent than expected: expected ")
+ << type << ", but got " << attr.getType();
+ return nullptr;
+ }
+ return attr;
+}
+
+/// Parse an extended type.
+///
+/// extended-type ::= (dialect-type | type-alias)
+/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
+/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
+/// type-alias ::= `!` alias-name
+///
+Type Parser::parseExtendedType() {
+ return parseExtendedSymbol<Type>(
+ *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
+ [&](StringRef dialectName, StringRef symbolData,
+ llvm::SMLoc loc) -> Type {
+ // If we found a registered dialect, then ask it to parse the type.
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+ return parseSymbol<Type>(
+ symbolData, state.context, state.symbols, [&](Parser &parser) {
+ CustomDialectAsmParser customParser(symbolData, parser);
+ return dialect->parseType(customParser);
+ });
+ }
+
+ // Otherwise, form a new opaque type.
+ return OpaqueType::getChecked(
+ Identifier::get(dialectName, state.context), symbolData,
+ state.context, getEncodedSourceLocation(loc));
+ });
+}
+
+//===----------------------------------------------------------------------===//
+// mlir::parseAttribute/parseType
+//===----------------------------------------------------------------------===//
+
+/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
+/// parsing failed, nullptr is returned. The number of bytes read from the input
+/// string is returned in 'numRead'.
+template <typename T, typename ParserFn>
+static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
+ ParserFn &&parserFn) {
+ SymbolState aliasState;
+ return parseSymbol<T>(
+ inputStr, context, aliasState,
+ [&](Parser &parser) {
+ SourceMgrDiagnosticHandler handler(
+ const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
+ parser.getContext());
+ return parserFn(parser);
+ },
+ &numRead);
+}
+
+Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
+ size_t numRead = 0;
+ return parseAttribute(attrStr, context, numRead);
+}
+Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
+ size_t numRead = 0;
+ return parseAttribute(attrStr, type, numRead);
+}
+
+Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
+ size_t &numRead) {
+ return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
+ return parser.parseAttribute();
+ });
+}
+Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
+ return parseSymbol<Attribute>(
+ attrStr, type.getContext(), numRead,
+ [type](Parser &parser) { return parser.parseAttribute(type); });
+}
+
+Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
+ size_t numRead = 0;
+ return parseType(typeStr, context, numRead);
+}
+
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
+ return parseSymbol<Type>(typeStr, context, numRead,
+ [](Parser &parser) { return parser.parseType(); });
+}
diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp
new file mode 100644
index 000000000000..886514a92492
--- /dev/null
+++ b/mlir/lib/Parser/LocationParser.cpp
@@ -0,0 +1,197 @@
+//===- LocationParser.cpp - MLIR Location Parser -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+/// Parse a location.
+///
+/// location ::= `loc` inline-location
+/// inline-location ::= '(' location-inst ')'
+///
+ParseResult Parser::parseLocation(LocationAttr &loc) {
+ // Check for 'loc' identifier.
+ if (parseToken(Token::kw_loc, "expected 'loc' keyword"))
+ return emitError();
+
+ // Parse the inline-location.
+ if (parseToken(Token::l_paren, "expected '(' in inline location") ||
+ parseLocationInstance(loc) ||
+ parseToken(Token::r_paren, "expected ')' in inline location"))
+ return failure();
+ return success();
+}
+
+/// Specific location instances.
+///
+/// location-inst ::= filelinecol-location |
+/// name-location |
+/// callsite-location |
+/// fused-location |
+/// unknown-location
+/// filelinecol-location ::= string-literal ':' integer-literal
+/// ':' integer-literal
+/// name-location ::= string-literal
+/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')'
+/// fused-location ::= fused ('<' attribute-value '>')?
+/// '[' location-inst (location-inst ',')* ']'
+/// unknown-location ::= 'unknown'
+///
+ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
+ consumeToken(Token::bare_identifier);
+
+ // Parse the '('.
+ if (parseToken(Token::l_paren, "expected '(' in callsite location"))
+ return failure();
+
+ // Parse the callee location.
+ LocationAttr calleeLoc;
+ if (parseLocationInstance(calleeLoc))
+ return failure();
+
+ // Parse the 'at'.
+ if (getToken().isNot(Token::bare_identifier) ||
+ getToken().getSpelling() != "at")
+ return emitError("expected 'at' in callsite location");
+ consumeToken(Token::bare_identifier);
+
+ // Parse the caller location.
+ LocationAttr callerLoc;
+ if (parseLocationInstance(callerLoc))
+ return failure();
+
+ // Parse the ')'.
+ if (parseToken(Token::r_paren, "expected ')' in callsite location"))
+ return failure();
+
+ // Return the callsite location.
+ loc = CallSiteLoc::get(calleeLoc, callerLoc);
+ return success();
+}
+
+ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
+ consumeToken(Token::bare_identifier);
+
+ // Try to parse the optional metadata.
+ Attribute metadata;
+ if (consumeIf(Token::less)) {
+ metadata = parseAttribute();
+ if (!metadata)
+ return emitError("expected valid attribute metadata");
+ // Parse the '>' token.
+ if (parseToken(Token::greater,
+ "expected '>' after fused location metadata"))
+ return failure();
+ }
+
+ SmallVector<Location, 4> locations;
+ auto parseElt = [&] {
+ LocationAttr newLoc;
+ if (parseLocationInstance(newLoc))
+ return failure();
+ locations.push_back(newLoc);
+ return success();
+ };
+
+ if (parseToken(Token::l_square, "expected '[' in fused location") ||
+ parseCommaSeparatedList(parseElt) ||
+ parseToken(Token::r_square, "expected ']' in fused location"))
+ return failure();
+
+ // Return the fused location.
+ loc = FusedLoc::get(locations, metadata, getContext());
+ return success();
+}
+
+ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
+ auto *ctx = getContext();
+ auto str = getToken().getStringValue();
+ consumeToken(Token::string);
+
+ // If the next token is ':' this is a filelinecol location.
+ if (consumeIf(Token::colon)) {
+ // Parse the line number.
+ if (getToken().isNot(Token::integer))
+ return emitError("expected integer line number in FileLineColLoc");
+ auto line = getToken().getUnsignedIntegerValue();
+ if (!line.hasValue())
+ return emitError("expected integer line number in FileLineColLoc");
+ consumeToken(Token::integer);
+
+ // Parse the ':'.
+ if (parseToken(Token::colon, "expected ':' in FileLineColLoc"))
+ return failure();
+
+ // Parse the column number.
+ if (getToken().isNot(Token::integer))
+ return emitError("expected integer column number in FileLineColLoc");
+ auto column = getToken().getUnsignedIntegerValue();
+ if (!column.hasValue())
+ return emitError("expected integer column number in FileLineColLoc");
+ consumeToken(Token::integer);
+
+ loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx);
+ return success();
+ }
+
+ // Otherwise, this is a NameLoc.
+
+ // Check for a child location.
+ if (consumeIf(Token::l_paren)) {
+ auto childSourceLoc = getToken().getLoc();
+
+ // Parse the child location.
+ LocationAttr childLoc;
+ if (parseLocationInstance(childLoc))
+ return failure();
+
+ // The child must not be another NameLoc.
+ if (childLoc.isa<NameLoc>())
+ return emitError(childSourceLoc,
+ "child of NameLoc cannot be another NameLoc");
+ loc = NameLoc::get(Identifier::get(str, ctx), childLoc);
+
+ // Parse the closing ')'.
+ if (parseToken(Token::r_paren,
+ "expected ')' after child location of NameLoc"))
+ return failure();
+ } else {
+ loc = NameLoc::get(Identifier::get(str, ctx), ctx);
+ }
+
+ return success();
+}
+
+ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
+ // Handle either name or filelinecol locations.
+ if (getToken().is(Token::string))
+ return parseNameOrFileLineColLocation(loc);
+
+ // Bare tokens required for other cases.
+ if (!getToken().is(Token::bare_identifier))
+ return emitError("expected location instance");
+
+ // Check for the 'callsite' signifying a callsite location.
+ if (getToken().getSpelling() == "callsite")
+ return parseCallSiteLocation(loc);
+
+ // If the token is 'fused', then this is a fused location.
+ if (getToken().getSpelling() == "fused")
+ return parseFusedLocation(loc);
+
+ // Check for a 'unknown' for an unknown location.
+ if (getToken().getSpelling() == "unknown") {
+ consumeToken(Token::bare_identifier);
+ loc = UnknownLoc::get(getContext());
+ return success();
+ }
+
+ return emitError("expected location instance");
+}
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index be96a39e6789..465709d925e0 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -10,3312 +10,86 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Parser.h"
-#include "Lexer.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/Dialect.h"
-#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/IntegerSet.h"
-#include "mlir/IR/Location.h"
-#include "mlir/IR/MLIRContext.h"
-#include "mlir/IR/Module.h"
-#include "mlir/IR/OpImplementation.h"
-#include "mlir/IR/StandardTypes.h"
-#include "mlir/IR/Verifier.h"
-#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/StringExtras.h"
-#include "llvm/ADT/StringSet.h"
-#include "llvm/ADT/bit.h"
-#include "llvm/Support/PrettyStackTrace.h"
-#include "llvm/Support/SMLoc.h"
-#include "llvm/Support/SourceMgr.h"
-#include <algorithm>
-using namespace mlir;
-using llvm::MemoryBuffer;
-using llvm::SMLoc;
-using llvm::SourceMgr;
-
-namespace {
-class Parser;
-
-//===----------------------------------------------------------------------===//
-// SymbolState
-//===----------------------------------------------------------------------===//
-
-/// This class contains record of any parsed top-level symbols.
-struct SymbolState {
- // A map from attribute alias identifier to Attribute.
- llvm::StringMap<Attribute> attributeAliasDefinitions;
-
- // A map from type alias identifier to Type.
- llvm::StringMap<Type> typeAliasDefinitions;
-
- /// A set of locations into the main parser memory buffer for each of the
- /// active nested parsers. Given that some nested parsers, i.e. custom dialect
- /// parsers, operate on a temporary memory buffer, this provides an anchor
- /// point for emitting diagnostics.
- SmallVector<llvm::SMLoc, 1> nestedParserLocs;
-
- /// The top-level lexer that contains the original memory buffer provided by
- /// the user. This is used by nested parsers to get a properly encoded source
- /// location.
- Lexer *topLevelLexer = nullptr;
-};
-
-//===----------------------------------------------------------------------===//
-// ParserState
-//===----------------------------------------------------------------------===//
-
-/// This class refers to all of the state maintained globally by the parser,
-/// such as the current lexer position etc.
-struct ParserState {
- ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
- SymbolState &symbols)
- : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
- symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
- // Set the top level lexer for the symbol state if one doesn't exist.
- if (!symbols.topLevelLexer)
- symbols.topLevelLexer = &lex;
- }
- ~ParserState() {
- // Reset the top level lexer if it refers the lexer in our state.
- if (symbols.topLevelLexer == &lex)
- symbols.topLevelLexer = nullptr;
- }
- ParserState(const ParserState &) = delete;
- void operator=(const ParserState &) = delete;
-
- /// The context we're parsing into.
- MLIRContext *const context;
-
- /// The lexer for the source file we're parsing.
- Lexer lex;
-
- /// This is the next token that hasn't been consumed yet.
- Token curToken;
-
- /// The current state for symbol parsing.
- SymbolState &symbols;
-
- /// The depth of this parser in the nested parsing stack.
- size_t parserDepth;
-};
-
-//===----------------------------------------------------------------------===//
-// Parser
-//===----------------------------------------------------------------------===//
-
-/// This class implement support for parsing global entities like types and
-/// shared entities like SSA names. It is intended to be subclassed by
-/// specialized subparsers that include state, e.g. when a local symbol table.
-class Parser {
-public:
- Builder builder;
-
- Parser(ParserState &state) : builder(state.context), state(state) {}
-
- // Helper methods to get stuff from the parser-global state.
- ParserState &getState() const { return state; }
- MLIRContext *getContext() const { return state.context; }
- const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
-
- /// Parse a comma-separated list of elements up until the specified end token.
- ParseResult
- parseCommaSeparatedListUntil(Token::Kind rightToken,
- const std::function<ParseResult()> &parseElement,
- bool allowEmptyList = true);
-
- /// Parse a comma separated list of elements that must have at least one entry
- /// in it.
- ParseResult
- parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
-
- ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
-
- // We have two forms of parsing methods - those that return a non-null
- // pointer on success, and those that return a ParseResult to indicate whether
- // they returned a failure. The second class fills in by-reference arguments
- // as the results of their action.
-
- //===--------------------------------------------------------------------===//
- // Error Handling
- //===--------------------------------------------------------------------===//
-
- /// Emit an error and return failure.
- InFlightDiagnostic emitError(const Twine &message = {}) {
- return emitError(state.curToken.getLoc(), message);
- }
- InFlightDiagnostic emitError(SMLoc loc, const Twine &message = {});
-
- /// Encode the specified source location information into an attribute for
- /// attachment to the IR.
- Location getEncodedSourceLocation(llvm::SMLoc loc) {
- // If there are no active nested parsers, we can get the encoded source
- // location directly.
- if (state.parserDepth == 0)
- return state.lex.getEncodedSourceLocation(loc);
- // Otherwise, we need to re-encode it to point to the top level buffer.
- return state.symbols.topLevelLexer->getEncodedSourceLocation(
- remapLocationToTopLevelBuffer(loc));
- }
-
- /// Remaps the given SMLoc to the top level lexer of the parser. This is used
- /// to adjust locations of potentially nested parsers to ensure that they can
- /// be emitted properly as diagnostics.
- llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) {
- // If there are no active nested parsers, we can return location directly.
- SymbolState &symbols = state.symbols;
- if (state.parserDepth == 0)
- return loc;
- assert(symbols.topLevelLexer && "expected valid top-level lexer");
-
- // Otherwise, we need to remap the location to the main parser. This is
- // simply offseting the location onto the location of the last nested
- // parser.
- size_t offset = loc.getPointer() - state.lex.getBufferBegin();
- auto *rawLoc =
- symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset;
- return llvm::SMLoc::getFromPointer(rawLoc);
- }
-
- //===--------------------------------------------------------------------===//
- // Token Parsing
- //===--------------------------------------------------------------------===//
-
- /// Return the current token the parser is inspecting.
- const Token &getToken() const { return state.curToken; }
- StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
-
- /// If the current token has the specified kind, consume it and return true.
- /// If not, return false.
- bool consumeIf(Token::Kind kind) {
- if (state.curToken.isNot(kind))
- return false;
- consumeToken(kind);
- return true;
- }
-
- /// Advance the current lexer onto the next token.
- void consumeToken() {
- assert(state.curToken.isNot(Token::eof, Token::error) &&
- "shouldn't advance past EOF or errors");
- state.curToken = state.lex.lexToken();
- }
-
- /// Advance the current lexer onto the next token, asserting what the expected
- /// current token is. This is preferred to the above method because it leads
- /// to more self-documenting code with better checking.
- void consumeToken(Token::Kind kind) {
- assert(state.curToken.is(kind) && "consumed an unexpected token");
- consumeToken();
- }
-
- /// Consume the specified token if present and return success. On failure,
- /// output a diagnostic and return failure.
- ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
-
- //===--------------------------------------------------------------------===//
- // Type Parsing
- //===--------------------------------------------------------------------===//
-
- ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
- ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
- ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
-
- /// Optionally parse a type.
- OptionalParseResult parseOptionalType(Type &type);
-
- /// Parse an arbitrary type.
- Type parseType();
-
- /// Parse a complex type.
- Type parseComplexType();
-
- /// Parse an extended type.
- Type parseExtendedType();
-
- /// Parse a function type.
- Type parseFunctionType();
-
- /// Parse a memref type.
- Type parseMemRefType();
-
- /// Parse a non function type.
- Type parseNonFunctionType();
-
- /// Parse a tensor type.
- Type parseTensorType();
-
- /// Parse a tuple type.
- Type parseTupleType();
-
- /// Parse a vector type.
- VectorType parseVectorType();
- ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic = true);
- ParseResult parseXInDimensionList();
-
- /// Parse strided layout specification.
- ParseResult parseStridedLayout(int64_t &offset,
- SmallVectorImpl<int64_t> &strides);
-
- // Parse a brace-delimiter list of comma-separated integers with `?` as an
- // unknown marker.
- ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
-
- //===--------------------------------------------------------------------===//
- // Attribute Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an arbitrary attribute with an optional type.
- Attribute parseAttribute(Type type = {});
-
- /// Parse an attribute dictionary.
- ParseResult parseAttributeDict(NamedAttrList &attributes);
-
- /// Parse an extended attribute.
- Attribute parseExtendedAttr(Type type);
-
- /// Parse a float attribute.
- Attribute parseFloatAttr(Type type, bool isNegative);
-
- /// Parse a decimal or a hexadecimal literal, which can be either an integer
- /// or a float attribute.
- Attribute parseDecOrHexAttr(Type type, bool isNegative);
-
- /// Parse an opaque elements attribute.
- Attribute parseOpaqueElementsAttr(Type attrType);
-
- /// Parse a dense elements attribute.
- Attribute parseDenseElementsAttr(Type attrType);
- ShapedType parseElementsLiteralType(Type type);
-
- /// Parse a sparse elements attribute.
- Attribute parseSparseElementsAttr(Type attrType);
-
- //===--------------------------------------------------------------------===//
- // Location Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an inline location.
- ParseResult parseLocation(LocationAttr &loc);
-
- /// Parse a raw location instance.
- ParseResult parseLocationInstance(LocationAttr &loc);
-
- /// Parse a callsite location instance.
- ParseResult parseCallSiteLocation(LocationAttr &loc);
-
- /// Parse a fused location instance.
- ParseResult parseFusedLocation(LocationAttr &loc);
-
- /// Parse a name or FileLineCol location instance.
- ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
-
- /// Parse an optional trailing location.
- ///
- /// trailing-location ::= (`loc` `(` location `)`)?
- ///
- ParseResult parseOptionalTrailingLocation(Location &loc) {
- // If there is a 'loc' we parse a trailing location.
- if (!getToken().is(Token::kw_loc))
- return success();
-
- // Parse the location.
- LocationAttr directLoc;
- if (parseLocation(directLoc))
- return failure();
- loc = directLoc;
- return success();
- }
-
- //===--------------------------------------------------------------------===//
- // Affine Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a reference to either an affine map, or an integer set.
- ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
- IntegerSet &set);
- ParseResult parseAffineMapReference(AffineMap &map);
- ParseResult parseIntegerSetReference(IntegerSet &set);
-
- /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
- ParseResult
- parseAffineMapOfSSAIds(AffineMap &map,
- function_ref<ParseResult(bool)> parseElement,
- OpAsmParser::Delimiter delimiter);
-
-private:
- /// The Parser is subclassed and reinstantiated. Do not add additional
- /// non-trivial state here, add it to the ParserState class.
- ParserState &state;
-};
-} // end anonymous namespace
-
-//===----------------------------------------------------------------------===//
-// Helper methods.
-//===----------------------------------------------------------------------===//
-
-/// Parse a comma separated list of elements that must have at least one entry
-/// in it.
-ParseResult Parser::parseCommaSeparatedList(
- const std::function<ParseResult()> &parseElement) {
- // Non-empty case starts with an element.
- if (parseElement())
- return failure();
-
- // Otherwise we have a list of comma separated elements.
- while (consumeIf(Token::comma)) {
- if (parseElement())
- return failure();
- }
- return success();
-}
-
-/// Parse a comma-separated list of elements, terminated with an arbitrary
-/// token. This allows empty lists if allowEmptyList is true.
-///
-/// abstract-list ::= rightToken // if allowEmptyList == true
-/// abstract-list ::= element (',' element)* rightToken
-///
-ParseResult Parser::parseCommaSeparatedListUntil(
- Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
- bool allowEmptyList) {
- // Handle the empty case.
- if (getToken().is(rightToken)) {
- if (!allowEmptyList)
- return emitError("expected list element");
- consumeToken(rightToken);
- return success();
- }
-
- if (parseCommaSeparatedList(parseElement) ||
- parseToken(rightToken, "expected ',' or '" +
- Token::getTokenSpelling(rightToken) + "'"))
- return failure();
-
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// DialectAsmParser
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class provides the main implementation of the DialectAsmParser that
-/// allows for dialects to parse attributes and types. This allows for dialect
-/// hooking into the main MLIR parsing logic.
-class CustomDialectAsmParser : public DialectAsmParser {
-public:
- CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
- : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()),
- parser(parser) {}
- ~CustomDialectAsmParser() override {}
-
- /// Emit a diagnostic at the specified location and return failure.
- InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override {
- return parser.emitError(loc, message);
- }
-
- /// Return a builder which provides useful access to MLIRContext, global
- /// objects like types and attributes.
- Builder &getBuilder() const override { return parser.builder; }
-
- /// Get the location of the next token and store it into the argument. This
- /// always succeeds.
- llvm::SMLoc getCurrentLocation() override {
- return parser.getToken().getLoc();
- }
-
- /// Return the location of the original name token.
- llvm::SMLoc getNameLoc() const override { return nameLoc; }
-
- /// Re-encode the given source location as an MLIR location and return it.
- Location getEncodedSourceLoc(llvm::SMLoc loc) override {
- return parser.getEncodedSourceLocation(loc);
- }
-
- /// Returns the full specification of the symbol being parsed. This allows
- /// for using a separate parser if necessary.
- StringRef getFullSymbolSpec() const override { return fullSpec; }
-
- /// Parse a floating point value from the stream.
- ParseResult parseFloat(double &result) override {
- bool negative = parser.consumeIf(Token::minus);
- Token curTok = parser.getToken();
-
- // Check for a floating point value.
- if (curTok.is(Token::floatliteral)) {
- auto val = curTok.getFloatingPointValue();
- if (!val.hasValue())
- return emitError(curTok.getLoc(), "floating point value too large");
- parser.consumeToken(Token::floatliteral);
- result = negative ? -*val : *val;
- return success();
- }
-
- // TODO(riverriddle) support hex floating point values.
- return emitError(getCurrentLocation(), "expected floating point literal");
- }
-
- /// Parse an optional integer value from the stream.
- OptionalParseResult parseOptionalInteger(uint64_t &result) override {
- Token curToken = parser.getToken();
- if (curToken.isNot(Token::integer, Token::minus))
- return llvm::None;
-
- bool negative = parser.consumeIf(Token::minus);
- Token curTok = parser.getToken();
- if (parser.parseToken(Token::integer, "expected integer value"))
- return failure();
-
- auto val = curTok.getUInt64IntegerValue();
- if (!val)
- return emitError(curTok.getLoc(), "integer value too large");
- result = negative ? -*val : *val;
- return success();
- }
-
- //===--------------------------------------------------------------------===//
- // Token Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse a `->` token.
- ParseResult parseArrow() override {
- return parser.parseToken(Token::arrow, "expected '->'");
- }
-
- /// Parses a `->` if present.
- ParseResult parseOptionalArrow() override {
- return success(parser.consumeIf(Token::arrow));
- }
-
- /// Parse a '{' token.
- ParseResult parseLBrace() override {
- return parser.parseToken(Token::l_brace, "expected '{'");
- }
-
- /// Parse a '{' token if present
- ParseResult parseOptionalLBrace() override {
- return success(parser.consumeIf(Token::l_brace));
- }
-
- /// Parse a `}` token.
- ParseResult parseRBrace() override {
- return parser.parseToken(Token::r_brace, "expected '}'");
- }
-
- /// Parse a `}` token if present
- ParseResult parseOptionalRBrace() override {
- return success(parser.consumeIf(Token::r_brace));
- }
-
- /// Parse a `:` token.
- ParseResult parseColon() override {
- return parser.parseToken(Token::colon, "expected ':'");
- }
-
- /// Parse a `:` token if present.
- ParseResult parseOptionalColon() override {
- return success(parser.consumeIf(Token::colon));
- }
-
- /// Parse a `,` token.
- ParseResult parseComma() override {
- return parser.parseToken(Token::comma, "expected ','");
- }
-
- /// Parse a `,` token if present.
- ParseResult parseOptionalComma() override {
- return success(parser.consumeIf(Token::comma));
- }
-
- /// Parses a `...` if present.
- ParseResult parseOptionalEllipsis() override {
- return success(parser.consumeIf(Token::ellipsis));
- }
-
- /// Parse a `=` token.
- ParseResult parseEqual() override {
- return parser.parseToken(Token::equal, "expected '='");
- }
-
- /// Parse a `=` token if present.
- ParseResult parseOptionalEqual() override {
- return success(parser.consumeIf(Token::equal));
- }
-
- /// Parse a '<' token.
- ParseResult parseLess() override {
- return parser.parseToken(Token::less, "expected '<'");
- }
-
- /// Parse a `<` token if present.
- ParseResult parseOptionalLess() override {
- return success(parser.consumeIf(Token::less));
- }
-
- /// Parse a '>' token.
- ParseResult parseGreater() override {
- return parser.parseToken(Token::greater, "expected '>'");
- }
-
- /// Parse a `>` token if present.
- ParseResult parseOptionalGreater() override {
- return success(parser.consumeIf(Token::greater));
- }
-
- /// Parse a `(` token.
- ParseResult parseLParen() override {
- return parser.parseToken(Token::l_paren, "expected '('");
- }
-
- /// Parses a '(' if present.
- ParseResult parseOptionalLParen() override {
- return success(parser.consumeIf(Token::l_paren));
- }
-
- /// Parse a `)` token.
- ParseResult parseRParen() override {
- return parser.parseToken(Token::r_paren, "expected ')'");
- }
-
- /// Parses a ')' if present.
- ParseResult parseOptionalRParen() override {
- return success(parser.consumeIf(Token::r_paren));
- }
-
- /// Parse a `[` token.
- ParseResult parseLSquare() override {
- return parser.parseToken(Token::l_square, "expected '['");
- }
-
- /// Parses a '[' if present.
- ParseResult parseOptionalLSquare() override {
- return success(parser.consumeIf(Token::l_square));
- }
-
- /// Parse a `]` token.
- ParseResult parseRSquare() override {
- return parser.parseToken(Token::r_square, "expected ']'");
- }
-
- /// Parses a ']' if present.
- ParseResult parseOptionalRSquare() override {
- return success(parser.consumeIf(Token::r_square));
- }
-
- /// Parses a '?' if present.
- ParseResult parseOptionalQuestion() override {
- return success(parser.consumeIf(Token::question));
- }
-
- /// Parses a '*' if present.
- ParseResult parseOptionalStar() override {
- return success(parser.consumeIf(Token::star));
- }
-
- /// Returns if the current token corresponds to a keyword.
- bool isCurrentTokenAKeyword() const {
- return parser.getToken().is(Token::bare_identifier) ||
- parser.getToken().isKeyword();
- }
-
- /// Parse the given keyword if present.
- ParseResult parseOptionalKeyword(StringRef keyword) override {
- // Check that the current token has the same spelling.
- if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword)
- return failure();
- parser.consumeToken();
- return success();
- }
-
- /// Parse a keyword, if present, into 'keyword'.
- ParseResult parseOptionalKeyword(StringRef *keyword) override {
- // Check that the current token is a keyword.
- if (!isCurrentTokenAKeyword())
- return failure();
-
- *keyword = parser.getTokenSpelling();
- parser.consumeToken();
- return success();
- }
-
- //===--------------------------------------------------------------------===//
- // Attribute Parsing
- //===--------------------------------------------------------------------===//
-
- /// Parse an arbitrary attribute and return it in result.
- ParseResult parseAttribute(Attribute &result, Type type) override {
- result = parser.parseAttribute(type);
- return success(static_cast<bool>(result));
- }
-
- /// Parse an affine map instance into 'map'.
- ParseResult parseAffineMap(AffineMap &map) override {
- return parser.parseAffineMapReference(map);
- }
-
- /// Parse an integer set instance into 'set'.
- ParseResult printIntegerSet(IntegerSet &set) override {
- return parser.parseIntegerSetReference(set);
- }
-
- //===--------------------------------------------------------------------===//
- // Type Parsing
- //===--------------------------------------------------------------------===//
-
- ParseResult parseType(Type &result) override {
- result = parser.parseType();
- return success(static_cast<bool>(result));
- }
-
- ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic) override {
- return parser.parseDimensionListRanked(dimensions, allowDynamic);
- }
-
-private:
- /// The full symbol specification.
- StringRef fullSpec;
-
- /// The source location of the dialect symbol.
- SMLoc nameLoc;
-
- /// The main parser.
- Parser &parser;
-};
-} // namespace
-
-/// Parse the body of a pretty dialect symbol, which starts and ends with <>'s,
-/// and may be recursive. Return with the 'prettyName' StringRef encompassing
-/// the entire pretty name.
-///
-/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
-/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
-/// | '(' pretty-dialect-sym-contents+ ')'
-/// | '[' pretty-dialect-sym-contents+ ']'
-/// | '{' pretty-dialect-sym-contents+ '}'
-/// | '[^[<({>\])}\0]+'
-///
-ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) {
- // Pretty symbol names are a relatively unstructured format that contains a
- // series of properly nested punctuation, with anything else in the middle.
- // Scan ahead to find it and consume it if successful, otherwise emit an
- // error.
- auto *curPtr = getTokenSpelling().data();
-
- SmallVector<char, 8> nestedPunctuation;
-
- // Scan over the nested punctuation, bailing out on error and consuming until
- // we find the end. We know that we're currently looking at the '<', so we
- // can go until we find the matching '>' character.
- assert(*curPtr == '<');
- do {
- char c = *curPtr++;
- switch (c) {
- case '\0':
- // This also handles the EOF case.
- return emitError("unexpected nul or EOF in pretty dialect name");
- case '<':
- case '[':
- case '(':
- case '{':
- nestedPunctuation.push_back(c);
- continue;
-
- case '-':
- // The sequence `->` is treated as special token.
- if (*curPtr == '>')
- ++curPtr;
- continue;
-
- case '>':
- if (nestedPunctuation.pop_back_val() != '<')
- return emitError("unbalanced '>' character in pretty dialect name");
- break;
- case ']':
- if (nestedPunctuation.pop_back_val() != '[')
- return emitError("unbalanced ']' character in pretty dialect name");
- break;
- case ')':
- if (nestedPunctuation.pop_back_val() != '(')
- return emitError("unbalanced ')' character in pretty dialect name");
- break;
- case '}':
- if (nestedPunctuation.pop_back_val() != '{')
- return emitError("unbalanced '}' character in pretty dialect name");
- break;
-
- default:
- continue;
- }
- } while (!nestedPunctuation.empty());
-
- // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
- // consuming all this stuff, and return.
- state.lex.resetPointer(curPtr);
-
- unsigned length = curPtr - prettyName.begin();
- prettyName = StringRef(prettyName.begin(), length);
- consumeToken();
- return success();
-}
-
-/// Parse an extended dialect symbol.
-template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
-static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
- SymbolAliasMap &aliases,
- CreateFn &&createSymbol) {
- // Parse the dialect namespace.
- StringRef identifier = p.getTokenSpelling().drop_front();
- auto loc = p.getToken().getLoc();
- p.consumeToken(identifierTok);
-
- // If there is no '<' token following this, and if the typename contains no
- // dot, then we are parsing a symbol alias.
- if (p.getToken().isNot(Token::less) && !identifier.contains('.')) {
- // Check for an alias for this type.
- auto aliasIt = aliases.find(identifier);
- if (aliasIt == aliases.end())
- return (p.emitError("undefined symbol alias id '" + identifier + "'"),
- nullptr);
- return aliasIt->second;
- }
-
- // Otherwise, we are parsing a dialect-specific symbol. If the name contains
- // a dot, then this is the "pretty" form. If not, it is the verbose form that
- // looks like <"...">.
- std::string symbolData;
- auto dialectName = identifier;
-
- // Handle the verbose form, where "identifier" is a simple dialect name.
- if (!identifier.contains('.')) {
- // Consume the '<'.
- if (p.parseToken(Token::less, "expected '<' in dialect type"))
- return nullptr;
-
- // Parse the symbol specific data.
- if (p.getToken().isNot(Token::string))
- return (p.emitError("expected string literal data in dialect symbol"),
- nullptr);
- symbolData = p.getToken().getStringValue();
- loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1);
- p.consumeToken(Token::string);
-
- // Consume the '>'.
- if (p.parseToken(Token::greater, "expected '>' in dialect symbol"))
- return nullptr;
- } else {
- // Ok, the dialect name is the part of the identifier before the dot, the
- // part after the dot is the dialect's symbol, or the start thereof.
- auto dotHalves = identifier.split('.');
- dialectName = dotHalves.first;
- auto prettyName = dotHalves.second;
- loc = llvm::SMLoc::getFromPointer(prettyName.data());
-
- // If the dialect's symbol is followed immediately by a <, then lex the body
- // of it into prettyName.
- if (p.getToken().is(Token::less) &&
- prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) {
- if (p.parsePrettyDialectSymbolName(prettyName))
- return nullptr;
- }
-
- symbolData = prettyName.str();
- }
-
- // Record the name location of the type remapped to the top level buffer.
- llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc);
- p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer);
-
- // Call into the provided symbol construction function.
- Symbol sym = createSymbol(dialectName, symbolData, loc);
-
- // Pop the last parser location.
- p.getState().symbols.nestedParserLocs.pop_back();
- return sym;
-}
-
-/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
-/// parsing failed, nullptr is returned. The number of bytes read from the input
-/// string is returned in 'numRead'.
-template <typename T, typename ParserFn>
-static T parseSymbol(StringRef inputStr, MLIRContext *context,
- SymbolState &symbolState, ParserFn &&parserFn,
- size_t *numRead = nullptr) {
- SourceMgr sourceMgr;
- auto memBuffer = MemoryBuffer::getMemBuffer(
- inputStr, /*BufferName=*/"<mlir_parser_buffer>",
- /*RequiresNullTerminator=*/false);
- sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
- ParserState state(sourceMgr, context, symbolState);
- Parser parser(state);
-
- Token startTok = parser.getToken();
- T symbol = parserFn(parser);
- if (!symbol)
- return T();
-
- // If 'numRead' is valid, then provide the number of bytes that were read.
- Token endTok = parser.getToken();
- if (numRead) {
- *numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
- startTok.getLoc().getPointer());
-
- // Otherwise, ensure that all of the tokens were parsed.
- } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) {
- parser.emitError(endTok.getLoc(), "encountered unexpected token");
- return T();
- }
- return symbol;
-}
-
-//===----------------------------------------------------------------------===//
-// Error Handling
-//===----------------------------------------------------------------------===//
-
-InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
- auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
-
- // If we hit a parse error in response to a lexer error, then the lexer
- // already reported the error.
- if (getToken().is(Token::error))
- diag.abandon();
- return diag;
-}
-
-//===----------------------------------------------------------------------===//
-// Token Parsing
-//===----------------------------------------------------------------------===//
-
-/// Consume the specified token if present and return success. On failure,
-/// output a diagnostic and return failure.
-ParseResult Parser::parseToken(Token::Kind expectedToken,
- const Twine &message) {
- if (consumeIf(expectedToken))
- return success();
- return emitError(message);
-}
-
-//===----------------------------------------------------------------------===//
-// Type Parsing
-//===----------------------------------------------------------------------===//
-
-/// Optionally parse a type.
-OptionalParseResult Parser::parseOptionalType(Type &type) {
- // There are many
diff erent starting tokens for a type, check them here.
- switch (getToken().getKind()) {
- case Token::l_paren:
- case Token::kw_memref:
- case Token::kw_tensor:
- case Token::kw_complex:
- case Token::kw_tuple:
- case Token::kw_vector:
- case Token::inttype:
- case Token::kw_bf16:
- case Token::kw_f16:
- case Token::kw_f32:
- case Token::kw_f64:
- case Token::kw_index:
- case Token::kw_none:
- case Token::exclamation_identifier:
- return failure(!(type = parseType()));
-
- default:
- return llvm::None;
- }
-}
-
-/// Parse an arbitrary type.
-///
-/// type ::= function-type
-/// | non-function-type
-///
-Type Parser::parseType() {
- if (getToken().is(Token::l_paren))
- return parseFunctionType();
- return parseNonFunctionType();
-}
-
-/// Parse a function result type.
-///
-/// function-result-type ::= type-list-parens
-/// | non-function-type
-///
-ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
- if (getToken().is(Token::l_paren))
- return parseTypeListParens(elements);
-
- Type t = parseNonFunctionType();
- if (!t)
- return failure();
- elements.push_back(t);
- return success();
-}
-
-/// Parse a list of types without an enclosing parenthesis. The list must have
-/// at least one member.
-///
-/// type-list-no-parens ::= type (`,` type)*
-///
-ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
- auto parseElt = [&]() -> ParseResult {
- auto elt = parseType();
- elements.push_back(elt);
- return elt ? success() : failure();
- };
-
- return parseCommaSeparatedList(parseElt);
-}
-
-/// Parse a parenthesized list of types.
-///
-/// type-list-parens ::= `(` `)`
-/// | `(` type-list-no-parens `)`
-///
-ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
- if (parseToken(Token::l_paren, "expected '('"))
- return failure();
-
- // Handle empty lists.
- if (getToken().is(Token::r_paren))
- return consumeToken(), success();
-
- if (parseTypeListNoParens(elements) ||
- parseToken(Token::r_paren, "expected ')'"))
- return failure();
- return success();
-}
-
-/// Parse a complex type.
-///
-/// complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
- consumeToken(Token::kw_complex);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in complex type"))
- return nullptr;
-
- llvm::SMLoc elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType ||
- parseToken(Token::greater, "expected '>' in complex type"))
- return nullptr;
- if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
- return emitError(elementTypeLoc, "invalid element type for complex"),
- nullptr;
-
- return ComplexType::get(elementType);
-}
-
-/// Parse an extended type.
-///
-/// extended-type ::= (dialect-type | type-alias)
-/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
-/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
-/// type-alias ::= `!` alias-name
-///
-Type Parser::parseExtendedType() {
- return parseExtendedSymbol<Type>(
- *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
- [&](StringRef dialectName, StringRef symbolData,
- llvm::SMLoc loc) -> Type {
- // If we found a registered dialect, then ask it to parse the type.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
- return parseSymbol<Type>(
- symbolData, state.context, state.symbols, [&](Parser &parser) {
- CustomDialectAsmParser customParser(symbolData, parser);
- return dialect->parseType(customParser);
- });
- }
-
- // Otherwise, form a new opaque type.
- return OpaqueType::getChecked(
- Identifier::get(dialectName, state.context), symbolData,
- state.context, getEncodedSourceLocation(loc));
- });
-}
-
-/// Parse a function type.
-///
-/// function-type ::= type-list-parens `->` function-result-type
-///
-Type Parser::parseFunctionType() {
- assert(getToken().is(Token::l_paren));
-
- SmallVector<Type, 4> arguments, results;
- if (parseTypeListParens(arguments) ||
- parseToken(Token::arrow, "expected '->' in function type") ||
- parseFunctionResultTypes(results))
- return nullptr;
-
- return builder.getFunctionType(arguments, results);
-}
-
-/// Parse the offset and strides from a strided layout specification.
-///
-/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-///
-ParseResult Parser::parseStridedLayout(int64_t &offset,
- SmallVectorImpl<int64_t> &strides) {
- // Parse offset.
- consumeToken(Token::kw_offset);
- if (!consumeIf(Token::colon))
- return emitError("expected colon after `offset` keyword");
- auto maybeOffset = getToken().getUnsignedIntegerValue();
- bool question = getToken().is(Token::question);
- if (!maybeOffset && !question)
- return emitError("invalid offset");
- offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
- : MemRefType::getDynamicStrideOrOffset();
- consumeToken();
-
- if (!consumeIf(Token::comma))
- return emitError("expected comma after offset value");
-
- // Parse stride list.
- if (!consumeIf(Token::kw_strides))
- return emitError("expected `strides` keyword after offset specification");
- if (!consumeIf(Token::colon))
- return emitError("expected colon after `strides` keyword");
- if (failed(parseStrideList(strides)))
- return emitError("invalid braces-enclosed stride list");
- if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
- return emitError("invalid memref stride");
-
- return success();
-}
-
-/// Parse a memref type.
-///
-/// memref-type ::= ranked-memref-type | unranked-memref-type
-///
-/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-/// (`,` semi-affine-map-composition)? (`,`
-/// memory-space)? `>`
-///
-/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
-/// memory-space ::= integer-literal /* | TODO: address-space-id */
-///
-Type Parser::parseMemRefType() {
- consumeToken(Token::kw_memref);
-
- if (parseToken(Token::less, "expected '<' in memref type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked memref type.
- isUnranked = true;
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType)
- return nullptr;
-
- // Check that memref is formed from allowed types.
- if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
- !elementType.isa<ComplexType>())
- return emitError(typeLoc, "invalid memref element type"), nullptr;
-
- // Parse semi-affine-map-composition.
- SmallVector<AffineMap, 2> affineMapComposition;
- Optional<unsigned> memorySpace;
- unsigned numDims = dimensions.size();
-
- auto parseElt = [&]() -> ParseResult {
- // Check for the memory space.
- if (getToken().is(Token::integer)) {
- if (memorySpace)
- return emitError("multiple memory spaces specified in memref type");
- memorySpace = getToken().getUnsignedIntegerValue();
- if (!memorySpace.hasValue())
- return emitError("invalid memory space in memref type");
- consumeToken(Token::integer);
- return success();
- }
- if (isUnranked)
- return emitError("cannot have affine map for unranked memref type");
- if (memorySpace)
- return emitError("expected memory space to be last in memref type");
-
- AffineMap map;
- llvm::SMLoc mapLoc = getToken().getLoc();
- if (getToken().is(Token::kw_offset)) {
- int64_t offset;
- SmallVector<int64_t, 4> strides;
- if (failed(parseStridedLayout(offset, strides)))
- return failure();
- // Construct strided affine map.
- map = makeStridedLinearLayoutMap(strides, offset, state.context);
- } else {
- // Parse an affine map attribute.
- auto affineMap = parseAttribute();
- if (!affineMap)
- return failure();
- auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
- if (!affineMapAttr)
- return emitError("expected affine map in memref type");
- map = affineMapAttr.getValue();
- }
-
- if (map.getNumDims() != numDims) {
- size_t i = affineMapComposition.size();
- return emitError(mapLoc, "memref affine map dimension mismatch between ")
- << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
- << " and affine map" << i + 1 << ": " << numDims
- << " != " << map.getNumDims();
- }
- numDims = map.getNumResults();
- affineMapComposition.push_back(map);
- return success();
- };
-
- // Parse a list of mappings and address space if present.
- if (!consumeIf(Token::greater)) {
- // Parse comma separated list of affine maps, followed by memory space.
- if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
- parseCommaSeparatedListUntil(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
- return nullptr;
- }
- }
-
- if (isUnranked)
- return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
-
- return MemRefType::get(dimensions, elementType, affineMapComposition,
- memorySpace.getValueOr(0));
-}
-
-/// Parse any type except the function type.
-///
-/// non-function-type ::= integer-type
-/// | index-type
-/// | float-type
-/// | extended-type
-/// | vector-type
-/// | tensor-type
-/// | memref-type
-/// | complex-type
-/// | tuple-type
-/// | none-type
-///
-/// index-type ::= `index`
-/// float-type ::= `f16` | `bf16` | `f32` | `f64`
-/// none-type ::= `none`
-///
-Type Parser::parseNonFunctionType() {
- switch (getToken().getKind()) {
- default:
- return (emitError("expected non-function type"), nullptr);
- case Token::kw_memref:
- return parseMemRefType();
- case Token::kw_tensor:
- return parseTensorType();
- case Token::kw_complex:
- return parseComplexType();
- case Token::kw_tuple:
- return parseTupleType();
- case Token::kw_vector:
- return parseVectorType();
- // integer-type
- case Token::inttype: {
- auto width = getToken().getIntTypeBitwidth();
- if (!width.hasValue())
- return (emitError("invalid integer width"), nullptr);
- if (width.getValue() > IntegerType::kMaxWidth) {
- emitError(getToken().getLoc(), "integer bitwidth is limited to ")
- << IntegerType::kMaxWidth << " bits";
- return nullptr;
- }
-
- IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
- if (Optional<bool> signedness = getToken().getIntTypeSignedness())
- signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
-
- auto loc = getEncodedSourceLocation(getToken().getLoc());
- consumeToken(Token::inttype);
- return IntegerType::getChecked(width.getValue(), signSemantics, loc);
- }
-
- // float-type
- case Token::kw_bf16:
- consumeToken(Token::kw_bf16);
- return builder.getBF16Type();
- case Token::kw_f16:
- consumeToken(Token::kw_f16);
- return builder.getF16Type();
- case Token::kw_f32:
- consumeToken(Token::kw_f32);
- return builder.getF32Type();
- case Token::kw_f64:
- consumeToken(Token::kw_f64);
- return builder.getF64Type();
-
- // index-type
- case Token::kw_index:
- consumeToken(Token::kw_index);
- return builder.getIndexType();
-
- // none-type
- case Token::kw_none:
- consumeToken(Token::kw_none);
- return builder.getNoneType();
-
- // extended type
- case Token::exclamation_identifier:
- return parseExtendedType();
- }
-}
-
-/// Parse a tensor type.
-///
-/// tensor-type ::= `tensor` `<` dimension-list type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
- consumeToken(Token::kw_tensor);
-
- if (parseToken(Token::less, "expected '<' in tensor type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked tensor type.
- isUnranked = true;
-
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
- return nullptr;
- if (!TensorType::isValidElementType(elementType))
- return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
- if (isUnranked)
- return UnrankedTensorType::get(elementType);
- return RankedTensorType::get(dimensions, elementType);
-}
-
-/// Parse a tuple type.
-///
-/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
- consumeToken(Token::kw_tuple);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in tuple type"))
- return nullptr;
-
- // Check for an empty tuple by directly parsing '>'.
- if (consumeIf(Token::greater))
- return TupleType::get(getContext());
-
- // Parse the element types and the '>'.
- SmallVector<Type, 4> types;
- if (parseTypeListNoParens(types) ||
- parseToken(Token::greater, "expected '>' in tuple type"))
- return nullptr;
-
- return TupleType::get(types, getContext());
-}
-
-/// Parse a vector type.
-///
-/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
-/// non-empty-static-dimension-list ::= decimal-literal `x`
-/// static-dimension-list
-/// static-dimension-list ::= (decimal-literal `x`)*
-///
-VectorType Parser::parseVectorType() {
- consumeToken(Token::kw_vector);
-
- if (parseToken(Token::less, "expected '<' in vector type"))
- return nullptr;
-
- SmallVector<int64_t, 4> dimensions;
- if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
- return nullptr;
- if (dimensions.empty())
- return (emitError("expected dimension size in vector type"), nullptr);
- if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
- return emitError(getToken().getLoc(),
- "vector types must have positive constant sizes"),
- nullptr;
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
- return nullptr;
- if (!VectorType::isValidElementType(elementType))
- return emitError(typeLoc, "vector elements must be int or float type"),
- nullptr;
-
- return VectorType::get(dimensions, elementType);
-}
-
-/// Parse a dimension list of a tensor or memref type. This populates the
-/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
-/// errors out on `?` otherwise.
-///
-/// dimension-list-ranked ::= (dimension `x`)*
-/// dimension ::= `?` | decimal-literal
-///
-/// When `allowDynamic` is not set, this is used to parse:
-///
-/// static-dimension-list ::= (decimal-literal `x`)*
-ParseResult
-Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
- bool allowDynamic) {
- while (getToken().isAny(Token::integer, Token::question)) {
- if (consumeIf(Token::question)) {
- if (!allowDynamic)
- return emitError("expected static shape");
- dimensions.push_back(-1);
- } else {
- // Hexadecimal integer literals (starting with `0x`) are not allowed in
- // aggregate type declarations. Therefore, `0xf32` should be processed as
- // a sequence of separate elements `0`, `x`, `f32`.
- if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
- // We can get here only if the token is an integer literal. Hexadecimal
- // integer literals can only start with `0x` (`1x` wouldn't lex as a
- // literal, just `1` would, at which point we don't get into this
- // branch).
- assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
- dimensions.push_back(0);
- state.lex.resetPointer(getTokenSpelling().data() + 1);
- consumeToken();
- } else {
- // Make sure this integer value is in bound and valid.
- auto dimension = getToken().getUnsignedIntegerValue();
- if (!dimension.hasValue())
- return emitError("invalid dimension");
- dimensions.push_back((int64_t)dimension.getValue());
- consumeToken(Token::integer);
- }
- }
-
- // Make sure we have an 'x' or something like 'xbf32'.
- if (parseXInDimensionList())
- return failure();
- }
-
- return success();
-}
-
-/// Parse an 'x' token in a dimension list, handling the case where the x is
-/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
-/// token.
-ParseResult Parser::parseXInDimensionList() {
- if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
- return emitError("expected 'x' in dimension list");
-
- // If we had a prefix of 'x', lex the next token immediately after the 'x'.
- if (getTokenSpelling().size() != 1)
- state.lex.resetPointer(getTokenSpelling().data() + 1);
-
- // Consume the 'x'.
- consumeToken(Token::bare_identifier);
-
- return success();
-}
-
-// Parse a comma-separated list of dimensions, possibly empty:
-// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
- if (!consumeIf(Token::l_square))
- return failure();
- // Empty list early exit.
- if (consumeIf(Token::r_square))
- return success();
- while (true) {
- if (consumeIf(Token::question)) {
- dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
- } else {
- // This must be an integer value.
- int64_t val;
- if (getToken().getSpelling().getAsInteger(10, val))
- return emitError("invalid integer value: ") << getToken().getSpelling();
- // Make sure it is not the one value for `?`.
- if (ShapedType::isDynamic(val))
- return emitError("invalid integer value: ")
- << getToken().getSpelling()
- << ", use `?` to specify a dynamic dimension";
- dimensions.push_back(val);
- consumeToken(Token::integer);
- }
- if (!consumeIf(Token::comma))
- break;
- }
- if (!consumeIf(Token::r_square))
- return failure();
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Attribute parsing.
-//===----------------------------------------------------------------------===//
-
-/// Return the symbol reference referred to by the given token, that is known to
-/// be an @-identifier.
-static std::string extractSymbolReference(Token tok) {
- assert(tok.is(Token::at_identifier) && "expected valid @-identifier");
- StringRef nameStr = tok.getSpelling().drop_front();
-
- // Check to see if the reference is a string literal, or a bare identifier.
- if (nameStr.front() == '"')
- return tok.getStringValue();
- return std::string(nameStr);
-}
-
-/// Parse an arbitrary attribute.
-///
-/// attribute-value ::= `unit`
-/// | bool-literal
-/// | integer-literal (`:` (index-type | integer-type))?
-/// | float-literal (`:` float-type)?
-/// | string-literal (`:` type)?
-/// | type
-/// | `[` (attribute-value (`,` attribute-value)*)? `]`
-/// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
-/// | symbol-ref-id (`::` symbol-ref-id)*
-/// | `dense` `<` attribute-value `>` `:`
-/// (tensor-type | vector-type)
-/// | `sparse` `<` attribute-value `,` attribute-value `>`
-/// `:` (tensor-type | vector-type)
-/// | `opaque` `<` dialect-namespace `,` hex-string-literal
-/// `>` `:` (tensor-type | vector-type)
-/// | extended-attribute
-///
-Attribute Parser::parseAttribute(Type type) {
- switch (getToken().getKind()) {
- // Parse an AffineMap or IntegerSet attribute.
- case Token::kw_affine_map: {
- consumeToken(Token::kw_affine_map);
-
- AffineMap map;
- if (parseToken(Token::less, "expected '<' in affine map") ||
- parseAffineMapReference(map) ||
- parseToken(Token::greater, "expected '>' in affine map"))
- return Attribute();
- return AffineMapAttr::get(map);
- }
- case Token::kw_affine_set: {
- consumeToken(Token::kw_affine_set);
-
- IntegerSet set;
- if (parseToken(Token::less, "expected '<' in integer set") ||
- parseIntegerSetReference(set) ||
- parseToken(Token::greater, "expected '>' in integer set"))
- return Attribute();
- return IntegerSetAttr::get(set);
- }
-
- // Parse an array attribute.
- case Token::l_square: {
- consumeToken(Token::l_square);
-
- SmallVector<Attribute, 4> elements;
- auto parseElt = [&]() -> ParseResult {
- elements.push_back(parseAttribute());
- return elements.back() ? success() : failure();
- };
-
- if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
- return nullptr;
- return builder.getArrayAttr(elements);
- }
-
- // Parse a boolean attribute.
- case Token::kw_false:
- consumeToken(Token::kw_false);
- return builder.getBoolAttr(false);
- case Token::kw_true:
- consumeToken(Token::kw_true);
- return builder.getBoolAttr(true);
-
- // Parse a dense elements attribute.
- case Token::kw_dense:
- return parseDenseElementsAttr(type);
-
- // Parse a dictionary attribute.
- case Token::l_brace: {
- NamedAttrList elements;
- if (parseAttributeDict(elements))
- return nullptr;
- return elements.getDictionary(getContext());
- }
-
- // Parse an extended attribute, i.e. alias or dialect attribute.
- case Token::hash_identifier:
- return parseExtendedAttr(type);
-
- // Parse floating point and integer attributes.
- case Token::floatliteral:
- return parseFloatAttr(type, /*isNegative=*/false);
- case Token::integer:
- return parseDecOrHexAttr(type, /*isNegative=*/false);
- case Token::minus: {
- consumeToken(Token::minus);
- if (getToken().is(Token::integer))
- return parseDecOrHexAttr(type, /*isNegative=*/true);
- if (getToken().is(Token::floatliteral))
- return parseFloatAttr(type, /*isNegative=*/true);
-
- return (emitError("expected constant integer or floating point value"),
- nullptr);
- }
-
- // Parse a location attribute.
- case Token::kw_loc: {
- LocationAttr attr;
- return failed(parseLocation(attr)) ? Attribute() : attr;
- }
-
- // Parse an opaque elements attribute.
- case Token::kw_opaque:
- return parseOpaqueElementsAttr(type);
-
- // Parse a sparse elements attribute.
- case Token::kw_sparse:
- return parseSparseElementsAttr(type);
-
- // Parse a string attribute.
- case Token::string: {
- auto val = getToken().getStringValue();
- consumeToken(Token::string);
- // Parse the optional trailing colon type if one wasn't explicitly provided.
- if (!type && consumeIf(Token::colon) && !(type = parseType()))
- return Attribute();
-
- return type ? StringAttr::get(val, type)
- : StringAttr::get(val, getContext());
- }
-
- // Parse a symbol reference attribute.
- case Token::at_identifier: {
- std::string nameStr = extractSymbolReference(getToken());
- consumeToken(Token::at_identifier);
-
- // Parse any nested references.
- std::vector<FlatSymbolRefAttr> nestedRefs;
- while (getToken().is(Token::colon)) {
- // Check for the '::' prefix.
- const char *curPointer = getToken().getLoc().getPointer();
- consumeToken(Token::colon);
- if (!consumeIf(Token::colon)) {
- state.lex.resetPointer(curPointer);
- consumeToken();
- break;
- }
- // Parse the reference itself.
- auto curLoc = getToken().getLoc();
- if (getToken().isNot(Token::at_identifier)) {
- emitError(curLoc, "expected nested symbol reference identifier");
- return Attribute();
- }
-
- std::string nameStr = extractSymbolReference(getToken());
- consumeToken(Token::at_identifier);
- nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
- }
-
- return builder.getSymbolRefAttr(nameStr, nestedRefs);
- }
-
- // Parse a 'unit' attribute.
- case Token::kw_unit:
- consumeToken(Token::kw_unit);
- return builder.getUnitAttr();
-
- default:
- // Parse a type attribute.
- if (Type type = parseType())
- return TypeAttr::get(type);
- return nullptr;
- }
-}
-
-/// Attribute dictionary.
-///
-/// attribute-dict ::= `{` `}`
-/// | `{` attribute-entry (`,` attribute-entry)* `}`
-/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
-///
-ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
- if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
- return failure();
-
- llvm::SmallDenseSet<Identifier> seenKeys;
- auto parseElt = [&]() -> ParseResult {
- // The name of an attribute can either be a bare identifier, or a string.
- Optional<Identifier> nameId;
- if (getToken().is(Token::string))
- nameId = builder.getIdentifier(getToken().getStringValue());
- else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
- getToken().isKeyword())
- nameId = builder.getIdentifier(getTokenSpelling());
- else
- return emitError("expected attribute name");
- if (!seenKeys.insert(*nameId).second)
- return emitError("duplicate key in dictionary attribute");
- consumeToken();
-
- // Try to parse the '=' for the attribute value.
- if (!consumeIf(Token::equal)) {
- // If there is no '=', we treat this as a unit attribute.
- attributes.push_back({*nameId, builder.getUnitAttr()});
- return success();
- }
-
- auto attr = parseAttribute();
- if (!attr)
- return failure();
- attributes.push_back({*nameId, attr});
- return success();
- };
-
- if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
- return failure();
-
- return success();
-}
-
-/// Parse an extended attribute.
-///
-/// extended-attribute ::= (dialect-attribute | attribute-alias)
-/// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
-/// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
-/// attribute-alias ::= `#` alias-name
-///
-Attribute Parser::parseExtendedAttr(Type type) {
- Attribute attr = parseExtendedSymbol<Attribute>(
- *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
- [&](StringRef dialectName, StringRef symbolData,
- llvm::SMLoc loc) -> Attribute {
- // Parse an optional trailing colon type.
- Type attrType = type;
- if (consumeIf(Token::colon) && !(attrType = parseType()))
- return Attribute();
-
- // If we found a registered dialect, then ask it to parse the attribute.
- if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
- return parseSymbol<Attribute>(
- symbolData, state.context, state.symbols, [&](Parser &parser) {
- CustomDialectAsmParser customParser(symbolData, parser);
- return dialect->parseAttribute(customParser, attrType);
- });
- }
-
- // Otherwise, form a new opaque attribute.
- return OpaqueAttr::getChecked(
- Identifier::get(dialectName, state.context), symbolData,
- attrType ? attrType : NoneType::get(state.context),
- getEncodedSourceLocation(loc));
- });
-
- // Ensure that the attribute has the same type as requested.
- if (attr && type && attr.getType() != type) {
- emitError("attribute type
diff erent than expected: expected ")
- << type << ", but got " << attr.getType();
- return nullptr;
- }
- return attr;
-}
-
-/// Parse a float attribute.
-Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
- auto val = getToken().getFloatingPointValue();
- if (!val.hasValue())
- return (emitError("floating point value too large for attribute"), nullptr);
- consumeToken(Token::floatliteral);
- if (!type) {
- // Default to F64 when no type is specified.
- if (!consumeIf(Token::colon))
- type = builder.getF64Type();
- else if (!(type = parseType()))
- return nullptr;
- }
- if (!type.isa<FloatType>())
- return (emitError("floating point value not valid for specified type"),
- nullptr);
- return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
-}
-
-/// Construct a float attribute bitwise equivalent to the integer literal.
-static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
- uint64_t value) {
- if (type.isF64())
- return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
-
- APInt apInt(type.getWidth(), value);
- if (apInt != value) {
- p->emitError("hexadecimal float constant out of range for type");
- return llvm::None;
- }
- return APFloat(type.getFloatSemantics(), apInt);
-}
-
-/// Construct an APint from a parsed value, a known attribute type and
-/// sign.
-static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
- StringRef spelling) {
- // Parse the integer value into an APInt that is big enough to hold the value.
- APInt result;
- bool isHex = spelling.size() > 1 && spelling[1] == 'x';
- if (spelling.getAsInteger(isHex ? 0 : 10, result))
- return llvm::None;
-
- // Extend or truncate the bitwidth to the right size.
- unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
- : type.getIntOrFloatBitWidth();
- if (width > result.getBitWidth()) {
- result = result.zext(width);
- } else if (width < result.getBitWidth()) {
- // The parser can return an unnecessarily wide result with leading zeros.
- // This isn't a problem, but truncating off bits is bad.
- if (result.countLeadingZeros() < result.getBitWidth() - width)
- return llvm::None;
-
- result = result.trunc(width);
- }
-
- if (isNegative) {
- // The value is negative, we have an overflow if the sign bit is not set
- // in the negated apInt.
- result.negate();
- if (!result.isSignBitSet())
- return llvm::None;
- } else if ((type.isSignedInteger() || type.isIndex()) &&
- result.isSignBitSet()) {
- // The value is a positive signed integer or index,
- // we have an overflow if the sign bit is set.
- return llvm::None;
- }
-
- return result;
-}
-
-/// Parse a decimal or a hexadecimal literal, which can be either an integer
-/// or a float attribute.
-Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
- // Remember if the literal is hexadecimal.
- StringRef spelling = getToken().getSpelling();
- auto loc = state.curToken.getLoc();
- bool isHex = spelling.size() > 1 && spelling[1] == 'x';
-
- consumeToken(Token::integer);
- if (!type) {
- // Default to i64 if not type is specified.
- if (!consumeIf(Token::colon))
- type = builder.getIntegerType(64);
- else if (!(type = parseType()))
- return nullptr;
- }
-
- if (auto floatType = type.dyn_cast<FloatType>()) {
- if (isNegative)
- return emitError(
- loc,
- "hexadecimal float literal should not have a leading minus"),
- nullptr;
- if (!isHex) {
- emitError(loc, "unexpected decimal integer literal for a float attribute")
- .attachNote()
- << "add a trailing dot to make the literal a float";
- return nullptr;
- }
-
- auto val = Token::getUInt64IntegerValue(spelling);
- if (!val.hasValue())
- return emitError("integer constant out of range for attribute"), nullptr;
-
- // Construct a float attribute bitwise equivalent to the integer literal.
- Optional<APFloat> apVal =
- buildHexadecimalFloatLiteral(this, floatType, *val);
- return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
- }
-
- if (!type.isa<IntegerType>() && !type.isa<IndexType>())
- return emitError(loc, "integer literal not valid for specified type"),
- nullptr;
-
- if (isNegative && type.isUnsignedInteger()) {
- emitError(loc,
- "negative integer literal not valid for unsigned integer type");
- return nullptr;
- }
-
- Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
- if (!apInt)
- return emitError(loc, "integer constant out of range for attribute"),
- nullptr;
- return builder.getIntegerAttr(type, *apInt);
-}
-
-/// Parse elements values stored within a hex etring. On success, the values are
-/// stored into 'result'.
-static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
- std::string &result) {
- std::string val = tok.getStringValue();
- if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
- return parser.emitError(tok.getLoc(),
- "elements hex string should start with '0x'");
-
- StringRef hexValues = StringRef(val).drop_front(2);
- if (!llvm::all_of(hexValues, llvm::isHexDigit))
- return parser.emitError(tok.getLoc(),
- "elements hex string only contains hex digits");
-
- result = llvm::fromHex(hexValues);
- return success();
-}
-
-/// Parse an opaque elements attribute.
-Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
- consumeToken(Token::kw_opaque);
- if (parseToken(Token::less, "expected '<' after 'opaque'"))
- return nullptr;
-
- if (getToken().isNot(Token::string))
- return (emitError("expected dialect namespace"), nullptr);
-
- auto name = getToken().getStringValue();
- auto *dialect = builder.getContext()->getRegisteredDialect(name);
- // TODO(shpeisman): Allow for having an unknown dialect on an opaque
- // attribute. Otherwise, it can't be roundtripped without having the dialect
- // registered.
- if (!dialect)
- return (emitError("no registered dialect with namespace '" + name + "'"),
- nullptr);
- consumeToken(Token::string);
-
- if (parseToken(Token::comma, "expected ','"))
- return nullptr;
-
- Token hexTok = getToken();
- if (parseToken(Token::string, "elements hex string should start with '0x'") ||
- parseToken(Token::greater, "expected '>'"))
- return nullptr;
- auto type = parseElementsLiteralType(attrType);
- if (!type)
- return nullptr;
-
- std::string data;
- if (parseElementAttrHexValues(*this, hexTok, data))
- return nullptr;
- return OpaqueElementsAttr::get(dialect, type, data);
-}
-
-namespace {
-class TensorLiteralParser {
-public:
- TensorLiteralParser(Parser &p) : p(p) {}
-
- /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
- /// may also parse a tensor literal that is store as a hex string.
- ParseResult parse(bool allowHex);
-
- /// Build a dense attribute instance with the parsed elements and the given
- /// shaped type.
- DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
-
- ArrayRef<int64_t> getShape() const { return shape; }
-
-private:
- /// Get the parsed elements for an integer attribute.
- ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
- std::vector<APInt> &intValues);
-
- /// Get the parsed elements for a float attribute.
- ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
- std::vector<APFloat> &floatValues);
-
- /// Build a Dense String attribute for the given type.
- DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
-
- /// Build a Dense attribute with hex data for the given type.
- DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
-
- /// Parse a single element, returning failure if it isn't a valid element
- /// literal. For example:
- /// parseElement(1) -> Success, 1
- /// parseElement([1]) -> Failure
- ParseResult parseElement();
-
- /// Parse a list of either lists or elements, returning the dimensions of the
- /// parsed sub-tensors in dims. For example:
- /// parseList([1, 2, 3]) -> Success, [3]
- /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
- /// parseList([[1, 2], 3]) -> Failure
- /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
- ParseResult parseList(SmallVectorImpl<int64_t> &dims);
-
- /// Parse a literal that was printed as a hex string.
- ParseResult parseHexElements();
-
- Parser &p;
-
- /// The shape inferred from the parsed elements.
- SmallVector<int64_t, 4> shape;
-
- /// Storage used when parsing elements, this is a pair of <is_negated, token>.
- std::vector<std::pair<bool, Token>> storage;
-
- /// Storage used when parsing elements that were stored as hex values.
- Optional<Token> hexStorage;
-};
-} // namespace
-
-/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
-/// may also parse a tensor literal that is store as a hex string.
-ParseResult TensorLiteralParser::parse(bool allowHex) {
- // If hex is allowed, check for a string literal.
- if (allowHex && p.getToken().is(Token::string)) {
- hexStorage = p.getToken();
- p.consumeToken(Token::string);
- return success();
- }
- // Otherwise, parse a list or an individual element.
- if (p.getToken().is(Token::l_square))
- return parseList(shape);
- return parseElement();
-}
-
-/// Build a dense attribute instance with the parsed elements and the given
-/// shaped type.
-DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
- ShapedType type) {
- Type eltType = type.getElementType();
-
- // Check to see if we parse the literal from a hex string.
- if (hexStorage.hasValue() &&
- (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
- return getHexAttr(loc, type);
-
- // Check that the parsed storage size has the same number of elements to the
- // type, or is a known splat.
- if (!shape.empty() && getShape() != type.getShape()) {
- p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
- << "]) does not match type ([" << type.getShape() << "])";
- return nullptr;
- }
-
- // Handle complex types in the specific element type cases below.
- bool isComplex = false;
- if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
- eltType = complexTy.getElementType();
- isComplex = true;
- }
-
- // Handle integer and index types.
- if (eltType.isIntOrIndex()) {
- std::vector<APInt> intValues;
- if (failed(getIntAttrElements(loc, eltType, intValues)))
- return nullptr;
- if (isComplex) {
- // If this is a complex, treat the parsed values as complex values.
- auto complexData = llvm::makeArrayRef(
- reinterpret_cast<std::complex<APInt> *>(intValues.data()),
- intValues.size() / 2);
- return DenseElementsAttr::get(type, complexData);
- }
- return DenseElementsAttr::get(type, intValues);
- }
- // Handle floating point types.
- if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
- std::vector<APFloat> floatValues;
- if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
- return nullptr;
- if (isComplex) {
- // If this is a complex, treat the parsed values as complex values.
- auto complexData = llvm::makeArrayRef(
- reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
- floatValues.size() / 2);
- return DenseElementsAttr::get(type, complexData);
- }
- return DenseElementsAttr::get(type, floatValues);
- }
-
- // Other types are assumed to be string representations.
- return getStringAttr(loc, type, type.getElementType());
-}
-
-/// Build a Dense Integer attribute for the given type.
-ParseResult
-TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
- std::vector<APInt> &intValues) {
- intValues.reserve(storage.size());
- bool isUintType = eltTy.isUnsignedInteger();
- for (const auto &signAndToken : storage) {
- bool isNegative = signAndToken.first;
- const Token &token = signAndToken.second;
- auto tokenLoc = token.getLoc();
-
- if (isNegative && isUintType) {
- return p.emitError(tokenLoc)
- << "expected unsigned integer elements, but parsed negative value";
- }
-
- // Check to see if floating point values were parsed.
- if (token.is(Token::floatliteral)) {
- return p.emitError(tokenLoc)
- << "expected integer elements, but parsed floating-point";
- }
-
- assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
- "unexpected token type");
- if (token.isAny(Token::kw_true, Token::kw_false)) {
- if (!eltTy.isInteger(1)) {
- return p.emitError(tokenLoc)
- << "expected i1 type for 'true' or 'false' values";
- }
- APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
- intValues.push_back(apInt);
- continue;
- }
-
- // Create APInt values for each element with the correct bitwidth.
- Optional<APInt> apInt =
- buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
- if (!apInt)
- return p.emitError(tokenLoc, "integer constant out of range for type");
- intValues.push_back(*apInt);
- }
- return success();
-}
-
-/// Build a Dense Float attribute for the given type.
-ParseResult
-TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
- std::vector<APFloat> &floatValues) {
- floatValues.reserve(storage.size());
- for (const auto &signAndToken : storage) {
- bool isNegative = signAndToken.first;
- const Token &token = signAndToken.second;
-
- // Handle hexadecimal float literals.
- if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
- if (isNegative) {
- return p.emitError(token.getLoc())
- << "hexadecimal float literal should not have a leading minus";
- }
- auto val = token.getUInt64IntegerValue();
- if (!val.hasValue()) {
- return p.emitError(
- "hexadecimal float constant out of range for attribute");
- }
- Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
- if (!apVal)
- return failure();
- floatValues.push_back(*apVal);
- continue;
- }
-
- // Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral))
- return p.emitError()
- << "expected floating-point elements, but parsed integer";
-
- // Build the float values from tokens.
- auto val = token.getFloatingPointValue();
- if (!val.hasValue())
- return p.emitError("floating point value too large for attribute");
-
- APFloat apVal(isNegative ? -*val : *val);
- if (!eltTy.isF64()) {
- bool unused;
- apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &unused);
- }
- floatValues.push_back(apVal);
- }
- return success();
-}
-
-/// Build a Dense String attribute for the given type.
-DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
- ShapedType type,
- Type eltTy) {
- if (hexStorage.hasValue()) {
- auto stringValue = hexStorage.getValue().getStringValue();
- return DenseStringElementsAttr::get(type, {stringValue});
- }
-
- std::vector<std::string> stringValues;
- std::vector<StringRef> stringRefValues;
- stringValues.reserve(storage.size());
- stringRefValues.reserve(storage.size());
-
- for (auto val : storage) {
- stringValues.push_back(val.second.getStringValue());
- stringRefValues.push_back(stringValues.back());
- }
-
- return DenseStringElementsAttr::get(type, stringRefValues);
-}
-
-/// Build a Dense attribute with hex data for the given type.
-DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
- ShapedType type) {
- Type elementType = type.getElementType();
- if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
- p.emitError(loc)
- << "expected floating-point, integer, or complex element type, got "
- << elementType;
- return nullptr;
- }
-
- std::string data;
- if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
- return nullptr;
-
- ArrayRef<char> rawData(data.data(), data.size());
- bool detectedSplat = false;
- if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
- p.emitError(loc) << "elements hex data size is invalid for provided type: "
- << type;
- return nullptr;
- }
-
- return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
-}
-
-ParseResult TensorLiteralParser::parseElement() {
- switch (p.getToken().getKind()) {
- // Parse a boolean element.
- case Token::kw_true:
- case Token::kw_false:
- case Token::floatliteral:
- case Token::integer:
- storage.emplace_back(/*isNegative=*/false, p.getToken());
- p.consumeToken();
- break;
-
- // Parse a signed integer or a negative floating-point element.
- case Token::minus:
- p.consumeToken(Token::minus);
- if (!p.getToken().isAny(Token::floatliteral, Token::integer))
- return p.emitError("expected integer or floating point literal");
- storage.emplace_back(/*isNegative=*/true, p.getToken());
- p.consumeToken();
- break;
-
- case Token::string:
- storage.emplace_back(/*isNegative=*/ false, p.getToken());
- p.consumeToken();
- break;
-
- // Parse a complex element of the form '(' element ',' element ')'.
- case Token::l_paren:
- p.consumeToken(Token::l_paren);
- if (parseElement() ||
- p.parseToken(Token::comma, "expected ',' between complex elements") ||
- parseElement() ||
- p.parseToken(Token::r_paren, "expected ')' after complex elements"))
- return failure();
- break;
-
- default:
- return p.emitError("expected element literal of primitive type");
- }
-
- return success();
-}
-
-/// Parse a list of either lists or elements, returning the dimensions of the
-/// parsed sub-tensors in dims. For example:
-/// parseList([1, 2, 3]) -> Success, [3]
-/// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
-/// parseList([[1, 2], 3]) -> Failure
-/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
-ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
- p.consumeToken(Token::l_square);
-
- auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
- const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
- if (prevDims == newDims)
- return success();
- return p.emitError("tensor literal is invalid; ranks are not consistent "
- "between elements");
- };
-
- bool first = true;
- SmallVector<int64_t, 4> newDims;
- unsigned size = 0;
- auto parseCommaSeparatedList = [&]() -> ParseResult {
- SmallVector<int64_t, 4> thisDims;
- if (p.getToken().getKind() == Token::l_square) {
- if (parseList(thisDims))
- return failure();
- } else if (parseElement()) {
- return failure();
- }
- ++size;
- if (!first)
- return checkDims(newDims, thisDims);
- newDims = thisDims;
- first = false;
- return success();
- };
- if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
- return failure();
-
- // Return the sublists' dimensions with 'size' prepended.
- dims.clear();
- dims.push_back(size);
- dims.append(newDims.begin(), newDims.end());
- return success();
-}
-
-/// Parse a dense elements attribute.
-Attribute Parser::parseDenseElementsAttr(Type attrType) {
- consumeToken(Token::kw_dense);
- if (parseToken(Token::less, "expected '<' after 'dense'"))
- return nullptr;
-
- // Parse the literal data.
- TensorLiteralParser literalParser(*this);
- if (literalParser.parse(/*allowHex=*/true))
- return nullptr;
-
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
-
- auto typeLoc = getToken().getLoc();
- auto type = parseElementsLiteralType(attrType);
- if (!type)
- return nullptr;
- return literalParser.getAttr(typeLoc, type);
-}
-
-/// Shaped type for elements attribute.
-///
-/// elements-literal-type ::= vector-type | ranked-tensor-type
-///
-/// This method also checks the type has static shape.
-ShapedType Parser::parseElementsLiteralType(Type type) {
- // If the user didn't provide a type, parse the colon type for the literal.
- if (!type) {
- if (parseToken(Token::colon, "expected ':'"))
- return nullptr;
- if (!(type = parseType()))
- return nullptr;
- }
-
- if (!type.isa<RankedTensorType>() && !type.isa<VectorType>()) {
- emitError("elements literal must be a ranked tensor or vector type");
- return nullptr;
- }
-
- auto sType = type.cast<ShapedType>();
- if (!sType.hasStaticShape())
- return (emitError("elements literal type must have static shape"), nullptr);
-
- return sType;
-}
-
-/// Parse a sparse elements attribute.
-Attribute Parser::parseSparseElementsAttr(Type attrType) {
- consumeToken(Token::kw_sparse);
- if (parseToken(Token::less, "Expected '<' after 'sparse'"))
- return nullptr;
-
- /// Parse the indices. We don't allow hex values here as we may need to use
- /// the inferred shape.
- auto indicesLoc = getToken().getLoc();
- TensorLiteralParser indiceParser(*this);
- if (indiceParser.parse(/*allowHex=*/false))
- return nullptr;
-
- if (parseToken(Token::comma, "expected ','"))
- return nullptr;
-
- /// Parse the values.
- auto valuesLoc = getToken().getLoc();
- TensorLiteralParser valuesParser(*this);
- if (valuesParser.parse(/*allowHex=*/true))
- return nullptr;
-
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
-
- auto type = parseElementsLiteralType(attrType);
- if (!type)
- return nullptr;
-
- // If the indices are a splat, i.e. the literal parser parsed an element and
- // not a list, we set the shape explicitly. The indices are represented by a
- // 2-dimensional shape where the second dimension is the rank of the type.
- // Given that the parsed indices is a splat, we know that we only have one
- // indice and thus one for the first dimension.
- auto indiceEltType = builder.getIntegerType(64);
- ShapedType indicesType;
- if (indiceParser.getShape().empty()) {
- indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
- } else {
- // Otherwise, set the shape to the one parsed by the literal parser.
- indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
- }
- auto indices = indiceParser.getAttr(indicesLoc, indicesType);
-
- // If the values are a splat, set the shape explicitly based on the number of
- // indices. The number of indices is encoded in the first dimension of the
- // indice shape type.
- auto valuesEltType = type.getElementType();
- ShapedType valuesType =
- valuesParser.getShape().empty()
- ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
- : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
- auto values = valuesParser.getAttr(valuesLoc, valuesType);
-
- /// Sanity check.
- if (valuesType.getRank() != 1)
- return (emitError("expected 1-d tensor for values"), nullptr);
-
- auto sameShape = (indicesType.getRank() == 1) ||
- (type.getRank() == indicesType.getDimSize(1));
- auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
- if (!sameShape || !sameElementNum) {
- emitError() << "expected shape ([" << type.getShape()
- << "]); inferred shape of indices literal (["
- << indicesType.getShape()
- << "]); inferred shape of values literal (["
- << valuesType.getShape() << "])";
- return nullptr;
- }
-
- // Build the sparse elements attribute by the indices and values.
- return SparseElementsAttr::get(type, indices, values);
-}
-
-//===----------------------------------------------------------------------===//
-// Location parsing.
-//===----------------------------------------------------------------------===//
-
-/// Parse a location.
-///
-/// location ::= `loc` inline-location
-/// inline-location ::= '(' location-inst ')'
-///
-ParseResult Parser::parseLocation(LocationAttr &loc) {
- // Check for 'loc' identifier.
- if (parseToken(Token::kw_loc, "expected 'loc' keyword"))
- return emitError();
-
- // Parse the inline-location.
- if (parseToken(Token::l_paren, "expected '(' in inline location") ||
- parseLocationInstance(loc) ||
- parseToken(Token::r_paren, "expected ')' in inline location"))
- return failure();
- return success();
-}
-
-/// Specific location instances.
-///
-/// location-inst ::= filelinecol-location |
-/// name-location |
-/// callsite-location |
-/// fused-location |
-/// unknown-location
-/// filelinecol-location ::= string-literal ':' integer-literal
-/// ':' integer-literal
-/// name-location ::= string-literal
-/// callsite-location ::= 'callsite' '(' location-inst 'at' location-inst ')'
-/// fused-location ::= fused ('<' attribute-value '>')?
-/// '[' location-inst (location-inst ',')* ']'
-/// unknown-location ::= 'unknown'
-///
-ParseResult Parser::parseCallSiteLocation(LocationAttr &loc) {
- consumeToken(Token::bare_identifier);
-
- // Parse the '('.
- if (parseToken(Token::l_paren, "expected '(' in callsite location"))
- return failure();
-
- // Parse the callee location.
- LocationAttr calleeLoc;
- if (parseLocationInstance(calleeLoc))
- return failure();
-
- // Parse the 'at'.
- if (getToken().isNot(Token::bare_identifier) ||
- getToken().getSpelling() != "at")
- return emitError("expected 'at' in callsite location");
- consumeToken(Token::bare_identifier);
-
- // Parse the caller location.
- LocationAttr callerLoc;
- if (parseLocationInstance(callerLoc))
- return failure();
-
- // Parse the ')'.
- if (parseToken(Token::r_paren, "expected ')' in callsite location"))
- return failure();
-
- // Return the callsite location.
- loc = CallSiteLoc::get(calleeLoc, callerLoc);
- return success();
-}
-
-ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
- consumeToken(Token::bare_identifier);
-
- // Try to parse the optional metadata.
- Attribute metadata;
- if (consumeIf(Token::less)) {
- metadata = parseAttribute();
- if (!metadata)
- return emitError("expected valid attribute metadata");
- // Parse the '>' token.
- if (parseToken(Token::greater,
- "expected '>' after fused location metadata"))
- return failure();
- }
-
- SmallVector<Location, 4> locations;
- auto parseElt = [&] {
- LocationAttr newLoc;
- if (parseLocationInstance(newLoc))
- return failure();
- locations.push_back(newLoc);
- return success();
- };
-
- if (parseToken(Token::l_square, "expected '[' in fused location") ||
- parseCommaSeparatedList(parseElt) ||
- parseToken(Token::r_square, "expected ']' in fused location"))
- return failure();
-
- // Return the fused location.
- loc = FusedLoc::get(locations, metadata, getContext());
- return success();
-}
-
-ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
- auto *ctx = getContext();
- auto str = getToken().getStringValue();
- consumeToken(Token::string);
-
- // If the next token is ':' this is a filelinecol location.
- if (consumeIf(Token::colon)) {
- // Parse the line number.
- if (getToken().isNot(Token::integer))
- return emitError("expected integer line number in FileLineColLoc");
- auto line = getToken().getUnsignedIntegerValue();
- if (!line.hasValue())
- return emitError("expected integer line number in FileLineColLoc");
- consumeToken(Token::integer);
-
- // Parse the ':'.
- if (parseToken(Token::colon, "expected ':' in FileLineColLoc"))
- return failure();
-
- // Parse the column number.
- if (getToken().isNot(Token::integer))
- return emitError("expected integer column number in FileLineColLoc");
- auto column = getToken().getUnsignedIntegerValue();
- if (!column.hasValue())
- return emitError("expected integer column number in FileLineColLoc");
- consumeToken(Token::integer);
-
- loc = FileLineColLoc::get(str, line.getValue(), column.getValue(), ctx);
- return success();
- }
-
- // Otherwise, this is a NameLoc.
-
- // Check for a child location.
- if (consumeIf(Token::l_paren)) {
- auto childSourceLoc = getToken().getLoc();
-
- // Parse the child location.
- LocationAttr childLoc;
- if (parseLocationInstance(childLoc))
- return failure();
-
- // The child must not be another NameLoc.
- if (childLoc.isa<NameLoc>())
- return emitError(childSourceLoc,
- "child of NameLoc cannot be another NameLoc");
- loc = NameLoc::get(Identifier::get(str, ctx), childLoc);
-
- // Parse the closing ')'.
- if (parseToken(Token::r_paren,
- "expected ')' after child location of NameLoc"))
- return failure();
- } else {
- loc = NameLoc::get(Identifier::get(str, ctx), ctx);
- }
-
- return success();
-}
-
-ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
- // Handle either name or filelinecol locations.
- if (getToken().is(Token::string))
- return parseNameOrFileLineColLocation(loc);
-
- // Bare tokens required for other cases.
- if (!getToken().is(Token::bare_identifier))
- return emitError("expected location instance");
-
- // Check for the 'callsite' signifying a callsite location.
- if (getToken().getSpelling() == "callsite")
- return parseCallSiteLocation(loc);
-
- // If the token is 'fused', then this is a fused location.
- if (getToken().getSpelling() == "fused")
- return parseFusedLocation(loc);
-
- // Check for a 'unknown' for an unknown location.
- if (getToken().getSpelling() == "unknown") {
- consumeToken(Token::bare_identifier);
- loc = UnknownLoc::get(getContext());
- return success();
- }
+#include "Parser.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/Parser.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/bit.h"
+#include "llvm/Support/PrettyStackTrace.h"
+#include "llvm/Support/SourceMgr.h"
+#include <algorithm>
- return emitError("expected location instance");
-}
+using namespace mlir;
+using namespace mlir::detail;
+using llvm::MemoryBuffer;
+using llvm::SMLoc;
+using llvm::SourceMgr;
//===----------------------------------------------------------------------===//
-// Affine parsing.
+// Parser
//===----------------------------------------------------------------------===//
-/// Lower precedence ops (all at the same precedence level). LNoOp is false in
-/// the boolean sense.
-enum AffineLowPrecOp {
- /// Null value.
- LNoOp,
- Add,
- Sub
-};
-
-/// Higher precedence ops - all at the same precedence level. HNoOp is false
-/// in the boolean sense.
-enum AffineHighPrecOp {
- /// Null value.
- HNoOp,
- Mul,
- FloorDiv,
- CeilDiv,
- Mod
-};
-
-namespace {
-/// This is a specialized parser for affine structures (affine maps, affine
-/// expressions, and integer sets), maintaining the state transient to their
-/// bodies.
-class AffineParser : public Parser {
-public:
- AffineParser(ParserState &state, bool allowParsingSSAIds = false,
- function_ref<ParseResult(bool)> parseElement = nullptr)
- : Parser(state), allowParsingSSAIds(allowParsingSSAIds),
- parseElement(parseElement), numDimOperands(0), numSymbolOperands(0) {}
-
- AffineMap parseAffineMapRange(unsigned numDims, unsigned numSymbols);
- ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
- IntegerSet parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols);
- ParseResult parseAffineMapOfSSAIds(AffineMap &map,
- OpAsmParser::Delimiter delimiter);
- void getDimsAndSymbolSSAIds(SmallVectorImpl<StringRef> &dimAndSymbolSSAIds,
- unsigned &numDims);
-
-private:
- // Binary affine op parsing.
- AffineLowPrecOp consumeIfLowPrecOp();
- AffineHighPrecOp consumeIfHighPrecOp();
-
- // Identifier lists for polyhedral structures.
- ParseResult parseDimIdList(unsigned &numDims);
- ParseResult parseSymbolIdList(unsigned &numSymbols);
- ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
- unsigned &numSymbols);
- ParseResult parseIdentifierDefinition(AffineExpr idExpr);
-
- AffineExpr parseAffineExpr();
- AffineExpr parseParentheticalExpr();
- AffineExpr parseNegateExpression(AffineExpr lhs);
- AffineExpr parseIntegerExpr();
- AffineExpr parseBareIdExpr();
- AffineExpr parseSSAIdExpr(bool isSymbol);
- AffineExpr parseSymbolSSAIdExpr();
-
- AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
- AffineExpr rhs, SMLoc opLoc);
- AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
- AffineExpr rhs);
- AffineExpr parseAffineOperandExpr(AffineExpr lhs);
- AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
- AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
- SMLoc llhsOpLoc);
- AffineExpr parseAffineConstraint(bool *isEq);
-
-private:
- bool allowParsingSSAIds;
- function_ref<ParseResult(bool)> parseElement;
- unsigned numDimOperands;
- unsigned numSymbolOperands;
- SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
-};
-} // end anonymous namespace
-
-/// Create an affine binary high precedence op expression (mul's, div's, mod).
-/// opLoc is the location of the op token to be used to report errors
-/// for non-conforming expressions.
-AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
- AffineExpr lhs, AffineExpr rhs,
- SMLoc opLoc) {
- // TODO: make the error location info accurate.
- switch (op) {
- case Mul:
- if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
- emitError(opLoc, "non-affine expression: at least one of the multiply "
- "operands has to be either a constant or symbolic");
- return nullptr;
- }
- return lhs * rhs;
- case FloorDiv:
- if (!rhs.isSymbolicOrConstant()) {
- emitError(opLoc, "non-affine expression: right operand of floordiv "
- "has to be either a constant or symbolic");
- return nullptr;
- }
- return lhs.floorDiv(rhs);
- case CeilDiv:
- if (!rhs.isSymbolicOrConstant()) {
- emitError(opLoc, "non-affine expression: right operand of ceildiv "
- "has to be either a constant or symbolic");
- return nullptr;
- }
- return lhs.ceilDiv(rhs);
- case Mod:
- if (!rhs.isSymbolicOrConstant()) {
- emitError(opLoc, "non-affine expression: right operand of mod "
- "has to be either a constant or symbolic");
- return nullptr;
- }
- return lhs % rhs;
- case HNoOp:
- llvm_unreachable("can't create affine expression for null high prec op");
- return nullptr;
- }
- llvm_unreachable("Unknown AffineHighPrecOp");
-}
-
-/// Create an affine binary low precedence op expression (add, sub).
-AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
- AffineExpr lhs, AffineExpr rhs) {
- switch (op) {
- case AffineLowPrecOp::Add:
- return lhs + rhs;
- case AffineLowPrecOp::Sub:
- return lhs - rhs;
- case AffineLowPrecOp::LNoOp:
- llvm_unreachable("can't create affine expression for null low prec op");
- return nullptr;
- }
- llvm_unreachable("Unknown AffineLowPrecOp");
-}
-
-/// Consume this token if it is a lower precedence affine op (there are only
-/// two precedence levels).
-AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
- switch (getToken().getKind()) {
- case Token::plus:
- consumeToken(Token::plus);
- return AffineLowPrecOp::Add;
- case Token::minus:
- consumeToken(Token::minus);
- return AffineLowPrecOp::Sub;
- default:
- return AffineLowPrecOp::LNoOp;
- }
-}
-
-/// Consume this token if it is a higher precedence affine op (there are only
-/// two precedence levels)
-AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
- switch (getToken().getKind()) {
- case Token::star:
- consumeToken(Token::star);
- return Mul;
- case Token::kw_floordiv:
- consumeToken(Token::kw_floordiv);
- return FloorDiv;
- case Token::kw_ceildiv:
- consumeToken(Token::kw_ceildiv);
- return CeilDiv;
- case Token::kw_mod:
- consumeToken(Token::kw_mod);
- return Mod;
- default:
- return HNoOp;
- }
-}
-
-/// Parse a high precedence op expression list: mul, div, and mod are high
-/// precedence binary ops, i.e., parse a
-/// expr_1 op_1 expr_2 op_2 ... expr_n
-/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
-/// All affine binary ops are left associative.
-/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
-/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
-/// null. llhsOpLoc is the location of the llhsOp token that will be used to
-/// report an error for non-conforming expressions.
-AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
- AffineHighPrecOp llhsOp,
- SMLoc llhsOpLoc) {
- AffineExpr lhs = parseAffineOperandExpr(llhs);
- if (!lhs)
- return nullptr;
-
- // Found an LHS. Parse the remaining expression.
- auto opLoc = getToken().getLoc();
- if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
- if (llhs) {
- AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
- if (!expr)
- return nullptr;
- return parseAffineHighPrecOpExpr(expr, op, opLoc);
- }
- // No LLHS, get RHS
- return parseAffineHighPrecOpExpr(lhs, op, opLoc);
- }
-
- // This is the last operand in this expression.
- if (llhs)
- return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
-
- // No llhs, 'lhs' itself is the expression.
- return lhs;
-}
-
-/// Parse an affine expression inside parentheses.
-///
-/// affine-expr ::= `(` affine-expr `)`
-AffineExpr AffineParser::parseParentheticalExpr() {
- if (parseToken(Token::l_paren, "expected '('"))
- return nullptr;
- if (getToken().is(Token::r_paren))
- return (emitError("no expression inside parentheses"), nullptr);
-
- auto expr = parseAffineExpr();
- if (!expr)
- return nullptr;
- if (parseToken(Token::r_paren, "expected ')'"))
- return nullptr;
-
- return expr;
-}
-
-/// Parse the negation expression.
-///
-/// affine-expr ::= `-` affine-expr
-AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
- if (parseToken(Token::minus, "expected '-'"))
- return nullptr;
-
- AffineExpr operand = parseAffineOperandExpr(lhs);
- // Since negation has the highest precedence of all ops (including high
- // precedence ops) but lower than parentheses, we are only going to use
- // parseAffineOperandExpr instead of parseAffineExpr here.
- if (!operand)
- // Extra error message although parseAffineOperandExpr would have
- // complained. Leads to a better diagnostic.
- return (emitError("missing operand of negation"), nullptr);
- return (-1) * operand;
-}
-
-/// Parse a bare id that may appear in an affine expression.
-///
-/// affine-expr ::= bare-id
-AffineExpr AffineParser::parseBareIdExpr() {
- if (getToken().isNot(Token::bare_identifier))
- return (emitError("expected bare identifier"), nullptr);
-
- StringRef sRef = getTokenSpelling();
- for (auto entry : dimsAndSymbols) {
- if (entry.first == sRef) {
- consumeToken(Token::bare_identifier);
- return entry.second;
- }
- }
-
- return (emitError("use of undeclared identifier"), nullptr);
-}
-
-/// Parse an SSA id which may appear in an affine expression.
-AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
- if (!allowParsingSSAIds)
- return (emitError("unexpected ssa identifier"), nullptr);
- if (getToken().isNot(Token::percent_identifier))
- return (emitError("expected ssa identifier"), nullptr);
- auto name = getTokenSpelling();
- // Check if we already parsed this SSA id.
- for (auto entry : dimsAndSymbols) {
- if (entry.first == name) {
- consumeToken(Token::percent_identifier);
- return entry.second;
- }
- }
- // Parse the SSA id and add an AffineDim/SymbolExpr to represent it.
- if (parseElement(isSymbol))
- return (emitError("failed to parse ssa identifier"), nullptr);
- auto idExpr = isSymbol
- ? getAffineSymbolExpr(numSymbolOperands++, getContext())
- : getAffineDimExpr(numDimOperands++, getContext());
- dimsAndSymbols.push_back({name, idExpr});
- return idExpr;
-}
-
-AffineExpr AffineParser::parseSymbolSSAIdExpr() {
- if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
- parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
- return nullptr;
- AffineExpr symbolExpr = parseSSAIdExpr(/*isSymbol=*/true);
- if (!symbolExpr)
- return nullptr;
- if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
- return nullptr;
- return symbolExpr;
-}
-
-/// Parse a positive integral constant appearing in an affine expression.
-///
-/// affine-expr ::= integer-literal
-AffineExpr AffineParser::parseIntegerExpr() {
- auto val = getToken().getUInt64IntegerValue();
- if (!val.hasValue() || (int64_t)val.getValue() < 0)
- return (emitError("constant too large for index"), nullptr);
-
- consumeToken(Token::integer);
- return builder.getAffineConstantExpr((int64_t)val.getValue());
-}
-
-/// Parses an expression that can be a valid operand of an affine expression.
-/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
-/// operator, the rhs of which is being parsed. This is used to determine
-/// whether an error should be emitted for a missing right operand.
-// Eg: for an expression without parentheses (like i + j + k + l), each
-// of the four identifiers is an operand. For i + j*k + l, j*k is not an
-// operand expression, it's an op expression and will be parsed via
-// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
-// -l are valid operands that will be parsed by this function.
-AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
- switch (getToken().getKind()) {
- case Token::bare_identifier:
- return parseBareIdExpr();
- case Token::kw_symbol:
- return parseSymbolSSAIdExpr();
- case Token::percent_identifier:
- return parseSSAIdExpr(/*isSymbol=*/false);
- case Token::integer:
- return parseIntegerExpr();
- case Token::l_paren:
- return parseParentheticalExpr();
- case Token::minus:
- return parseNegateExpression(lhs);
- case Token::kw_ceildiv:
- case Token::kw_floordiv:
- case Token::kw_mod:
- case Token::plus:
- case Token::star:
- if (lhs)
- emitError("missing right operand of binary operator");
- else
- emitError("missing left operand of binary operator");
- return nullptr;
- default:
- if (lhs)
- emitError("missing right operand of binary operator");
- else
- emitError("expected affine expression");
- return nullptr;
- }
-}
-
-/// Parse affine expressions that are bare-id's, integer constants,
-/// parenthetical affine expressions, and affine op expressions that are a
-/// composition of those.
-///
-/// All binary op's associate from left to right.
-///
-/// {add, sub} have lower precedence than {mul, div, and mod}.
-///
-/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
-/// ceildiv, and mod are at the same higher precedence level. Negation has
-/// higher precedence than any binary op.
-///
-/// llhs: the affine expression appearing on the left of the one being parsed.
-/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
-/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
-/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
-/// associativity.
-///
-/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
-/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
-/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
-AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
- AffineLowPrecOp llhsOp) {
- AffineExpr lhs;
- if (!(lhs = parseAffineOperandExpr(llhs)))
- return nullptr;
-
- // Found an LHS. Deal with the ops.
- if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
- if (llhs) {
- AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
- return parseAffineLowPrecOpExpr(sum, lOp);
- }
- // No LLHS, get RHS and form the expression.
- return parseAffineLowPrecOpExpr(lhs, lOp);
- }
- auto opLoc = getToken().getLoc();
- if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
- // We have a higher precedence op here. Get the rhs operand for the llhs
- // through parseAffineHighPrecOpExpr.
- AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
- if (!highRes)
- return nullptr;
-
- // If llhs is null, the product forms the first operand of the yet to be
- // found expression. If non-null, the op to associate with llhs is llhsOp.
- AffineExpr expr =
- llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
-
- // Recurse for subsequent low prec op's after the affine high prec op
- // expression.
- if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
- return parseAffineLowPrecOpExpr(expr, nextOp);
- return expr;
- }
- // Last operand in the expression list.
- if (llhs)
- return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
- // No llhs, 'lhs' itself is the expression.
- return lhs;
-}
-
-/// Parse an affine expression.
-/// affine-expr ::= `(` affine-expr `)`
-/// | `-` affine-expr
-/// | affine-expr `+` affine-expr
-/// | affine-expr `-` affine-expr
-/// | affine-expr `*` affine-expr
-/// | affine-expr `floordiv` affine-expr
-/// | affine-expr `ceildiv` affine-expr
-/// | affine-expr `mod` affine-expr
-/// | bare-id
-/// | integer-literal
-///
-/// Additional conditions are checked depending on the production. For eg.,
-/// one of the operands for `*` has to be either constant/symbolic; the second
-/// operand for floordiv, ceildiv, and mod has to be a positive integer.
-AffineExpr AffineParser::parseAffineExpr() {
- return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
-}
-
-/// Parse a dim or symbol from the lists appearing before the actual
-/// expressions of the affine map. Update our state to store the
-/// dimensional/symbolic identifier.
-ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
- if (getToken().isNot(Token::bare_identifier))
- return emitError("expected bare identifier");
-
- auto name = getTokenSpelling();
- for (auto entry : dimsAndSymbols) {
- if (entry.first == name)
- return emitError("redefinition of identifier '" + name + "'");
- }
- consumeToken(Token::bare_identifier);
-
- dimsAndSymbols.push_back({name, idExpr});
- return success();
-}
-
-/// Parse the list of dimensional identifiers to an affine map.
-ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
- if (parseToken(Token::l_paren,
- "expected '(' at start of dimensional identifiers list")) {
- return failure();
- }
-
- auto parseElt = [&]() -> ParseResult {
- auto dimension = getAffineDimExpr(numDims++, getContext());
- return parseIdentifierDefinition(dimension);
- };
- return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
-}
-
-/// Parse the list of symbolic identifiers to an affine map.
-ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
- consumeToken(Token::l_square);
- auto parseElt = [&]() -> ParseResult {
- auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
- return parseIdentifierDefinition(symbol);
- };
- return parseCommaSeparatedListUntil(Token::r_square, parseElt);
-}
-
-/// Parse the list of symbolic identifiers to an affine map.
-ParseResult
-AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
- unsigned &numSymbols) {
- if (parseDimIdList(numDims)) {
- return failure();
- }
- if (!getToken().is(Token::l_square)) {
- numSymbols = 0;
- return success();
- }
- return parseSymbolIdList(numSymbols);
-}
-
-/// Parses an ambiguous affine map or integer set definition inline.
-ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
- IntegerSet &set) {
- unsigned numDims = 0, numSymbols = 0;
-
- // List of dimensional and optional symbol identifiers.
- if (parseDimAndOptionalSymbolIdList(numDims, numSymbols)) {
- return failure();
- }
-
- // This is needed for parsing attributes as we wouldn't know whether we would
- // be parsing an integer set attribute or an affine map attribute.
- bool isArrow = getToken().is(Token::arrow);
- bool isColon = getToken().is(Token::colon);
- if (!isArrow && !isColon) {
- return emitError("expected '->' or ':'");
- } else if (isArrow) {
- parseToken(Token::arrow, "expected '->' or '['");
- map = parseAffineMapRange(numDims, numSymbols);
- return map ? success() : failure();
- } else if (parseToken(Token::colon, "expected ':' or '['")) {
+/// Parse a comma separated list of elements that must have at least one entry
+/// in it.
+ParseResult Parser::parseCommaSeparatedList(
+ const std::function<ParseResult()> &parseElement) {
+ // Non-empty case starts with an element.
+ if (parseElement())
return failure();
- }
- if ((set = parseIntegerSetConstraints(numDims, numSymbols)))
- return success();
-
- return failure();
-}
-
-/// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
-ParseResult
-AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
- OpAsmParser::Delimiter delimiter) {
- Token::Kind rightToken;
- switch (delimiter) {
- case OpAsmParser::Delimiter::Square:
- if (parseToken(Token::l_square, "expected '['"))
- return failure();
- rightToken = Token::r_square;
- break;
- case OpAsmParser::Delimiter::Paren:
- if (parseToken(Token::l_paren, "expected '('"))
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::comma)) {
+ if (parseElement())
return failure();
- rightToken = Token::r_paren;
- break;
- default:
- return emitError("unexpected delimiter");
}
-
- SmallVector<AffineExpr, 4> exprs;
- auto parseElt = [&]() -> ParseResult {
- auto elt = parseAffineExpr();
- exprs.push_back(elt);
- return elt ? success() : failure();
- };
-
- // Parse a multi-dimensional affine expression (a comma-separated list of
- // 1-d affine expressions); the list can be empty. Grammar:
- // multi-dim-affine-expr ::= `(` `)`
- // | `(` affine-expr (`,` affine-expr)* `)`
- if (parseCommaSeparatedListUntil(rightToken, parseElt,
- /*allowEmptyList=*/true))
- return failure();
- // Parsed a valid affine map.
- map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
- exprs, getContext());
return success();
}
-/// Parse the range and sizes affine map definition inline.
-///
-/// affine-map ::= dim-and-symbol-id-lists `->` multi-dim-affine-expr
-///
-/// multi-dim-affine-expr ::= `(` `)`
-/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
-AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
- unsigned numSymbols) {
- parseToken(Token::l_paren, "expected '(' at start of affine map range");
-
- SmallVector<AffineExpr, 4> exprs;
- auto parseElt = [&]() -> ParseResult {
- auto elt = parseAffineExpr();
- ParseResult res = elt ? success() : failure();
- exprs.push_back(elt);
- return res;
- };
-
- // Parse a multi-dimensional affine expression (a comma-separated list of
- // 1-d affine expressions). Grammar:
- // multi-dim-affine-expr ::= `(` `)`
- // | `(` affine-expr (`,` affine-expr)* `)`
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
- return AffineMap();
-
- // Parsed a valid affine map.
- return AffineMap::get(numDims, numSymbols, exprs, getContext());
-}
-
-/// Parse an affine constraint.
-/// affine-constraint ::= affine-expr `>=` `0`
-/// | affine-expr `==` `0`
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token. This allows empty lists if allowEmptyList is true.
///
-/// isEq is set to true if the parsed constraint is an equality, false if it
-/// is an inequality (greater than or equal).
+/// abstract-list ::= rightToken // if allowEmptyList == true
+/// abstract-list ::= element (',' element)* rightToken
///
-AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
- AffineExpr expr = parseAffineExpr();
- if (!expr)
- return nullptr;
-
- if (consumeIf(Token::greater) && consumeIf(Token::equal) &&
- getToken().is(Token::integer)) {
- auto dim = getToken().getUnsignedIntegerValue();
- if (dim.hasValue() && dim.getValue() == 0) {
- consumeToken(Token::integer);
- *isEq = false;
- return expr;
- }
- return (emitError("expected '0' after '>='"), nullptr);
+ParseResult Parser::parseCommaSeparatedListUntil(
+ Token::Kind rightToken, const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (getToken().is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return success();
}
- if (consumeIf(Token::equal) && consumeIf(Token::equal) &&
- getToken().is(Token::integer)) {
- auto dim = getToken().getUnsignedIntegerValue();
- if (dim.hasValue() && dim.getValue() == 0) {
- consumeToken(Token::integer);
- *isEq = true;
- return expr;
- }
- return (emitError("expected '0' after '=='"), nullptr);
- }
+ if (parseCommaSeparatedList(parseElement) ||
+ parseToken(rightToken, "expected ',' or '" +
+ Token::getTokenSpelling(rightToken) + "'"))
+ return failure();
- return (emitError("expected '== 0' or '>= 0' at end of affine constraint"),
- nullptr);
+ return success();
}
-/// Parse the constraints that are part of an integer set definition.
-/// integer-set-inline
-/// ::= dim-and-symbol-id-lists `:`
-/// '(' affine-constraint-conjunction? ')'
-/// affine-constraint-conjunction ::= affine-constraint (`,`
-/// affine-constraint)*
-///
-IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
- unsigned numSymbols) {
- if (parseToken(Token::l_paren,
- "expected '(' at start of integer set constraint list"))
- return IntegerSet();
-
- SmallVector<AffineExpr, 4> constraints;
- SmallVector<bool, 4> isEqs;
- auto parseElt = [&]() -> ParseResult {
- bool isEq;
- auto elt = parseAffineConstraint(&isEq);
- ParseResult res = elt ? success() : failure();
- if (elt) {
- constraints.push_back(elt);
- isEqs.push_back(isEq);
- }
- return res;
- };
-
- // Parse a list of affine constraints (comma-separated).
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
- return IntegerSet();
-
- // If no constraints were parsed, then treat this as a degenerate 'true' case.
- if (constraints.empty()) {
- /* 0 == 0 */
- auto zero = getAffineConstantExpr(0, getContext());
- return IntegerSet::get(numDims, numSymbols, zero, true);
- }
-
- // Parsed a valid integer set.
- return IntegerSet::get(numDims, numSymbols, constraints, isEqs);
-}
+InFlightDiagnostic Parser::emitError(SMLoc loc, const Twine &message) {
+ auto diag = mlir::emitError(getEncodedSourceLocation(loc), message);
-/// Parse an ambiguous reference to either and affine map or an integer set.
-ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
- IntegerSet &set) {
- return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
-}
-ParseResult Parser::parseAffineMapReference(AffineMap &map) {
- llvm::SMLoc curLoc = getToken().getLoc();
- IntegerSet set;
- if (parseAffineMapOrIntegerSetReference(map, set))
- return failure();
- if (set)
- return emitError(curLoc, "expected AffineMap, but got IntegerSet");
- return success();
-}
-ParseResult Parser::parseIntegerSetReference(IntegerSet &set) {
- llvm::SMLoc curLoc = getToken().getLoc();
- AffineMap map;
- if (parseAffineMapOrIntegerSetReference(map, set))
- return failure();
- if (map)
- return emitError(curLoc, "expected IntegerSet, but got AffineMap");
- return success();
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already reported the error.
+ if (getToken().is(Token::error))
+ diag.abandon();
+ return diag;
}
-/// Parse an AffineMap of SSA ids. The callback 'parseElement' is used to
-/// parse SSA value uses encountered while parsing affine expressions.
-ParseResult
-Parser::parseAffineMapOfSSAIds(AffineMap &map,
- function_ref<ParseResult(bool)> parseElement,
- OpAsmParser::Delimiter delimiter) {
- return AffineParser(state, /*allowParsingSSAIds=*/true, parseElement)
- .parseAffineMapOfSSAIds(map, delimiter);
+/// Consume the specified token if present and return success. On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+ const Twine &message) {
+ if (consumeIf(expectedToken))
+ return success();
+ return emitError(message);
}
//===----------------------------------------------------------------------===//
@@ -3413,7 +187,7 @@ class OperationParser : public Parser {
/// Parse an operation instance that is in the op-defined custom form.
/// resultInfo specifies information about the "%name =" specifiers.
- Operation *parseCustomOperation(ArrayRef<ResultRecord> resultInfo);
+ Operation *parseCustomOperation(ArrayRef<ResultRecord> resultIDs);
//===--------------------------------------------------------------------===//
// Region Parsing
@@ -4300,7 +1074,7 @@ class CustomOpAsmParser : public OpAsmParser {
if (atToken.isNot(Token::at_identifier))
return failure();
- result = getBuilder().getStringAttr(extractSymbolReference(atToken));
+ result = getBuilder().getStringAttr(atToken.getSymbolReference());
attrs.push_back(getBuilder().getNamedAttr(attrName, result));
parser.consumeToken();
return success();
@@ -5141,52 +1915,3 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr,
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
return parseSourceFile(sourceMgr, context);
}
-
-/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
-/// parsing failed, nullptr is returned. The number of bytes read from the input
-/// string is returned in 'numRead'.
-template <typename T, typename ParserFn>
-static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
- ParserFn &&parserFn) {
- SymbolState aliasState;
- return parseSymbol<T>(
- inputStr, context, aliasState,
- [&](Parser &parser) {
- SourceMgrDiagnosticHandler handler(
- const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
- parser.getContext());
- return parserFn(parser);
- },
- &numRead);
-}
-
-Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
- size_t numRead = 0;
- return parseAttribute(attrStr, context, numRead);
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
- size_t numRead = 0;
- return parseAttribute(attrStr, type, numRead);
-}
-
-Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
- size_t &numRead) {
- return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
- return parser.parseAttribute();
- });
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
- return parseSymbol<Attribute>(
- attrStr, type.getContext(), numRead,
- [type](Parser &parser) { return parser.parseAttribute(type); });
-}
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
- size_t numRead = 0;
- return parseType(typeStr, context, numRead);
-}
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
- return parseSymbol<Type>(typeStr, context, numRead,
- [](Parser &parser) { return parser.parseType(); });
-}
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
new file mode 100644
index 000000000000..3d82d622bf06
--- /dev/null
+++ b/mlir/lib/Parser/Parser.h
@@ -0,0 +1,270 @@
+//===- Parser.h - MLIR Base Parser Class ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_PARSER_H
+#define MLIR_LIB_PARSER_PARSER_H
+
+#include "ParserState.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/OpImplementation.h"
+
+namespace mlir {
+namespace detail {
+//===----------------------------------------------------------------------===//
+// Parser
+//===----------------------------------------------------------------------===//
+
+/// This class implement support for parsing global entities like attributes and
+/// types. It is intended to be subclassed by specialized subparsers that
+/// include state.
+class Parser {
+public:
+ Builder builder;
+
+ Parser(ParserState &state) : builder(state.context), state(state) {}
+
+ // Helper methods to get stuff from the parser-global state.
+ ParserState &getState() const { return state; }
+ MLIRContext *getContext() const { return state.context; }
+ const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
+
+ /// Parse a comma-separated list of elements up until the specified end token.
+ ParseResult
+ parseCommaSeparatedListUntil(Token::Kind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(const std::function<ParseResult()> &parseElement);
+
+ ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
+
+ // We have two forms of parsing methods - those that return a non-null
+ // pointer on success, and those that return a ParseResult to indicate whether
+ // they returned a failure. The second class fills in by-reference arguments
+ // as the results of their action.
+
+ //===--------------------------------------------------------------------===//
+ // Error Handling
+ //===--------------------------------------------------------------------===//
+
+ /// Emit an error and return failure.
+ InFlightDiagnostic emitError(const Twine &message = {}) {
+ return emitError(state.curToken.getLoc(), message);
+ }
+ InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message = {});
+
+ /// Encode the specified source location information into an attribute for
+ /// attachment to the IR.
+ Location getEncodedSourceLocation(llvm::SMLoc loc) {
+ // If there are no active nested parsers, we can get the encoded source
+ // location directly.
+ if (state.parserDepth == 0)
+ return state.lex.getEncodedSourceLocation(loc);
+ // Otherwise, we need to re-encode it to point to the top level buffer.
+ return state.symbols.topLevelLexer->getEncodedSourceLocation(
+ remapLocationToTopLevelBuffer(loc));
+ }
+
+ /// Remaps the given SMLoc to the top level lexer of the parser. This is used
+ /// to adjust locations of potentially nested parsers to ensure that they can
+ /// be emitted properly as diagnostics.
+ llvm::SMLoc remapLocationToTopLevelBuffer(llvm::SMLoc loc) {
+ // If there are no active nested parsers, we can return location directly.
+ SymbolState &symbols = state.symbols;
+ if (state.parserDepth == 0)
+ return loc;
+ assert(symbols.topLevelLexer && "expected valid top-level lexer");
+
+ // Otherwise, we need to remap the location to the main parser. This is
+ // simply offseting the location onto the location of the last nested
+ // parser.
+ size_t offset = loc.getPointer() - state.lex.getBufferBegin();
+ auto *rawLoc =
+ symbols.nestedParserLocs[state.parserDepth - 1].getPointer() + offset;
+ return llvm::SMLoc::getFromPointer(rawLoc);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Token Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Return the current token the parser is inspecting.
+ const Token &getToken() const { return state.curToken; }
+ StringRef getTokenSpelling() const { return state.curToken.getSpelling(); }
+
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::Kind kind) {
+ if (state.curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+
+ /// Advance the current lexer onto the next token.
+ void consumeToken() {
+ assert(state.curToken.isNot(Token::eof, Token::error) &&
+ "shouldn't advance past EOF or errors");
+ state.curToken = state.lex.lexToken();
+ }
+
+ /// Advance the current lexer onto the next token, asserting what the expected
+ /// current token is. This is preferred to the above method because it leads
+ /// to more self-documenting code with better checking.
+ void consumeToken(Token::Kind kind) {
+ assert(state.curToken.is(kind) && "consumed an unexpected token");
+ consumeToken();
+ }
+
+ /// Consume the specified token if present and return success. On failure,
+ /// output a diagnostic and return failure.
+ ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
+ //===--------------------------------------------------------------------===//
+ // Type Parsing
+ //===--------------------------------------------------------------------===//
+
+ ParseResult parseFunctionResultTypes(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+ ParseResult parseTypeListParens(SmallVectorImpl<Type> &elements);
+
+ /// Optionally parse a type.
+ OptionalParseResult parseOptionalType(Type &type);
+
+ /// Parse an arbitrary type.
+ Type parseType();
+
+ /// Parse a complex type.
+ Type parseComplexType();
+
+ /// Parse an extended type.
+ Type parseExtendedType();
+
+ /// Parse a function type.
+ Type parseFunctionType();
+
+ /// Parse a memref type.
+ Type parseMemRefType();
+
+ /// Parse a non function type.
+ Type parseNonFunctionType();
+
+ /// Parse a tensor type.
+ Type parseTensorType();
+
+ /// Parse a tuple type.
+ Type parseTupleType();
+
+ /// Parse a vector type.
+ VectorType parseVectorType();
+ ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic = true);
+ ParseResult parseXInDimensionList();
+
+ /// Parse strided layout specification.
+ ParseResult parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides);
+
+ // Parse a brace-delimiter list of comma-separated integers with `?` as an
+ // unknown marker.
+ ParseResult parseStrideList(SmallVectorImpl<int64_t> &dimensions);
+
+ //===--------------------------------------------------------------------===//
+ // Attribute Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an arbitrary attribute with an optional type.
+ Attribute parseAttribute(Type type = {});
+
+ /// Parse an attribute dictionary.
+ ParseResult parseAttributeDict(NamedAttrList &attributes);
+
+ /// Parse an extended attribute.
+ Attribute parseExtendedAttr(Type type);
+
+ /// Parse a float attribute.
+ Attribute parseFloatAttr(Type type, bool isNegative);
+
+ /// Parse a decimal or a hexadecimal literal, which can be either an integer
+ /// or a float attribute.
+ Attribute parseDecOrHexAttr(Type type, bool isNegative);
+
+ /// Parse an opaque elements attribute.
+ Attribute parseOpaqueElementsAttr(Type attrType);
+
+ /// Parse a dense elements attribute.
+ Attribute parseDenseElementsAttr(Type attrType);
+ ShapedType parseElementsLiteralType(Type type);
+
+ /// Parse a sparse elements attribute.
+ Attribute parseSparseElementsAttr(Type attrType);
+
+ //===--------------------------------------------------------------------===//
+ // Location Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse an inline location.
+ ParseResult parseLocation(LocationAttr &loc);
+
+ /// Parse a raw location instance.
+ ParseResult parseLocationInstance(LocationAttr &loc);
+
+ /// Parse a callsite location instance.
+ ParseResult parseCallSiteLocation(LocationAttr &loc);
+
+ /// Parse a fused location instance.
+ ParseResult parseFusedLocation(LocationAttr &loc);
+
+ /// Parse a name or FileLineCol location instance.
+ ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
+
+ /// Parse an optional trailing location.
+ ///
+ /// trailing-location ::= (`loc` `(` location `)`)?
+ ///
+ ParseResult parseOptionalTrailingLocation(Location &loc) {
+ // If there is a 'loc' we parse a trailing location.
+ if (!getToken().is(Token::kw_loc))
+ return success();
+
+ // Parse the location.
+ LocationAttr directLoc;
+ if (parseLocation(directLoc))
+ return failure();
+ loc = directLoc;
+ return success();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Affine Parsing
+ //===--------------------------------------------------------------------===//
+
+ /// Parse a reference to either an affine map, or an integer set.
+ ParseResult parseAffineMapOrIntegerSetReference(AffineMap &map,
+ IntegerSet &set);
+ ParseResult parseAffineMapReference(AffineMap &map);
+ ParseResult parseIntegerSetReference(IntegerSet &set);
+
+ /// Parse an AffineMap where the dim and symbol identifiers are SSA ids.
+ ParseResult
+ parseAffineMapOfSSAIds(AffineMap &map,
+ function_ref<ParseResult(bool)> parseElement,
+ OpAsmParser::Delimiter delimiter);
+
+private:
+ /// The Parser is subclassed and reinstantiated. Do not add additional
+ /// non-trivial state here, add it to the ParserState class.
+ ParserState &state;
+};
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_PARSER_H
diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h
new file mode 100644
index 000000000000..27048d5d2f03
--- /dev/null
+++ b/mlir/lib/Parser/ParserState.h
@@ -0,0 +1,85 @@
+//===- ParserState.h - MLIR ParserState -------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_LIB_PARSER_PARSERSTATE_H
+#define MLIR_LIB_PARSER_PARSERSTATE_H
+
+#include "Lexer.h"
+#include "mlir/IR/Attributes.h"
+#include "llvm/ADT/StringMap.h"
+
+namespace mlir {
+namespace detail {
+
+//===----------------------------------------------------------------------===//
+// SymbolState
+//===----------------------------------------------------------------------===//
+
+/// This class contains record of any parsed top-level symbols.
+struct SymbolState {
+ // A map from attribute alias identifier to Attribute.
+ llvm::StringMap<Attribute> attributeAliasDefinitions;
+
+ // A map from type alias identifier to Type.
+ llvm::StringMap<Type> typeAliasDefinitions;
+
+ /// A set of locations into the main parser memory buffer for each of the
+ /// active nested parsers. Given that some nested parsers, i.e. custom dialect
+ /// parsers, operate on a temporary memory buffer, this provides an anchor
+ /// point for emitting diagnostics.
+ SmallVector<llvm::SMLoc, 1> nestedParserLocs;
+
+ /// The top-level lexer that contains the original memory buffer provided by
+ /// the user. This is used by nested parsers to get a properly encoded source
+ /// location.
+ Lexer *topLevelLexer = nullptr;
+};
+
+//===----------------------------------------------------------------------===//
+// ParserState
+//===----------------------------------------------------------------------===//
+
+/// This class refers to all of the state maintained globally by the parser,
+/// such as the current lexer position etc.
+struct ParserState {
+ ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx,
+ SymbolState &symbols)
+ : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()),
+ symbols(symbols), parserDepth(symbols.nestedParserLocs.size()) {
+ // Set the top level lexer for the symbol state if one doesn't exist.
+ if (!symbols.topLevelLexer)
+ symbols.topLevelLexer = &lex;
+ }
+ ~ParserState() {
+ // Reset the top level lexer if it refers the lexer in our state.
+ if (symbols.topLevelLexer == &lex)
+ symbols.topLevelLexer = nullptr;
+ }
+ ParserState(const ParserState &) = delete;
+ void operator=(const ParserState &) = delete;
+
+ /// The context we're parsing into.
+ MLIRContext *const context;
+
+ /// The lexer for the source file we're parsing.
+ Lexer lex;
+
+ /// This is the next token that hasn't been consumed yet.
+ Token curToken;
+
+ /// The current state for symbol parsing.
+ SymbolState &symbols;
+
+ /// The depth of this parser in the nested parsing stack.
+ size_t parserDepth;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_LIB_PARSER_PARSERSTATE_H
diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp
index 7b5fb9a545bb..db6f716ed0a0 100644
--- a/mlir/lib/Parser/Token.cpp
+++ b/mlir/lib/Parser/Token.cpp
@@ -124,6 +124,18 @@ std::string Token::getStringValue() const {
return result;
}
+/// Given a token containing a symbol reference, return the unescaped string
+/// value.
+std::string Token::getSymbolReference() const {
+ assert(is(Token::at_identifier) && "expected valid @-identifier");
+ StringRef nameStr = getSpelling().drop_front();
+
+ // Check to see if the reference is a string literal, or a bare identifier.
+ if (nameStr.front() == '"')
+ return getStringValue();
+ return std::string(nameStr);
+}
+
/// Given a hash_identifier token like #123, try to parse the number out of
/// the identifier, returning None if it is a named identifier like #x or
/// if the integer doesn't fit.
diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h
index 4f9c098bb4b2..8f37b2b438d0 100644
--- a/mlir/lib/Parser/Token.h
+++ b/mlir/lib/Parser/Token.h
@@ -91,6 +91,10 @@ class Token {
/// removing the quote characters and unescaping the contents of the string.
std::string getStringValue() const;
+ /// Given a token containing a symbol reference, return the unescaped string
+ /// value.
+ std::string getSymbolReference() const;
+
// Location processing.
llvm::SMLoc getLoc() const;
llvm::SMLoc getEndLoc() const;
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
new file mode 100644
index 000000000000..68d381f968ad
--- /dev/null
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -0,0 +1,570 @@
+//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the parser for the MLIR Types.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Parser.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::detail;
+
+/// Optionally parse a type.
+OptionalParseResult Parser::parseOptionalType(Type &type) {
+ // There are many
diff erent starting tokens for a type, check them here.
+ switch (getToken().getKind()) {
+ case Token::l_paren:
+ case Token::kw_memref:
+ case Token::kw_tensor:
+ case Token::kw_complex:
+ case Token::kw_tuple:
+ case Token::kw_vector:
+ case Token::inttype:
+ case Token::kw_bf16:
+ case Token::kw_f16:
+ case Token::kw_f32:
+ case Token::kw_f64:
+ case Token::kw_index:
+ case Token::kw_none:
+ case Token::exclamation_identifier:
+ return failure(!(type = parseType()));
+
+ default:
+ return llvm::None;
+ }
+}
+
+/// Parse an arbitrary type.
+///
+/// type ::= function-type
+/// | non-function-type
+///
+Type Parser::parseType() {
+ if (getToken().is(Token::l_paren))
+ return parseFunctionType();
+ return parseNonFunctionType();
+}
+
+/// Parse a function result type.
+///
+/// function-result-type ::= type-list-parens
+/// | non-function-type
+///
+ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
+ if (getToken().is(Token::l_paren))
+ return parseTypeListParens(elements);
+
+ Type t = parseNonFunctionType();
+ if (!t)
+ return failure();
+ elements.push_back(t);
+ return success();
+}
+
+/// Parse a list of types without an enclosing parenthesis. The list must have
+/// at least one member.
+///
+/// type-list-no-parens ::= type (`,` type)*
+///
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
+ auto parseElt = [&]() -> ParseResult {
+ auto elt = parseType();
+ elements.push_back(elt);
+ return elt ? success() : failure();
+ };
+
+ return parseCommaSeparatedList(parseElt);
+}
+
+/// Parse a parenthesized list of types.
+///
+/// type-list-parens ::= `(` `)`
+/// | `(` type-list-no-parens `)`
+///
+ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
+ if (parseToken(Token::l_paren, "expected '('"))
+ return failure();
+
+ // Handle empty lists.
+ if (getToken().is(Token::r_paren))
+ return consumeToken(), success();
+
+ if (parseTypeListNoParens(elements) ||
+ parseToken(Token::r_paren, "expected ')'"))
+ return failure();
+ return success();
+}
+
+/// Parse a complex type.
+///
+/// complex-type ::= `complex` `<` type `>`
+///
+Type Parser::parseComplexType() {
+ consumeToken(Token::kw_complex);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in complex type"))
+ return nullptr;
+
+ llvm::SMLoc elementTypeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType ||
+ parseToken(Token::greater, "expected '>' in complex type"))
+ return nullptr;
+ if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
+ return emitError(elementTypeLoc, "invalid element type for complex"),
+ nullptr;
+
+ return ComplexType::get(elementType);
+}
+
+/// Parse a function type.
+///
+/// function-type ::= type-list-parens `->` function-result-type
+///
+Type Parser::parseFunctionType() {
+ assert(getToken().is(Token::l_paren));
+
+ SmallVector<Type, 4> arguments, results;
+ if (parseTypeListParens(arguments) ||
+ parseToken(Token::arrow, "expected '->' in function type") ||
+ parseFunctionResultTypes(results))
+ return nullptr;
+
+ return builder.getFunctionType(arguments, results);
+}
+
+/// Parse the offset and strides from a strided layout specification.
+///
+/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
+///
+ParseResult Parser::parseStridedLayout(int64_t &offset,
+ SmallVectorImpl<int64_t> &strides) {
+ // Parse offset.
+ consumeToken(Token::kw_offset);
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `offset` keyword");
+ auto maybeOffset = getToken().getUnsignedIntegerValue();
+ bool question = getToken().is(Token::question);
+ if (!maybeOffset && !question)
+ return emitError("invalid offset");
+ offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
+ : MemRefType::getDynamicStrideOrOffset();
+ consumeToken();
+
+ if (!consumeIf(Token::comma))
+ return emitError("expected comma after offset value");
+
+ // Parse stride list.
+ if (!consumeIf(Token::kw_strides))
+ return emitError("expected `strides` keyword after offset specification");
+ if (!consumeIf(Token::colon))
+ return emitError("expected colon after `strides` keyword");
+ if (failed(parseStrideList(strides)))
+ return emitError("invalid braces-enclosed stride list");
+ if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
+ return emitError("invalid memref stride");
+
+ return success();
+}
+
+/// Parse a memref type.
+///
+/// memref-type ::= ranked-memref-type | unranked-memref-type
+///
+/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
+/// (`,` semi-affine-map-composition)? (`,`
+/// memory-space)? `>`
+///
+/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
+///
+/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+/// memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+Type Parser::parseMemRefType() {
+ consumeToken(Token::kw_memref);
+
+ if (parseToken(Token::less, "expected '<' in memref type"))
+ return nullptr;
+
+ bool isUnranked;
+ SmallVector<int64_t, 4> dimensions;
+
+ if (consumeIf(Token::star)) {
+ // This is an unranked memref type.
+ isUnranked = true;
+ if (parseXInDimensionList())
+ return nullptr;
+
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return nullptr;
+ }
+
+ // Parse the element type.
+ auto typeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType)
+ return nullptr;
+
+ // Check that memref is formed from allowed types.
+ if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
+ !elementType.isa<ComplexType>())
+ return emitError(typeLoc, "invalid memref element type"), nullptr;
+
+ // Parse semi-affine-map-composition.
+ SmallVector<AffineMap, 2> affineMapComposition;
+ Optional<unsigned> memorySpace;
+ unsigned numDims = dimensions.size();
+
+ auto parseElt = [&]() -> ParseResult {
+ // Check for the memory space.
+ if (getToken().is(Token::integer)) {
+ if (memorySpace)
+ return emitError("multiple memory spaces specified in memref type");
+ memorySpace = getToken().getUnsignedIntegerValue();
+ if (!memorySpace.hasValue())
+ return emitError("invalid memory space in memref type");
+ consumeToken(Token::integer);
+ return success();
+ }
+ if (isUnranked)
+ return emitError("cannot have affine map for unranked memref type");
+ if (memorySpace)
+ return emitError("expected memory space to be last in memref type");
+
+ AffineMap map;
+ llvm::SMLoc mapLoc = getToken().getLoc();
+ if (getToken().is(Token::kw_offset)) {
+ int64_t offset;
+ SmallVector<int64_t, 4> strides;
+ if (failed(parseStridedLayout(offset, strides)))
+ return failure();
+ // Construct strided affine map.
+ map = makeStridedLinearLayoutMap(strides, offset, state.context);
+ } else {
+ // Parse an affine map attribute.
+ auto affineMap = parseAttribute();
+ if (!affineMap)
+ return failure();
+ auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
+ if (!affineMapAttr)
+ return emitError("expected affine map in memref type");
+ map = affineMapAttr.getValue();
+ }
+
+ if (map.getNumDims() != numDims) {
+ size_t i = affineMapComposition.size();
+ return emitError(mapLoc, "memref affine map dimension mismatch between ")
+ << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
+ << " and affine map" << i + 1 << ": " << numDims
+ << " != " << map.getNumDims();
+ }
+ numDims = map.getNumResults();
+ affineMapComposition.push_back(map);
+ return success();
+ };
+
+ // Parse a list of mappings and address space if present.
+ if (!consumeIf(Token::greater)) {
+ // Parse comma separated list of affine maps, followed by memory space.
+ if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
+ parseCommaSeparatedListUntil(Token::greater, parseElt,
+ /*allowEmptyList=*/false)) {
+ return nullptr;
+ }
+ }
+
+ if (isUnranked)
+ return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
+
+ return MemRefType::get(dimensions, elementType, affineMapComposition,
+ memorySpace.getValueOr(0));
+}
+
+/// Parse any type except the function type.
+///
+/// non-function-type ::= integer-type
+/// | index-type
+/// | float-type
+/// | extended-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
+/// | complex-type
+/// | tuple-type
+/// | none-type
+///
+/// index-type ::= `index`
+/// float-type ::= `f16` | `bf16` | `f32` | `f64`
+/// none-type ::= `none`
+///
+Type Parser::parseNonFunctionType() {
+ switch (getToken().getKind()) {
+ default:
+ return (emitError("expected non-function type"), nullptr);
+ case Token::kw_memref:
+ return parseMemRefType();
+ case Token::kw_tensor:
+ return parseTensorType();
+ case Token::kw_complex:
+ return parseComplexType();
+ case Token::kw_tuple:
+ return parseTupleType();
+ case Token::kw_vector:
+ return parseVectorType();
+ // integer-type
+ case Token::inttype: {
+ auto width = getToken().getIntTypeBitwidth();
+ if (!width.hasValue())
+ return (emitError("invalid integer width"), nullptr);
+ if (width.getValue() > IntegerType::kMaxWidth) {
+ emitError(getToken().getLoc(), "integer bitwidth is limited to ")
+ << IntegerType::kMaxWidth << " bits";
+ return nullptr;
+ }
+
+ IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
+ if (Optional<bool> signedness = getToken().getIntTypeSignedness())
+ signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
+
+ auto loc = getEncodedSourceLocation(getToken().getLoc());
+ consumeToken(Token::inttype);
+ return IntegerType::getChecked(width.getValue(), signSemantics, loc);
+ }
+
+ // float-type
+ case Token::kw_bf16:
+ consumeToken(Token::kw_bf16);
+ return builder.getBF16Type();
+ case Token::kw_f16:
+ consumeToken(Token::kw_f16);
+ return builder.getF16Type();
+ case Token::kw_f32:
+ consumeToken(Token::kw_f32);
+ return builder.getF32Type();
+ case Token::kw_f64:
+ consumeToken(Token::kw_f64);
+ return builder.getF64Type();
+
+ // index-type
+ case Token::kw_index:
+ consumeToken(Token::kw_index);
+ return builder.getIndexType();
+
+ // none-type
+ case Token::kw_none:
+ consumeToken(Token::kw_none);
+ return builder.getNoneType();
+
+ // extended type
+ case Token::exclamation_identifier:
+ return parseExtendedType();
+ }
+}
+
+/// Parse a tensor type.
+///
+/// tensor-type ::= `tensor` `<` dimension-list type `>`
+/// dimension-list ::= dimension-list-ranked | `*x`
+///
+Type Parser::parseTensorType() {
+ consumeToken(Token::kw_tensor);
+
+ if (parseToken(Token::less, "expected '<' in tensor type"))
+ return nullptr;
+
+ bool isUnranked;
+ SmallVector<int64_t, 4> dimensions;
+
+ if (consumeIf(Token::star)) {
+ // This is an unranked tensor type.
+ isUnranked = true;
+
+ if (parseXInDimensionList())
+ return nullptr;
+
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return nullptr;
+ }
+
+ // Parse the element type.
+ auto elementTypeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
+ return nullptr;
+ if (!TensorType::isValidElementType(elementType))
+ return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
+
+ if (isUnranked)
+ return UnrankedTensorType::get(elementType);
+ return RankedTensorType::get(dimensions, elementType);
+}
+
+/// Parse a tuple type.
+///
+/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
+///
+Type Parser::parseTupleType() {
+ consumeToken(Token::kw_tuple);
+
+ // Parse the '<'.
+ if (parseToken(Token::less, "expected '<' in tuple type"))
+ return nullptr;
+
+ // Check for an empty tuple by directly parsing '>'.
+ if (consumeIf(Token::greater))
+ return TupleType::get(getContext());
+
+ // Parse the element types and the '>'.
+ SmallVector<Type, 4> types;
+ if (parseTypeListNoParens(types) ||
+ parseToken(Token::greater, "expected '>' in tuple type"))
+ return nullptr;
+
+ return TupleType::get(types, getContext());
+}
+
+/// Parse a vector type.
+///
+/// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
+/// non-empty-static-dimension-list ::= decimal-literal `x`
+/// static-dimension-list
+/// static-dimension-list ::= (decimal-literal `x`)*
+///
+VectorType Parser::parseVectorType() {
+ consumeToken(Token::kw_vector);
+
+ if (parseToken(Token::less, "expected '<' in vector type"))
+ return nullptr;
+
+ SmallVector<int64_t, 4> dimensions;
+ if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
+ return nullptr;
+ if (dimensions.empty())
+ return (emitError("expected dimension size in vector type"), nullptr);
+ if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
+ return emitError(getToken().getLoc(),
+ "vector types must have positive constant sizes"),
+ nullptr;
+
+ // Parse the element type.
+ auto typeLoc = getToken().getLoc();
+ auto elementType = parseType();
+ if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
+ return nullptr;
+ if (!VectorType::isValidElementType(elementType))
+ return emitError(typeLoc, "vector elements must be int or float type"),
+ nullptr;
+
+ return VectorType::get(dimensions, elementType);
+}
+
+/// Parse a dimension list of a tensor or memref type. This populates the
+/// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
+/// errors out on `?` otherwise.
+///
+/// dimension-list-ranked ::= (dimension `x`)*
+/// dimension ::= `?` | decimal-literal
+///
+/// When `allowDynamic` is not set, this is used to parse:
+///
+/// static-dimension-list ::= (decimal-literal `x`)*
+ParseResult
+Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
+ bool allowDynamic) {
+ while (getToken().isAny(Token::integer, Token::question)) {
+ if (consumeIf(Token::question)) {
+ if (!allowDynamic)
+ return emitError("expected static shape");
+ dimensions.push_back(-1);
+ } else {
+ // Hexadecimal integer literals (starting with `0x`) are not allowed in
+ // aggregate type declarations. Therefore, `0xf32` should be processed as
+ // a sequence of separate elements `0`, `x`, `f32`.
+ if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
+ // We can get here only if the token is an integer literal. Hexadecimal
+ // integer literals can only start with `0x` (`1x` wouldn't lex as a
+ // literal, just `1` would, at which point we don't get into this
+ // branch).
+ assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
+ dimensions.push_back(0);
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+ consumeToken();
+ } else {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = getToken().getUnsignedIntegerValue();
+ if (!dimension.hasValue())
+ return emitError("invalid dimension");
+ dimensions.push_back((int64_t)dimension.getValue());
+ consumeToken(Token::integer);
+ }
+ }
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (parseXInDimensionList())
+ return failure();
+ }
+
+ return success();
+}
+
+/// Parse an 'x' token in a dimension list, handling the case where the x is
+/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
+/// token.
+ParseResult Parser::parseXInDimensionList() {
+ if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (getTokenSpelling().size() != 1)
+ state.lex.resetPointer(getTokenSpelling().data() + 1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+
+ return success();
+}
+
+// Parse a comma-separated list of dimensions, possibly empty:
+// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
+ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
+ if (!consumeIf(Token::l_square))
+ return failure();
+ // Empty list early exit.
+ if (consumeIf(Token::r_square))
+ return success();
+ while (true) {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
+ } else {
+ // This must be an integer value.
+ int64_t val;
+ if (getToken().getSpelling().getAsInteger(10, val))
+ return emitError("invalid integer value: ") << getToken().getSpelling();
+ // Make sure it is not the one value for `?`.
+ if (ShapedType::isDynamic(val))
+ return emitError("invalid integer value: ")
+ << getToken().getSpelling()
+ << ", use `?` to specify a dynamic dimension";
+ dimensions.push_back(val);
+ consumeToken(Token::integer);
+ }
+ if (!consumeIf(Token::comma))
+ break;
+ }
+ if (!consumeIf(Token::r_square))
+ return failure();
+ return success();
+}
More information about the Mlir-commits
mailing list