[Mlir-commits] [mlir] 537c601 - [mlir][AttrType] Emit unbalanced character errors using the last punctuation used

River Riddle llvmlistbot at llvm.org
Thu Jan 12 14:58:26 PST 2023


Author: River Riddle
Date: 2023-01-12T14:58:05-08:00
New Revision: 537c6016c5855614d12c9ae35032b7bb0fd4f57b

URL: https://github.com/llvm/llvm-project/commit/537c6016c5855614d12c9ae35032b7bb0fd4f57b
DIFF: https://github.com/llvm/llvm-project/commit/537c6016c5855614d12c9ae35032b7bb0fd4f57b.diff

LOG: [mlir][AttrType] Emit unbalanced character errors using the last punctuation used

This gives a better diagnostic in general, because it indicates that the user didn't
close out the last open puncutation range. For example:

```
foo.op {
  some.op -> !blah.pointer<
}
```
We want the error to hint about the unclosed `<`, not the `}` (which isn't really in the context of the type).

Differential Revision: https://reviews.llvm.org/D141635

Added: 
    

Modified: 
    mlir/lib/AsmParser/DialectSymbolParser.cpp
    mlir/test/Dialect/SPIRV/IR/types.mlir
    mlir/test/IR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 8f1d69b4966ad..5902b6f381256 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -64,6 +64,19 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
   assert(*curPtr == '<');
   SmallVector<char, 8> nestedPunctuation;
   const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
+
+  // Functor used to emit an unbalanced punctuation error.
+  auto emitPunctError = [&] {
+    return emitError() << "unbalanced '" << nestedPunctuation.back()
+                       << "' character in pretty dialect name";
+  };
+  // Functor used to check for unbalanced punctuation.
+  auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
+    if (nestedPunctuation.back() != expectedToken)
+      return emitPunctError();
+    nestedPunctuation.pop_back();
+    return success();
+  };
   do {
     // Handle code completions, which may appear in the middle of the symbol
     // body.
@@ -77,10 +90,8 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
     switch (c) {
     case '\0':
       // This also handles the EOF case.
-      if (!nestedPunctuation.empty()) {
-        return emitError() << "unbalanced '" << nestedPunctuation.back()
-                           << "' character in pretty dialect name";
-      }
+      if (!nestedPunctuation.empty())
+        return emitPunctError();
       return emitError("unexpected nul or EOF in pretty dialect name");
     case '<':
     case '[':
@@ -96,20 +107,20 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
       continue;
 
     case '>':
-      if (nestedPunctuation.pop_back_val() != '<')
-        return emitError("unbalanced '>' character in pretty dialect name");
+      if (failed(checkNestedPunctuation('<')))
+        return failure();
       break;
     case ']':
-      if (nestedPunctuation.pop_back_val() != '[')
-        return emitError("unbalanced ']' character in pretty dialect name");
+      if (failed(checkNestedPunctuation('[')))
+        return failure();
       break;
     case ')':
-      if (nestedPunctuation.pop_back_val() != '(')
-        return emitError("unbalanced ')' character in pretty dialect name");
+      if (failed(checkNestedPunctuation('(')))
+        return failure();
       break;
     case '}':
-      if (nestedPunctuation.pop_back_val() != '{')
-        return emitError("unbalanced '}' character in pretty dialect name");
+      if (failed(checkNestedPunctuation('{')))
+        return failure();
       break;
     case '"': {
       // Dispatch to the lexer to lex past strings.

diff  --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir
index 28b9241f00852..7e2833e79646e 100644
--- a/mlir/test/Dialect/SPIRV/IR/types.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/types.mlir
@@ -320,12 +320,12 @@ func.func private @struct_type_missing_comma2(!spirv.struct<(f32 [0] i32)>) -> (
 
 // -----
 
-//  expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+//  expected-error @+1 {{unbalanced '[' character in pretty dialect name}}
 func.func private @struct_type_neg_offset(!spirv.struct<(f32 [0)>) -> ()
 
 // -----
 
-//  expected-error @+1 {{unbalanced ']' character in pretty dialect name}}
+//  expected-error @+1 {{unbalanced '(' character in pretty dialect name}}
 func.func private @struct_type_neg_offset(!spirv.struct<(f32 0])>) -> ()
 
 // -----
@@ -497,7 +497,7 @@ func.func private @matrix_invalid_format(!spirv.matrix 3 x vector<3xf32>>) -> ()
 
 // -----
 
-// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
 func.func private @matrix_invalid_format(!spirv.matrix< 3 x vector<3xf32>) -> ()
 
 // -----

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index dd97fcbf8e17e..9bf7efd8ae5d8 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -369,7 +369,7 @@ func.func @dialect_type_empty_namespace(!<"">) -> () { // expected-error {{inval
 
 // -----
 
-func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced ')' character in pretty dialect name}}
+func.func @dialect_type_missing_greater(!foo<) -> () { // expected-error {{unbalanced '<' character in pretty dialect name}}
   return
 
 // -----
@@ -414,7 +414,7 @@ func.func @invalid_nested_dominance() {
 
 // -----
 
-// expected-error @+1 {{unbalanced ']' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
 func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect<!x@#]!@#>
 
 // -----
@@ -582,7 +582,7 @@ func.func @invalid_region_dominance() {
 
 // -----
 
-// expected-error @+1 {{unbalanced ')' character in pretty dialect name}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
 func.func @bad_arrow(%arg : !unreg.ptr<(i32)->)
 
 // -----


        


More information about the Mlir-commits mailing list