[Mlir-commits] [mlir] 58abc8c - [OpAsmParser] Add a parseCommaSeparatedList helper and beef up Delimeter.

Chris Lattner llvmlistbot at llvm.org
Mon Sep 20 20:59:19 PDT 2021


Author: Chris Lattner
Date: 2021-09-20T20:59:11-07:00
New Revision: 58abc8c34bde7021bbfa0a7bdfd2af9524cba263

URL: https://github.com/llvm/llvm-project/commit/58abc8c34bde7021bbfa0a7bdfd2af9524cba263
DIFF: https://github.com/llvm/llvm-project/commit/58abc8c34bde7021bbfa0a7bdfd2af9524cba263.diff

LOG: [OpAsmParser] Add a parseCommaSeparatedList helper and beef up Delimeter.

Lots of custom ops have hand-rolled comma-delimited parsing loops, as does
the MLIR parser itself.  Provides a standard interface for doing this that
is less error prone and less boilerplate.

While here, extend Delimiter to support <> and {} delimited sequences as
well (I have a use for <> in CIRCT specifically).

Differential Revision: https://reviews.llvm.org/D110122

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/Async/IR/Async.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Parser/AffineParser.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/LocationParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/lib/Parser/TypeParser.cpp
    mlir/test/IR/invalid-affinemap.mlir
    mlir/test/IR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 0cfb5448894fa..979b86c60194f 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -473,6 +473,47 @@ class OpAsmParser {
     return success();
   }
 
+  /// These are the supported delimiters around operand lists and region
+  /// argument lists, used by parseOperandList and parseRegionArgumentList.
+  enum class Delimiter {
+    /// Zero or more operands with no delimiters.
+    None,
+    /// Parens surrounding zero or more operands.
+    Paren,
+    /// Square brackets surrounding zero or more operands.
+    Square,
+    /// <> brackets surrounding zero or more operands.
+    LessGreater,
+    /// {} brackets surrounding zero or more operands.
+    Braces,
+    /// Parens supporting zero or more operands, or nothing.
+    OptionalParen,
+    /// Square brackets supporting zero or more ops, or nothing.
+    OptionalSquare,
+    /// <> brackets supporting zero or more ops, or nothing.
+    OptionalLessGreater,
+    /// {} brackets surrounding zero or more operands, or nothing.
+    OptionalBraces,
+  };
+
+  /// Parse a list of comma-separated items with an optional delimiter.  If a
+  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// least one element will be parsed.
+  ///
+  /// contextMessage is an optional message appended to "expected '('" sorts of
+  /// diagnostics when parsing the delimeters.
+  virtual ParseResult
+  parseCommaSeparatedList(Delimiter delimiter,
+                          function_ref<ParseResult()> parseElementFn,
+                          StringRef contextMessage = StringRef()) = 0;
+
+  /// Parse a comma separated list of elements that must have at least one entry
+  /// in it.
+  ParseResult
+  parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
+    return parseCommaSeparatedList(Delimiter::None, parseElementFn);
+  }
+
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//
@@ -610,21 +651,6 @@ class OpAsmParser {
   /// Parse a single operand if present.
   virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
 
-  /// These are the supported delimiters around operand lists and region
-  /// argument lists, used by parseOperandList and parseRegionArgumentList.
-  enum class Delimiter {
-    /// Zero or more operands with no delimiters.
-    None,
-    /// Parens surrounding zero or more operands.
-    Paren,
-    /// Square brackets surrounding zero or more operands.
-    Square,
-    /// Parens supporting zero or more operands, or nothing.
-    OptionalParen,
-    /// Square brackets supporting zero or more ops, or nothing.
-    OptionalSquare,
-  };
-
   /// Parse zero or more SSA comma-separated operand references with a specified
   /// surrounding delimiter, and an optional required operand count.
   virtual ParseResult

diff  --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index 064edce46bc31..d21ff1fef3a2c 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -158,7 +158,6 @@ static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
 
   // Sizes of parsed variadic operands, will be updated below after parsing.
   int32_t numDependencies = 0;
-  int32_t numOperands = 0;
 
   auto tokenTy = TokenType::get(ctx);
 
@@ -179,38 +178,27 @@ static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
   SmallVector<Type, 4> valueTypes;
   SmallVector<Type, 4> unwrappedTypes;
 
-  if (succeeded(parser.parseOptionalLParen())) {
-    auto argsLoc = parser.getCurrentLocation();
-
-    // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
-    auto parseAsyncValueArg = [&]() -> ParseResult {
-      if (parser.parseOperand(valueArgs.emplace_back()) ||
-          parser.parseKeyword("as") ||
-          parser.parseOperand(unwrappedArgs.emplace_back()) ||
-          parser.parseColonType(valueTypes.emplace_back()))
-        return failure();
-
-      auto valueTy = valueTypes.back().dyn_cast<ValueType>();
-      unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
-
-      return success();
-    };
-
-    // If the next token is `)` skip async value arguments parsing.
-    if (failed(parser.parseOptionalRParen())) {
-      do {
-        if (parseAsyncValueArg())
-          return failure();
-      } while (succeeded(parser.parseOptionalComma()));
-
-      if (parser.parseRParen() ||
-          parser.resolveOperands(valueArgs, valueTypes, argsLoc,
-                                 result.operands))
-        return failure();
-    }
+  // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
+  auto parseAsyncValueArg = [&]() -> ParseResult {
+    if (parser.parseOperand(valueArgs.emplace_back()) ||
+        parser.parseKeyword("as") ||
+        parser.parseOperand(unwrappedArgs.emplace_back()) ||
+        parser.parseColonType(valueTypes.emplace_back()))
+      return failure();
 
-    numOperands = valueArgs.size();
-  }
+    auto valueTy = valueTypes.back().dyn_cast<ValueType>();
+    unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
+
+    return success();
+  };
+
+  auto argsLoc = parser.getCurrentLocation();
+  if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
+                                     parseAsyncValueArg) ||
+      parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
+    return failure();
+
+  int32_t numOperands = valueArgs.size();
 
   // Add derived `operand_segment_sizes` attribute based on parsed operands.
   auto operandSegmentSizes = DenseIntElementsAttr::get(

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 90ae3bb990c89..94be8c252a348 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -77,22 +77,16 @@ static ParseResult
 parseOperandAndTypeList(OpAsmParser &parser,
                         SmallVectorImpl<OpAsmParser::OperandType> &operands,
                         SmallVectorImpl<Type> &types) {
-  if (parser.parseLParen())
-    return failure();
-
-  do {
-    OpAsmParser::OperandType operand;
-    Type type;
-    if (parser.parseOperand(operand) || parser.parseColonType(type))
-      return failure();
-    operands.push_back(operand);
-    types.push_back(type);
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (parser.parseRParen())
-    return failure();
-
-  return success();
+  return parser.parseCommaSeparatedList(
+      OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+        OpAsmParser::OperandType operand;
+        Type type;
+        if (parser.parseOperand(operand) || parser.parseColonType(type))
+          return failure();
+        operands.push_back(operand);
+        types.push_back(type);
+        return success();
+      });
 }
 
 /// Parse an allocate clause with allocators and a list of operands with types.
@@ -108,30 +102,24 @@ static ParseResult parseAllocateAndAllocator(
     SmallVectorImpl<Type> &typesAllocate,
     SmallVectorImpl<OpAsmParser::OperandType> &operandsAllocator,
     SmallVectorImpl<Type> &typesAllocator) {
-  if (parser.parseLParen())
-    return failure();
 
-  do {
-    OpAsmParser::OperandType operand;
-    Type type;
-
-    if (parser.parseOperand(operand) || parser.parseColonType(type))
-      return failure();
-    operandsAllocator.push_back(operand);
-    typesAllocator.push_back(type);
-    if (parser.parseArrow())
-      return failure();
-    if (parser.parseOperand(operand) || parser.parseColonType(type))
-      return failure();
-
-    operandsAllocate.push_back(operand);
-    typesAllocate.push_back(type);
-  } while (succeeded(parser.parseOptionalComma()));
-
-  if (parser.parseRParen())
-    return failure();
+  return parser.parseCommaSeparatedList(
+      OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+        OpAsmParser::OperandType operand;
+        Type type;
+        if (parser.parseOperand(operand) || parser.parseColonType(type))
+          return failure();
+        operandsAllocator.push_back(operand);
+        typesAllocator.push_back(type);
+        if (parser.parseArrow())
+          return failure();
+        if (parser.parseOperand(operand) || parser.parseColonType(type))
+          return failure();
 
-  return success();
+        operandsAllocate.push_back(operand);
+        typesAllocate.push_back(type);
+        return success();
+      });
 }
 
 static LogicalResult verifyParallelOp(ParallelOp op) {

diff  --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 0cee846a75d94..900251f8ad7bc 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1778,16 +1778,16 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
 
   if (!parser.parseOptionalComma()) {
     // Parse the interface variables
-    do {
-      // The name of the interface variable attribute isnt important
-      auto attrName = "var_symbol";
-      FlatSymbolRefAttr var;
-      NamedAttrList attrs;
-      if (parser.parseAttribute(var, Type(), attrName, attrs)) {
-        return failure();
-      }
-      interfaceVars.push_back(var);
-    } while (!parser.parseOptionalComma());
+    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
+          // The name of the interface variable attribute isnt important
+          FlatSymbolRefAttr var;
+          NamedAttrList attrs;
+          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
+            return failure();
+          interfaceVars.push_back(var);
+          return success();
+        }))
+      return failure();
   }
   state.addAttribute(kInterfaceAttrName,
                      parser.getBuilder().getArrayAttr(interfaceVars));

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index ed0417634e309..6f95bf03e91dd 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -2197,13 +2197,12 @@ static ParseResult parseSwitchOpCases(
     SmallVectorImpl<Block *> &caseDestinations,
     SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
     SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
-  if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
-      failed(parser.parseSuccessor(defaultDestination)))
+  if (parser.parseKeyword("default") || parser.parseColon() ||
+      parser.parseSuccessor(defaultDestination))
     return failure();
   if (succeeded(parser.parseOptionalLParen())) {
-    if (failed(parser.parseRegionArgumentList(defaultOperands)) ||
-        failed(parser.parseColonTypeList(defaultOperandTypes)) ||
-        failed(parser.parseRParen()))
+    if (parser.parseRegionArgumentList(defaultOperands) ||
+        parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
       return failure();
   }
 

diff  --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp
index 567767265c6e8..708983709f6cc 100644
--- a/mlir/lib/Parser/AffineParser.cpp
+++ b/mlir/lib/Parser/AffineParser.cpp
@@ -474,26 +474,22 @@ ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
 
 /// Parse the list of dimensional identifiers to an affine map.
 ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
-  if (parseToken(Token::l_paren,
-                 "expected '(' at start of dimensional identifiers list")) {
-    return failure();
-  }
-
   auto parseElt = [&]() -> ParseResult {
     auto dimension = getAffineDimExpr(numDims++, getContext());
     return parseIdentifierDefinition(dimension);
   };
-  return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
+  return parseCommaSeparatedList(Delimiter::Paren, parseElt,
+                                 " in dimensional identifier list");
 }
 
 /// Parse the list of symbolic identifiers to an affine map.
 ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
-  consumeToken(Token::l_square);
   auto parseElt = [&]() -> ParseResult {
     auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
     return parseIdentifierDefinition(symbol);
   };
-  return parseCommaSeparatedListUntil(Token::r_square, parseElt);
+  return parseCommaSeparatedList(Delimiter::Square, parseElt,
+                                 " in symbol list");
 }
 
 /// Parse the list of symbolic identifiers to an affine map.
@@ -544,21 +540,6 @@ ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
 ParseResult
 AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
                                      OpAsmParser::Delimiter delimiter) {
-  Token::Kind rightToken;
-  switch (delimiter) {
-  case OpAsmParser::Delimiter::Square:
-    if (parseToken(Token::l_square, "expected '['"))
-      return failure();
-    rightToken = Token::r_square;
-    break;
-  case OpAsmParser::Delimiter::Paren:
-    if (parseToken(Token::l_paren, "expected '('"))
-      return failure();
-    rightToken = Token::r_paren;
-    break;
-  default:
-    return emitError("unexpected delimiter");
-  }
 
   SmallVector<AffineExpr, 4> exprs;
   auto parseElt = [&]() -> ParseResult {
@@ -571,9 +552,9 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
   // 1-d affine expressions); the list can be empty. Grammar:
   // multi-dim-affine-expr ::= `(` `)`
   //                         | `(` affine-expr (`,` affine-expr)* `)`
-  if (parseCommaSeparatedListUntil(rightToken, parseElt,
-                                   /*allowEmptyList=*/true))
+  if (parseCommaSeparatedList(delimiter, parseElt, " in affine map"))
     return failure();
+
   // Parsed a valid affine map.
   map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
                        exprs, getContext());
@@ -594,8 +575,6 @@ ParseResult AffineParser::parseAffineExprOfSSAIds(AffineExpr &expr) {
 ///  multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
 AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
                                             unsigned numSymbols) {
-  parseToken(Token::l_paren, "expected '(' at start of affine map range");
-
   SmallVector<AffineExpr, 4> exprs;
   auto parseElt = [&]() -> ParseResult {
     auto elt = parseAffineExpr();
@@ -608,7 +587,8 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
   // 1-d affine expressions). Grammar:
   // multi-dim-affine-expr ::= `(` `)`
   //                         | `(` affine-expr (`,` affine-expr)* `)`
-  if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+  if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+                              " in affine map range"))
     return AffineMap();
 
   // Parsed a valid affine map.
@@ -662,10 +642,6 @@ AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
 ///
 IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
                                                     unsigned numSymbols) {
-  if (parseToken(Token::l_paren,
-                 "expected '(' at start of integer set constraint list"))
-    return IntegerSet();
-
   SmallVector<AffineExpr, 4> constraints;
   SmallVector<bool, 4> isEqs;
   auto parseElt = [&]() -> ParseResult {
@@ -680,7 +656,8 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims,
   };
 
   // Parse a list of affine constraints (comma-separated).
-  if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
+  if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
+                              " in integer set constraint list"))
     return IntegerSet();
 
   // If no constraints were parsed, then treat this as a degenerate 'true' case.

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 6e512bdf9747c..20cfdf27a1001 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -67,15 +67,13 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse an array attribute.
   case Token::l_square: {
-    consumeToken(Token::l_square);
-
     SmallVector<Attribute, 4> elements;
     auto parseElt = [&]() -> ParseResult {
       elements.push_back(parseAttribute());
       return elements.back() ? success() : failure();
     };
 
-    if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
+    if (parseCommaSeparatedList(Delimiter::Square, parseElt))
       return nullptr;
     return builder.getArrayAttr(elements);
   }
@@ -262,9 +260,6 @@ OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
 ///
 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
-  if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
-    return failure();
-
   llvm::SmallDenseSet<Identifier> seenKeys;
   auto parseElt = [&]() -> ParseResult {
     // The name of an attribute can either be a bare identifier, or a string.
@@ -300,7 +295,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
     return success();
   };
 
-  if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
+  if (parseCommaSeparatedList(Delimiter::Braces, parseElt,
+                              " in attribute dictionary"))
     return failure();
 
   return success();
@@ -769,8 +765,6 @@ ParseResult TensorLiteralParser::parseElement() {
 ///   parseList([[1, 2], 3]) -> Failure
 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
-  p.consumeToken(Token::l_square);
-
   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
     if (prevDims == newDims)
@@ -782,7 +776,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
   bool first = true;
   SmallVector<int64_t, 4> newDims;
   unsigned size = 0;
-  auto parseCommaSeparatedList = [&]() -> ParseResult {
+  auto parseOneElement = [&]() -> ParseResult {
     SmallVector<int64_t, 4> thisDims;
     if (p.getToken().getKind() == Token::l_square) {
       if (parseList(thisDims))
@@ -797,7 +791,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
     first = false;
     return success();
   };
-  if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
+  if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOneElement))
     return failure();
 
   // Return the sublists' dimensions with 'size' prepended.

diff  --git a/mlir/lib/Parser/LocationParser.cpp b/mlir/lib/Parser/LocationParser.cpp
index d0bd5db5267db..d010260e84836 100644
--- a/mlir/lib/Parser/LocationParser.cpp
+++ b/mlir/lib/Parser/LocationParser.cpp
@@ -82,9 +82,8 @@ ParseResult Parser::parseFusedLocation(LocationAttr &loc) {
     return success();
   };
 
-  if (parseToken(Token::l_square, "expected '[' in fused location") ||
-      parseCommaSeparatedList(parseElt) ||
-      parseToken(Token::r_square, "expected ']' in fused location"))
+  if (parseCommaSeparatedList(Delimiter::Square, parseElt,
+                              " in fused location"))
     return failure();
 
   // Return the fused location.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 38d6c0b444dc8..6c8375427e791 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -35,20 +35,90 @@ using llvm::SourceMgr;
 // Parser
 //===----------------------------------------------------------------------===//
 
-/// Parse a comma separated list of elements that must have at least one entry
-/// in it.
+/// Parse a list of comma-separated items with an optional delimiter.  If a
+/// delimiter is provided, then an empty list is allowed.  If not, then at
+/// least one element will be parsed.
 ParseResult
-Parser::parseCommaSeparatedList(function_ref<ParseResult()> parseElement) {
+Parser::parseCommaSeparatedList(Delimiter delimiter,
+                                function_ref<ParseResult()> parseElementFn,
+                                StringRef contextMessage) {
+  switch (delimiter) {
+  case Delimiter::None:
+    break;
+  case Delimiter::OptionalParen:
+    if (getToken().isNot(Token::l_paren))
+      return success();
+    LLVM_FALLTHROUGH;
+  case Delimiter::Paren:
+    if (parseToken(Token::l_paren, "expected '('" + contextMessage))
+      return failure();
+    // Check for empty list.
+    if (consumeIf(Token::r_paren))
+      return success();
+    break;
+  case Delimiter::OptionalLessGreater:
+    // Check for absent list.
+    if (getToken().isNot(Token::less))
+      return success();
+    LLVM_FALLTHROUGH;
+  case Delimiter::LessGreater:
+    if (parseToken(Token::less, "expected '<'" + contextMessage))
+      return success();
+    // Check for empty list.
+    if (consumeIf(Token::greater))
+      return success();
+    break;
+  case Delimiter::OptionalSquare:
+    if (getToken().isNot(Token::l_square))
+      return success();
+    LLVM_FALLTHROUGH;
+  case Delimiter::Square:
+    if (parseToken(Token::l_square, "expected '['" + contextMessage))
+      return failure();
+    // Check for empty list.
+    if (consumeIf(Token::r_square))
+      return success();
+    break;
+  case Delimiter::OptionalBraces:
+    if (getToken().isNot(Token::l_brace))
+      return success();
+    LLVM_FALLTHROUGH;
+  case Delimiter::Braces:
+    if (parseToken(Token::l_brace, "expected '{'" + contextMessage))
+      return failure();
+    // Check for empty list.
+    if (consumeIf(Token::r_brace))
+      return success();
+    break;
+  }
+
   // Non-empty case starts with an element.
-  if (parseElement())
+  if (parseElementFn())
     return failure();
 
   // Otherwise we have a list of comma separated elements.
   while (consumeIf(Token::comma)) {
-    if (parseElement())
+    if (parseElementFn())
       return failure();
   }
-  return success();
+
+  switch (delimiter) {
+  case Delimiter::None:
+    return success();
+  case Delimiter::OptionalParen:
+  case Delimiter::Paren:
+    return parseToken(Token::r_paren, "expected ')'" + contextMessage);
+  case Delimiter::OptionalLessGreater:
+  case Delimiter::LessGreater:
+    return parseToken(Token::greater, "expected '>'" + contextMessage);
+  case Delimiter::OptionalSquare:
+  case Delimiter::Square:
+    return parseToken(Token::r_square, "expected ']'" + contextMessage);
+  case Delimiter::OptionalBraces:
+  case Delimiter::Braces:
+    return parseToken(Token::r_brace, "expected '}'" + contextMessage);
+  }
+  llvm_unreachable("Unknown delimiter");
 }
 
 /// Parse a comma-separated list of elements, terminated with an arbitrary
@@ -1282,6 +1352,15 @@ class CustomOpAsmParser : public OpAsmParser {
     return parser.parseOptionalInteger(result);
   }
 
+  /// Parse a list of comma-separated items with an optional delimiter.  If a
+  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// least one element will be parsed.
+  ParseResult parseCommaSeparatedList(Delimiter delimiter,
+                                      function_ref<ParseResult()> parseElt,
+                                      StringRef contextMessage) override {
+    return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
+  }
+
   //===--------------------------------------------------------------------===//
   // Attribute Parsing
   //===--------------------------------------------------------------------===//
@@ -1467,67 +1546,37 @@ class CustomOpAsmParser : public OpAsmParser {
                               Delimiter delimiter = Delimiter::None) {
     auto startLoc = parser.getToken().getLoc();
 
-    // Handle delimiters.
-    switch (delimiter) {
-    case Delimiter::None:
-      // Don't check for the absence of a delimiter if the number of operands
-      // is unknown (and hence the operand list could be empty).
-      if (requiredOperandCount == -1)
-        break;
-      // Token already matches an identifier and so can't be a delimiter.
-      if (parser.getToken().is(Token::percent_identifier))
-        break;
-      // Test against known delimiters.
-      if (parser.getToken().is(Token::l_paren) ||
-          parser.getToken().is(Token::l_square))
-        return emitError(startLoc, "unexpected delimiter");
-      return emitError(startLoc, "invalid operand");
-    case Delimiter::OptionalParen:
-      if (parser.getToken().isNot(Token::l_paren))
-        return success();
-      LLVM_FALLTHROUGH;
-    case Delimiter::Paren:
-      if (parser.parseToken(Token::l_paren, "expected '(' in operand list"))
-        return failure();
-      break;
-    case Delimiter::OptionalSquare:
-      if (parser.getToken().isNot(Token::l_square))
-        return success();
-      LLVM_FALLTHROUGH;
-    case Delimiter::Square:
-      if (parser.parseToken(Token::l_square, "expected '[' in operand list"))
-        return failure();
-      break;
-    }
-
-    // Check for zero operands.
-    if (parser.getToken().is(Token::percent_identifier)) {
-      do {
-        OperandType operandOrArg;
-        if (isOperandList ? parseOperand(operandOrArg)
-                          : parseRegionArgument(operandOrArg))
-          return failure();
-        result.push_back(operandOrArg);
-      } while (parser.consumeIf(Token::comma));
+    // The no-delimiter case has some special handling for better diagnostics.
+    if (delimiter == Delimiter::None) {
+      // parseCommaSeparatedList doesn't handle the missing case for "none",
+      // so we handle it custom here.
+      if (parser.getToken().isNot(Token::percent_identifier)) {
+        // If we didn't require any operands or required exactly zero (weird)
+        // then this is success.
+        if (requiredOperandCount == -1 || requiredOperandCount == 0)
+          return success();
+
+        // Otherwise, try to produce a nice error message.
+        if (parser.getToken().is(Token::l_paren) ||
+            parser.getToken().is(Token::l_square))
+          return emitError(startLoc, "unexpected delimiter");
+        return emitError(startLoc, "invalid operand");
+      }
     }
 
-    // Handle delimiters.   If we reach here, the optional delimiters were
-    // present, so we need to parse their closing one.
-    switch (delimiter) {
-    case Delimiter::None:
-      break;
-    case Delimiter::OptionalParen:
-    case Delimiter::Paren:
-      if (parser.parseToken(Token::r_paren, "expected ')' in operand list"))
-        return failure();
-      break;
-    case Delimiter::OptionalSquare:
-    case Delimiter::Square:
-      if (parser.parseToken(Token::r_square, "expected ']' in operand list"))
+    auto parseOneOperand = [&]() -> ParseResult {
+      OperandType operandOrArg;
+      if (isOperandList ? parseOperand(operandOrArg)
+                        : parseRegionArgument(operandOrArg))
         return failure();
-      break;
-    }
+      result.push_back(operandOrArg);
+      return success();
+    };
+
+    if (parseCommaSeparatedList(delimiter, parseOneOperand, " in operand list"))
+      return failure();
 
+    // Check that we got the expected # of elements.
     if (requiredOperandCount != -1 &&
         result.size() != static_cast<size_t>(requiredOperandCount))
       return emitError(startLoc, "expected ")

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index f8a2c74e455f4..9a1894deee7bf 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -24,6 +24,8 @@ namespace detail {
 /// include state.
 class Parser {
 public:
+  using Delimiter = OpAsmParser::Delimiter;
+
   Builder builder;
 
   Parser(ParserState &state) : builder(state.context), state(state) {}
@@ -39,9 +41,20 @@ class Parser {
                                function_ref<ParseResult()> parseElement,
                                bool allowEmptyList = true);
 
+  /// Parse a list of comma-separated items with an optional delimiter.  If a
+  /// delimiter is provided, then an empty list is allowed.  If not, then at
+  /// least one element will be parsed.
+  ParseResult
+  parseCommaSeparatedList(Delimiter delimiter,
+                          function_ref<ParseResult()> parseElementFn,
+                          StringRef contextMessage = StringRef());
+
   /// Parse a comma separated list of elements that must have at least one entry
   /// in it.
-  ParseResult parseCommaSeparatedList(function_ref<ParseResult()> parseElement);
+  ParseResult
+  parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
+    return parseCommaSeparatedList(Delimiter::None, parseElementFn);
+  }
 
   ParseResult parsePrettyDialectSymbolName(StringRef &prettyName);
 
@@ -276,7 +289,7 @@ class Parser {
   ParseResult
   parseAffineMapOfSSAIds(AffineMap &map,
                          function_ref<ParseResult(bool)> parseElement,
-                         OpAsmParser::Delimiter delimiter);
+                         Delimiter delimiter);
 
   /// Parse an AffineExpr where dim and symbol identifiers are SSA ids.
   ParseResult

diff  --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp
index 66ec3b4237081..46fbf52d25bb6 100644
--- a/mlir/lib/Parser/TypeParser.cpp
+++ b/mlir/lib/Parser/TypeParser.cpp
@@ -165,15 +165,12 @@ ParseResult Parser::parseStridedLayout(int64_t &offset,
     return emitError("expected comma after offset value");
 
   // Parse stride list.
-  if (!consumeIf(Token::kw_strides))
-    return emitError("expected `strides` keyword after offset specification");
-  if (!consumeIf(Token::colon))
-    return emitError("expected colon after `strides` keyword");
-  if (failed(parseStrideList(strides)))
-    return emitError("invalid braces-enclosed stride list");
-  if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
-    return emitError("invalid memref stride");
+  if (parseToken(Token::kw_strides,
+                 "expected `strides` keyword after offset specification") ||
 
+      parseToken(Token::colon, "expected colon after `strides` keyword") ||
+      parseStrideList(strides))
+    return failure();
   return success();
 }
 
@@ -560,31 +557,30 @@ ParseResult Parser::parseXInDimensionList() {
 // Parse a comma-separated list of dimensions, possibly empty:
 //   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
-  if (!consumeIf(Token::l_square))
-    return failure();
-  // Empty list early exit.
-  if (consumeIf(Token::r_square))
-    return success();
-  while (true) {
-    if (consumeIf(Token::question)) {
-      dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
-    } else {
-      // This must be an integer value.
-      int64_t val;
-      if (getToken().getSpelling().getAsInteger(10, val))
-        return emitError("invalid integer value: ") << getToken().getSpelling();
-      // Make sure it is not the one value for `?`.
-      if (ShapedType::isDynamic(val))
-        return emitError("invalid integer value: ")
-               << getToken().getSpelling()
-               << ", use `?` to specify a dynamic dimension";
-      dimensions.push_back(val);
-      consumeToken(Token::integer);
-    }
-    if (!consumeIf(Token::comma))
-      break;
-  }
-  if (!consumeIf(Token::r_square))
-    return failure();
-  return success();
+  return parseCommaSeparatedList(
+      Delimiter::Square,
+      [&]() -> ParseResult {
+        if (consumeIf(Token::question)) {
+          dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
+        } else {
+          // This must be an integer value.
+          int64_t val;
+          if (getToken().getSpelling().getAsInteger(10, val))
+            return emitError("invalid integer value: ")
+                   << getToken().getSpelling();
+          // Make sure it is not the one value for `?`.
+          if (ShapedType::isDynamic(val))
+            return emitError("invalid integer value: ")
+                   << getToken().getSpelling()
+                   << ", use `?` to specify a dynamic dimension";
+
+          if (val == 0)
+            return emitError("invalid memref stride");
+
+          dimensions.push_back(val);
+          consumeToken(Token::integer);
+        }
+        return success();
+      },
+      " in stride list");
 }

diff  --git a/mlir/test/IR/invalid-affinemap.mlir b/mlir/test/IR/invalid-affinemap.mlir
index 9377824f006a5..acca3cd32fe42 100644
--- a/mlir/test/IR/invalid-affinemap.mlir
+++ b/mlir/test/IR/invalid-affinemap.mlir
@@ -29,7 +29,9 @@
 #hello_world = affine_map<(i, j) [s0] -> (((s0 + (i + j) + 5), j)> // expected-error {{expected ')'}}
 
 // -----
-#hello_world = affine_map<(i, j) [s0] -> i + s0, j)> // expected-error {{expected '(' at start of affine map range}}
+ 
+// expected-error @+1 {{expected '(' in affine map range}}
+#hello_world = affine_map<(i, j) [s0] -> i + s0, j)>
 
 // -----
 #hello_world = affine_map<(i, j) [s0] -> (x)> // expected-error {{use of undeclared identifier}}
@@ -47,6 +49,8 @@
 #hello_world = affine_map<(i, j) [s0, s1] -> (+i, j)> // expected-error {{missing left operand of binary op}}
 
 // -----
+
+
 #hello_world = affine_map<(i, j) [s0, s1] -> (i, *j)> // expected-error {{missing left operand of binary op}}
 
 // -----
@@ -91,7 +95,8 @@
 #hello_world = affine_map<(i, j) [s0, s1] -> (i, i mod (2+i))> // expected-error {{non-affine expression: right operand of mod has to be either a constant or symbolic}}
 
 // -----
-#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)> // expected-error {{expected ',' or ')'}}
+// expected-error @+1 {{expected ')' in affine map range}}
+#hello_world = affine_map<(i, j) [s0, s1] -> (-1*i j, j)>
 
 // -----
 #hello_world = affine_map<(i, j) -> (i, 3*d0 + )> // expected-error {{use of undeclared identifier}}

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index ee4291606e42d..80b0fae523e2a 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -92,7 +92,8 @@ func @memref_stride_missing_colon_2(memref<42x42xi8, offset: 0, strides [?, ?]>)
 
 // -----
 
-func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>) // expected-error {{invalid braces-enclosed stride list}}
+// expected-error @+1 {{expected '['}}
+func @memref_stride_invalid_strides(memref<42x42xi8, offset: 0, strides: ()>)
 
 // -----
 
@@ -633,7 +634,8 @@ func @invalid_bound_map(%N : i32) {
 
 // -----
 
-#set0 = affine_set<(i)[N, M] : )i >= 0)> // expected-error {{expected '(' at start of integer set constraint list}}
+// expected-error @+1 {{expected '(' in integer set constraint list}}
+#set0 = affine_set<(i)[N, M] : )i >= 0)>
 
 // -----
 #set0 = affine_set<(i)[N] : (i >= 0, N - i >= 0)>


        


More information about the Mlir-commits mailing list