[Mlir-commits] [mlir] 537c601 - [mlir][AttrType] Emit unbalanced character errors using the last punctuation used
River Riddle
llvmlistbot at llvm.org
Thu Jan 12 14:58:26 PST 2023
Author: River Riddle
Date: 2023-01-12T14:58:05-08:00
New Revision: 537c6016c5855614d12c9ae35032b7bb0fd4f57b
URL: https://github.com/llvm/llvm-project/commit/537c6016c5855614d12c9ae35032b7bb0fd4f57b
DIFF: https://github.com/llvm/llvm-project/commit/537c6016c5855614d12c9ae35032b7bb0fd4f57b.diff
LOG: [mlir][AttrType] Emit unbalanced character errors using the last punctuation used
This gives a better diagnostic in general, because it indicates that the user didn't
close out the last open puncutation range. For example:
```
foo.op {
some.op -> !blah.pointer<
}
```
We want the error to hint about the unclosed `<`, not the `}` (which isn't really in the context of the type).
Differential Revision: https://reviews.llvm.org/D141635
Added:
Modified:
mlir/lib/AsmParser/DialectSymbolParser.cpp
mlir/test/Dialect/SPIRV/IR/types.mlir
mlir/test/IR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 8f1d69b4966ad..5902b6f381256 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -64,6 +64,19 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
assert(*curPtr == '<');
SmallVector<char, 8> nestedPunctuation;
const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
+
+ // Functor used to emit an unbalanced punctuation error.
+ auto emitPunctError = [&] {
+ return emitError() << "unbalanced '" << nestedPunctuation.back()
+ << "' character in pretty dialect name";
+ };
+ // Functor used to check for unbalanced punctuation.
+ auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
+ if (nestedPunctuation.back() != expectedToken)
+ return emitPunctError();
+ nestedPunctuation.pop_back();
+ return success();
+ };
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
@@ -77,10 +90,8 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
switch (c) {
case '\0':
// This also handles the EOF case.
- if (!nestedPunctuation.empty()) {
- return emitError() << "unbalanced '" << nestedPunctuation.back()
- << "' character in pretty dialect name";
- }
+ if (!nestedPunctuation.empty())
+ return emitPunctError();
return emitError("unexpected nul or EOF in pretty dialect name");
case '<':
case '[':
@@ -96,20 +107,20 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
continue;
case '>':
- if (nestedPunctuation.pop_back_val() != '<')
- return emitError("unbalanced '>' character in pretty dialect name");
+ if (failed(checkNestedPunctuation('<')))
+ return failure();
break;
case ']':
- if (nestedPunctuation.pop_back_val() != '[')
- return emitError("unbalanced ']' character in pretty dialect name");
+ if (failed(checkNestedPunctuation('[')))
+ return failure();
break;
case ')':
- if (nestedPunctuation.pop_back_val() != '(')
- return emitError("unbalanced ')' character in pretty dialect name");
+ if (failed(checkNestedPunctuation('(')))
+ return failure();
break;
case '}':
- if (nestedPunctuation.pop_back_val() != '{')
- return emitError("unbalanced '}' character in pretty dialect name");
+ if (failed(checkNestedPunctuation('{')))
+ return failure();
break;
case '"': {
// Dispatch to the lexer to lex past strings.
diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 28b9241f00852..7e2833e79646e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -320,12 +320,12 @@ func.func private @struct_type_missing_comma2(!spirv.struct<(f32 [0] i32)>) -> (
// -----
-// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '[' character in pretty dialect name}}
func.func private @struct_type_neg_offset(!spirv.struct<(f32 [0)>) -> ()
// -----
-// expected-error @+1 {{unbalanced ']' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '(' character in pretty dialect name}}
func.func private @struct_type_neg_offset(!spirv.struct<(f32 0])>) -> ()
// -----
@@ -497,7 +497,7 @@ func.func private @matrix_invalid_format(!spirv.matrix 3 x vector<3xf32>>) -> ()
// -----
-// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func private @matrix_invalid_format(!spirv.matrix< 3 x vector<3xf32>) -> ()
// -----
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index dd97fcbf8e17e..9bf7efd8ae5d8 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -369,7 +369,7 @@ func.func @dialect_type_empty_namespace(!<"">) -> () { // expected-error {{inval
// -----
-func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced ')' character in pretty dialect name}}
+func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced '<' character in pretty dialect name}}
return
// -----
@@ -414,7 +414,7 @@ func.func @invalid_nested_dominance() {
// -----
-// expected-error @+1 {{unbalanced ']' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect<!x@#]!@#>
// -----
@@ -582,7 +582,7 @@ func.func @invalid_region_dominance() {
// -----
-// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @bad_arrow(%arg : !unreg.ptr<(i32)->)
// -----
More information about the Mlir-commits
mailing list