[Mlir-commits] [mlir] 58abc8c - [OpAsmParser] Add a parseCommaSeparatedList helper and beef up Delimeter.
Chris Lattner
llvmlistbot at llvm.org
Mon Sep 20 20:59:19 PDT 2021
Author: Chris Lattner
Date: 2021-09-20T20:59:11-07:00
New Revision: 58abc8c34bde7021bbfa0a7bdfd2af9524cba263
URL: https://github.com/llvm/llvm-project/commit/58abc8c34bde7021bbfa0a7bdfd2af9524cba263
DIFF: https://github.com/llvm/llvm-project/commit/58abc8c34bde7021bbfa0a7bdfd2af9524cba263.diff
LOG: [OpAsmParser] Add a parseCommaSeparatedList helper and beef up Delimeter.
Lots of custom ops have hand-rolled comma-delimited parsing loops, as does
the MLIR parser itself. Provides a standard interface for doing this that
is less error prone and less boilerplate.
While here, extend Delimiter to support <> and {} delimited sequences as
well (I have a use for <> in CIRCT specifically).
Differential Revision: https://reviews.llvm.org/D110122
Added:
Modified:
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/Async/IR/Async.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Parser/AffineParser.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/LocationParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/lib/Parser/TypeParser.cpp
mlir/test/IR/invalid-affinemap.mlir
mlir/test/IR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0cfb5448894fa..979b86c60194f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -473,6 +473,47 @@ class OpAsmParser {
return success();
}
+ /// These are the supported delimiters around operand lists and region
+ /// argument lists, used by parseOperandList and parseRegionArgumentList.
+ enum class Delimiter {
+ /// Zero or more operands with no delimiters.
+ None,
+ /// Parens surrounding zero or more operands.
+ Paren,
+ /// Square brackets surrounding zero or more operands.
+ Square,
+ /// <> brackets surrounding zero or more operands.
+ LessGreater,
+ /// {} brackets surrounding zero or more operands.
+ Braces,
+ /// Parens supporting zero or more operands, or nothing.
+ OptionalParen,
+ /// Square brackets supporting zero or more ops, or nothing.
+ OptionalSquare,
+ /// <> brackets supporting zero or more ops, or nothing.
+ OptionalLessGreater,
+ /// {} brackets surrounding zero or more operands, or nothing.
+ OptionalBraces,
+ };
+
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ///
+ /// contextMessage is an optional message appended to "expected '('" sorts of
+ /// diagnostics when parsing the delimeters.
+ virtual ParseResult
+ parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElementFn,
+ StringRef contextMessage = StringRef()) = 0;
+
+ /// Parse a comma separated list of elements that must have at least one entry
+ /// in it.
+ ParseResult
+ parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
+ return parseCommaSeparatedList(Delimiter::None, parseElementFn);
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@@ -610,21 +651,6 @@ class OpAsmParser {
/// Parse a single operand if present.
virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
- /// These are the supported delimiters around operand lists and region
- /// argument lists, used by parseOperandList and parseRegionArgumentList.
- enum class Delimiter {
- /// Zero or more operands with no delimiters.
- None,
- /// Parens surrounding zero or more operands.
- Paren,
- /// Square brackets surrounding zero or more operands.
- Square,
- /// Parens supporting zero or more operands, or nothing.
- OptionalParen,
- /// Square brackets supporting zero or more ops, or nothing.
- OptionalSquare,
- };
-
/// Parse zero or more SSA comma-separated operand references with a specified
/// surrounding delimiter, and an optional required operand count.
virtual ParseResult
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 064edce46bc31..d21ff1fef3a2c 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -158,7 +158,6 @@ static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
// Sizes of parsed variadic operands, will be updated below after parsing.
int32_t numDependencies = 0;
- int32_t numOperands = 0;
auto tokenTy = TokenType::get(ctx);
@@ -179,38 +178,27 @@ static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
SmallVector<Type, 4> valueTypes;
SmallVector<Type, 4> unwrappedTypes;
- if (succeeded(parser.parseOptionalLParen())) {
- auto argsLoc = parser.getCurrentLocation();
-
- // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
- auto parseAsyncValueArg = [&]() -> ParseResult {
- if (parser.parseOperand(valueArgs.emplace_back()) ||
- parser.parseKeyword("as") ||
- parser.parseOperand(unwrappedArgs.emplace_back()) ||
- parser.parseColonType(valueTypes.emplace_back()))
- return failure();
-
- auto valueTy = valueTypes.back().dyn_cast<ValueType>();
- unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
-
- return success();
- };
-
- // If the next token is `)` skip async value arguments parsing.
- if (failed(parser.parseOptionalRParen())) {
- do {
- if (parseAsyncValueArg())
- return failure();
- } while (succeeded(parser.parseOptionalComma()));
-
- if (parser.parseRParen() ||
- parser.resolveOperands(valueArgs, valueTypes, argsLoc,
- result.operands))
- return failure();
- }
+ // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
+ auto parseAsyncValueArg = [&]() -> ParseResult {
+ if (parser.parseOperand(valueArgs.emplace_back()) ||
+ parser.parseKeyword("as") ||
+ parser.parseOperand(unwrappedArgs.emplace_back()) ||
+ parser.parseColonType(valueTypes.emplace_back()))
+ return failure();
- numOperands = valueArgs.size();
- }
+ auto valueTy = valueTypes.back().dyn_cast<ValueType>();
+ unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
+
+ return success();
+ };
+
+ auto argsLoc = parser.getCurrentLocation();
+ if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
+ parseAsyncValueArg) ||
+ parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
+ return failure();
+
+ int32_t numOperands = valueArgs.size();
// Add derived `operand_segment_sizes` attribute based on parsed operands.
auto operandSegmentSizes = DenseIntElementsAttr::get(
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90ae3bb990c89..94be8c252a348 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -77,22 +77,16 @@ static ParseResult
parseOperandAndTypeList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &types) {
- if (parser.parseLParen())
- return failure();
-
- do {
- OpAsmParser::OperandType operand;
- Type type;
- if (parser.parseOperand(operand) || parser.parseColonType(type))
- return failure();
- operands.push_back(operand);
- types.push_back(type);
- } while (succeeded(parser.parseOptionalComma()));
-
- if (parser.parseRParen())
- return failure();
-
- return success();
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ OpAsmParser::OperandType operand;
+ Type type;
+ if (parser.parseOperand(operand) || parser.parseColonType(type))
+ return failure();
+ operands.push_back(operand);
+ types.push_back(type);
+ return success();
+ });
}
/// Parse an allocate clause with allocators and a list of operands with types.
@@ -108,30 +102,24 @@ static ParseResult parseAllocateAndAllocator(
SmallVectorImpl<Type> &typesAllocate,
SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
SmallVectorImpl<Type> &typesAllocator) {
- if (parser.parseLParen())
- return failure();
- do {
- OpAsmParser::OperandType operand;
- Type type;
-
- if (parser.parseOperand(operand) || parser.parseColonType(type))
- return failure();
- operandsAllocator.push_back(operand);
- typesAllocator.push_back(type);
- if (parser.parseArrow())
- return failure();
- if (parser.parseOperand(operand) || parser.parseColonType(type))
- return failure();
-
- operandsAllocate.push_back(operand);
- typesAllocate.push_back(type);
- } while (succeeded(parser.parseOptionalComma()));
-
- if (parser.parseRParen())
- return failure();
+ return parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ OpAsmParser::OperandType operand;
+ Type type;
+ if (parser.parseOperand(operand) || parser.parseColonType(type))
+ return failure();
+ operandsAllocator.push_back(operand);
+ typesAllocator.push_back(type);
+ if (parser.parseArrow())
+ return failure();
+ if (parser.parseOperand(operand) || parser.parseColonType(type))
+ return failure();
- return success();
+ operandsAllocate.push_back(operand);
+ typesAllocate.push_back(type);
+ return success();
+ });
}
static LogicalResult verifyParallelOp(ParallelOp op) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0cee846a75d94..900251f8ad7bc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1778,16 +1778,16 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
if (!parser.parseOptionalComma()) {
// Parse the interface variables
- do {
- // The name of the interface variable attribute isnt important
- auto attrName = "var_symbol";
- FlatSymbolRefAttr var;
- NamedAttrList attrs;
- if (parser.parseAttribute(var, Type(), attrName, attrs)) {
- return failure();
- }
- interfaceVars.push_back(var);
- } while (!parser.parseOptionalComma());
+ if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+ // The name of the interface variable attribute isnt important
+ FlatSymbolRefAttr var;
+ NamedAttrList attrs;
+ if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+ return failure();
+ interfaceVars.push_back(var);
+ return success();
+ }))
+ return failure();
}
state.addAttribute(kInterfaceAttrName,
parser.getBuilder().getArrayAttr(interfaceVars));
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ed0417634e309..6f95bf03e91dd 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2197,13 +2197,12 @@ static ParseResult parseSwitchOpCases(
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
- if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
- failed(parser.parseSuccessor(defaultDestination)))
+ if (parser.parseKeyword("default") || parser.parseColon() ||
+ parser.parseSuccessor(defaultDestination))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
- if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
- failed(parser.parseColonTypeList(defaultOperandTypes)) ||
- failed(parser.parseRParen()))
+ if (parser.parseRegionArgumentList(defaultOperands) ||
+ parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}
diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp
index 567767265c6e8..708983709f6cc 100644
--- a/mlir/lib/Parser/AffineParser.cpp
+++ b/mlir/lib/Parser/AffineParser.cpp
@@ -474,26 +474,22 @@ ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
/// Parse the list of dimensional identifiers to an affine map.
ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
- if (parseToken(Token::l_paren,
- "expected '(' at start of dimensional identifiers list")) {
- return failure();
- }
-
auto parseElt = [&]() -> ParseResult {
auto dimension = getAffineDimExpr(numDims++, getContext());
return parseIdentifierDefinition(dimension);
};
- return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+ return parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in dimensional identifier list");
}
/// Parse the list of symbolic identifiers to an affine map.
ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
- consumeToken(Token::l_square);
auto parseElt = [&]() -> ParseResult {
auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
return parseIdentifierDefinition(symbol);
};
- return parseCommaSeparatedListUntil(Token::r_square, parseElt);
+ return parseCommaSeparatedList(Delimiter::Square, parseElt,
+ " in symbol list");
}
/// Parse the list of symbolic identifiers to an affine map.
@@ -544,21 +540,6 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
ParseResult
AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
OpAsmParser::Delimiter delimiter) {
- Token::Kind rightToken;
- switch (delimiter) {
- case OpAsmParser::Delimiter::Square:
- if (parseToken(Token::l_square, "expected '['"))
- return failure();
- rightToken = Token::r_square;
- break;
- case OpAsmParser::Delimiter::Paren:
- if (parseToken(Token::l_paren, "expected '('"))
- return failure();
- rightToken = Token::r_paren;
- break;
- default:
- return emitError("unexpected delimiter");
- }
SmallVector<AffineExpr, 4> exprs;
auto parseElt = [&]() -> ParseResult {
@@ -571,9 +552,9 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
// 1-d affine expressions); the list can be empty. Grammar:
// multi-dim-affine-expr ::= `(` `)`
// | `(` affine-expr (`,` affine-expr)* `)`
- if (parseCommaSeparatedListUntil(rightToken, parseElt,
- /*allowEmptyList=*/true))
+ if (parseCommaSeparatedList(delimiter, parseElt, " in affine map"))
return failure();
+
// Parsed a valid affine map.
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
exprs, getContext());
@@ -594,8 +575,6 @@ ParseResult AffineParser::parseAffineExprOfSSAIds(AffineExpr &expr) {
/// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
unsigned numSymbols) {
- parseToken(Token::l_paren, "expected '(' at start of affine map range");
-
SmallVector<AffineExpr, 4> exprs;
auto parseElt = [&]() -> ParseResult {
auto elt = parseAffineExpr();
@@ -608,7 +587,8 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
// 1-d affine expressions). Grammar:
// multi-dim-affine-expr ::= `(` `)`
// | `(` affine-expr (`,` affine-expr)* `)`
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in affine map range"))
return AffineMap();
// Parsed a valid affine map.
@@ -662,10 +642,6 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
///
IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
unsigned numSymbols) {
- if (parseToken(Token::l_paren,
- "expected '(' at start of integer set constraint list"))
- return IntegerSet();
-
SmallVector<AffineExpr, 4> constraints;
SmallVector<bool, 4> isEqs;
auto parseElt = [&]() -> ParseResult {
@@ -680,7 +656,8 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
};
// Parse a list of affine constraints (comma-separated).
- if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+ if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+ " in integer set constraint list"))
return IntegerSet();
// If no constraints were parsed, then treat this as a degenerate 'true' case.
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 6e512bdf9747c..20cfdf27a1001 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -67,15 +67,13 @@ Attribute Parser::parseAttribute(Type type) {
// Parse an array attribute.
case Token::l_square: {
- consumeToken(Token::l_square);
-
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
return elements.back() ? success() : failure();
};
- if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
+ if (parseCommaSeparatedList(Delimiter::Square, parseElt))
return nullptr;
return builder.getArrayAttr(elements);
}
@@ -262,9 +260,6 @@ OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
/// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
///
ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
- if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
- return failure();
-
llvm::SmallDenseSet<Identifier> seenKeys;
auto parseElt = [&]() -> ParseResult {
// The name of an attribute can either be a bare identifier, or a string.
@@ -300,7 +295,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return success();
};
- if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+ if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
+ " in attribute dictionary"))
return failure();
return success();
@@ -769,8 +765,6 @@ ParseResult TensorLiteralParser::parseElement() {
/// parseList([[1, 2], 3]) -> Failure
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
- p.consumeToken(Token::l_square);
-
auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
if (prevDims == newDims)
@@ -782,7 +776,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
bool first = true;
SmallVector<int64_t, 4> newDims;
unsigned size = 0;
- auto parseCommaSeparatedList = [&]() -> ParseResult {
+ auto parseOneElement = [&]() -> ParseResult {
SmallVector<int64_t, 4> thisDims;
if (p.getToken().getKind() == Token::l_square) {
if (parseList(thisDims))
@@ -797,7 +791,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
first = false;
return success();
};
- if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+ if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
return failure();
// Return the sublists' dimensions with 'size' prepended.
diff --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp
index d0bd5db5267db..d010260e84836 100644
--- a/mlir/lib/Parser/LocationParser.cpp
+++ b/mlir/lib/Parser/LocationParser.cpp
@@ -82,9 +82,8 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
return success();
};
- if (parseToken(Token::l_square, "expected '[' in fused location") ||
- parseCommaSeparatedList(parseElt) ||
- parseToken(Token::r_square, "expected ']' in fused location"))
+ if (parseCommaSeparatedList(Delimiter::Square, parseElt,
+ " in fused location"))
return failure();
// Return the fused location.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 38d6c0b444dc8..6c8375427e791 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -35,20 +35,90 @@ using llvm::SourceMgr;
// Parser
//===----------------------------------------------------------------------===//
-/// Parse a comma separated list of elements that must have at least one entry
-/// in it.
+/// Parse a list of comma-separated items with an optional delimiter. If a
+/// delimiter is provided, then an empty list is allowed. If not, then at
+/// least one element will be parsed.
ParseResult
-Parser::parseCommaSeparatedList(function_ref<ParseResult()> parseElement) {
+Parser::parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElementFn,
+ StringRef contextMessage) {
+ switch (delimiter) {
+ case Delimiter::None:
+ break;
+ case Delimiter::OptionalParen:
+ if (getToken().isNot(Token::l_paren))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::Paren:
+ if (parseToken(Token::l_paren, "expected '('" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_paren))
+ return success();
+ break;
+ case Delimiter::OptionalLessGreater:
+ // Check for absent list.
+ if (getToken().isNot(Token::less))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::LessGreater:
+ if (parseToken(Token::less, "expected '<'" + contextMessage))
+ return success();
+ // Check for empty list.
+ if (consumeIf(Token::greater))
+ return success();
+ break;
+ case Delimiter::OptionalSquare:
+ if (getToken().isNot(Token::l_square))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::Square:
+ if (parseToken(Token::l_square, "expected '['" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_square))
+ return success();
+ break;
+ case Delimiter::OptionalBraces:
+ if (getToken().isNot(Token::l_brace))
+ return success();
+ LLVM_FALLTHROUGH;
+ case Delimiter::Braces:
+ if (parseToken(Token::l_brace, "expected '{'" + contextMessage))
+ return failure();
+ // Check for empty list.
+ if (consumeIf(Token::r_brace))
+ return success();
+ break;
+ }
+
// Non-empty case starts with an element.
- if (parseElement())
+ if (parseElementFn())
return failure();
// Otherwise we have a list of comma separated elements.
while (consumeIf(Token::comma)) {
- if (parseElement())
+ if (parseElementFn())
return failure();
}
- return success();
+
+ switch (delimiter) {
+ case Delimiter::None:
+ return success();
+ case Delimiter::OptionalParen:
+ case Delimiter::Paren:
+ return parseToken(Token::r_paren, "expected ')'" + contextMessage);
+ case Delimiter::OptionalLessGreater:
+ case Delimiter::LessGreater:
+ return parseToken(Token::greater, "expected '>'" + contextMessage);
+ case Delimiter::OptionalSquare:
+ case Delimiter::Square:
+ return parseToken(Token::r_square, "expected ']'" + contextMessage);
+ case Delimiter::OptionalBraces:
+ case Delimiter::Braces:
+ return parseToken(Token::r_brace, "expected '}'" + contextMessage);
+ }
+ llvm_unreachable("Unknown delimiter");
}
/// Parse a comma-separated list of elements, terminated with an arbitrary
@@ -1282,6 +1352,15 @@ class CustomOpAsmParser : public OpAsmParser {
return parser.parseOptionalInteger(result);
}
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ParseResult parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElt,
+ StringRef contextMessage) override {
+ return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+ }
+
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@@ -1467,67 +1546,37 @@ class CustomOpAsmParser : public OpAsmParser {
Delimiter delimiter = Delimiter::None) {
auto startLoc = parser.getToken().getLoc();
- // Handle delimiters.
- switch (delimiter) {
- case Delimiter::None:
- // Don't check for the absence of a delimiter if the number of operands
- // is unknown (and hence the operand list could be empty).
- if (requiredOperandCount == -1)
- break;
- // Token already matches an identifier and so can't be a delimiter.
- if (parser.getToken().is(Token::percent_identifier))
- break;
- // Test against known delimiters.
- if (parser.getToken().is(Token::l_paren) ||
- parser.getToken().is(Token::l_square))
- return emitError(startLoc, "unexpected delimiter");
- return emitError(startLoc, "invalid operand");
- case Delimiter::OptionalParen:
- if (parser.getToken().isNot(Token::l_paren))
- return success();
- LLVM_FALLTHROUGH;
- case Delimiter::Paren:
- if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
- return failure();
- break;
- case Delimiter::OptionalSquare:
- if (parser.getToken().isNot(Token::l_square))
- return success();
- LLVM_FALLTHROUGH;
- case Delimiter::Square:
- if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
- return failure();
- break;
- }
-
- // Check for zero operands.
- if (parser.getToken().is(Token::percent_identifier)) {
- do {
- OperandType operandOrArg;
- if (isOperandList ? parseOperand(operandOrArg)
- : parseRegionArgument(operandOrArg))
- return failure();
- result.push_back(operandOrArg);
- } while (parser.consumeIf(Token::comma));
+ // The no-delimiter case has some special handling for better diagnostics.
+ if (delimiter == Delimiter::None) {
+ // parseCommaSeparatedList doesn't handle the missing case for "none",
+ // so we handle it custom here.
+ if (parser.getToken().isNot(Token::percent_identifier)) {
+ // If we didn't require any operands or required exactly zero (weird)
+ // then this is success.
+ if (requiredOperandCount == -1 || requiredOperandCount == 0)
+ return success();
+
+ // Otherwise, try to produce a nice error message.
+ if (parser.getToken().is(Token::l_paren) ||
+ parser.getToken().is(Token::l_square))
+ return emitError(startLoc, "unexpected delimiter");
+ return emitError(startLoc, "invalid operand");
+ }
}
- // Handle delimiters. If we reach here, the optional delimiters were
- // present, so we need to parse their closing one.
- switch (delimiter) {
- case Delimiter::None:
- break;
- case Delimiter::OptionalParen:
- case Delimiter::Paren:
- if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
- return failure();
- break;
- case Delimiter::OptionalSquare:
- case Delimiter::Square:
- if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+ auto parseOneOperand = [&]() -> ParseResult {
+ OperandType operandOrArg;
+ if (isOperandList ? parseOperand(operandOrArg)
+ : parseRegionArgument(operandOrArg))
return failure();
- break;
- }
+ result.push_back(operandOrArg);
+ return success();
+ };
+
+ if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list"))
+ return failure();
+ // Check that we got the expected # of elements.
if (requiredOperandCount != -1 &&
result.size() != static_cast<size_t>(requiredOperandCount))
return emitError(startLoc, "expected ")
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index f8a2c74e455f4..9a1894deee7bf 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -24,6 +24,8 @@ namespace detail {
/// include state.
class Parser {
public:
+ using Delimiter = OpAsmParser::Delimiter;
+
Builder builder;
Parser(ParserState &state) : builder(state.context), state(state) {}
@@ -39,9 +41,20 @@ class Parser {
function_ref<ParseResult()> parseElement,
bool allowEmptyList = true);
+ /// Parse a list of comma-separated items with an optional delimiter. If a
+ /// delimiter is provided, then an empty list is allowed. If not, then at
+ /// least one element will be parsed.
+ ParseResult
+ parseCommaSeparatedList(Delimiter delimiter,
+ function_ref<ParseResult()> parseElementFn,
+ StringRef contextMessage = StringRef());
+
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
- ParseResult parseCommaSeparatedList(function_ref<ParseResult()> parseElement);
+ ParseResult
+ parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
+ return parseCommaSeparatedList(Delimiter::None, parseElementFn);
+ }
ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
@@ -276,7 +289,7 @@ class Parser {
ParseResult
parseAffineMapOfSSAIds(AffineMap &map,
function_ref<ParseResult(bool)> parseElement,
- OpAsmParser::Delimiter delimiter);
+ Delimiter delimiter);
/// Parse an AffineExpr where dim and symbol identifiers are SSA ids.
ParseResult
diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 66ec3b4237081..46fbf52d25bb6 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -165,15 +165,12 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
return emitError("expected comma after offset value");
// Parse stride list.
- if (!consumeIf(Token::kw_strides))
- return emitError("expected `strides` keyword after offset specification");
- if (!consumeIf(Token::colon))
- return emitError("expected colon after `strides` keyword");
- if (failed(parseStrideList(strides)))
- return emitError("invalid braces-enclosed stride list");
- if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
- return emitError("invalid memref stride");
+ if (parseToken(Token::kw_strides,
+ "expected `strides` keyword after offset specification") ||
+ parseToken(Token::colon, "expected colon after `strides` keyword") ||
+ parseStrideList(strides))
+ return failure();
return success();
}
@@ -560,31 +557,30 @@ ParseResult Parser::parseXInDimensionList() {
// Parse a comma-separated list of dimensions, possibly empty:
// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
- if (!consumeIf(Token::l_square))
- return failure();
- // Empty list early exit.
- if (consumeIf(Token::r_square))
- return success();
- while (true) {
- if (consumeIf(Token::question)) {
- dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
- } else {
- // This must be an integer value.
- int64_t val;
- if (getToken().getSpelling().getAsInteger(10, val))
- return emitError("invalid integer value: ") << getToken().getSpelling();
- // Make sure it is not the one value for `?`.
- if (ShapedType::isDynamic(val))
- return emitError("invalid integer value: ")
- << getToken().getSpelling()
- << ", use `?` to specify a dynamic dimension";
- dimensions.push_back(val);
- consumeToken(Token::integer);
- }
- if (!consumeIf(Token::comma))
- break;
- }
- if (!consumeIf(Token::r_square))
- return failure();
- return success();
+ return parseCommaSeparatedList(
+ Delimiter::Square,
+ [&]() -> ParseResult {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
+ } else {
+ // This must be an integer value.
+ int64_t val;
+ if (getToken().getSpelling().getAsInteger(10, val))
+ return emitError("invalid integer value: ")
+ << getToken().getSpelling();
+ // Make sure it is not the one value for `?`.
+ if (ShapedType::isDynamic(val))
+ return emitError("invalid integer value: ")
+ << getToken().getSpelling()
+ << ", use `?` to specify a dynamic dimension";
+
+ if (val == 0)
+ return emitError("invalid memref stride");
+
+ dimensions.push_back(val);
+ consumeToken(Token::integer);
+ }
+ return success();
+ },
+ " in stride list");
}
diff --git a/mlir/test/IR/invalid-affinemap.mlir b/mlir/test/IR/invalid-affinemap.mlir
index 9377824f006a5..acca3cd32fe42 100644
--- a/mlir/test/IR/invalid-affinemap.mlir
+++ b/mlir/test/IR/invalid-affinemap.mlir
@@ -29,7 +29,9 @@
#hello_world = affine_map<(i, j) [s0] -> (((s0 + (i + j) + 5), j)> // expected-error {{expected ')'}}
// -----
-#hello_world = affine_map<(i, j) [s0] -> i + s0, j)> // expected-error {{expected '(' at start of affine map range}}
+
+// expected-error @+1 {{expected '(' in affine map range}}
+#hello_world = affine_map<(i, j) [s0] -> i + s0, j)>
// -----
#hello_world = affine_map<(i, j) [s0] -> (x)> // expected-error {{use of undeclared identifier}}
@@ -47,6 +49,8 @@
#hello_world = affine_map<(i, j) [s0, s1] -> (+i, j)> // expected-error {{missing left operand of binary op}}
// -----
+
+
#hello_world = affine_map<(i, j) [s0, s1] -> (i, *j)> // expected-error {{missing left operand of binary op}}
// -----
@@ -91,7 +95,8 @@
#hello_world = affine_map<(i, j) [s0, s1] -> (i, i mod (2+i))> // expected-error {{non-affine expression: right operand of mod has to be either a constant or symbolic}}
// -----
-#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)> // expected-error {{expected ',' or ')'}}
+// expected-error @+1 {{expected ')' in affine map range}}
+#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)>
// -----
#hello_world = affine_map<(i, j) -> (i, 3*d0 + )> // expected-error {{use of undeclared identifier}}
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index ee4291606e42d..80b0fae523e2a 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -92,7 +92,8 @@ func @memref_stride_missing_colon_2(memref<42x42xi8, offset: 0, strides [?, ?]>)
// -----
-func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>) // expected-error {{invalid braces-enclosed stride list}}
+// expected-error @+1 {{expected '['}}
+func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>)
// -----
@@ -633,7 +634,8 @@ func @invalid_bound_map(%N : i32) {
// -----
-#set0 = affine_set<(i)[N, M] : )i >= 0)> // expected-error {{expected '(' at start of integer set constraint list}}
+// expected-error @+1 {{expected '(' in integer set constraint list}}
+#set0 = affine_set<(i)[N, M] : )i >= 0)>
// -----
#set0 = affine_set<(i)[N] : (i >= 0, N - i >= 0)>
More information about the Mlir-commits
mailing list