[Mlir-commits] [mlir] f5f8a46 - [mlir][AsmParser] Improve parse{Attribute, Type} error handling

Rahul Kayaith llvmlistbot at llvm.org
Wed Mar 1 14:15:04 PST 2023


Author: Rahul Kayaith
Date: 2023-03-01T17:14:59-05:00
New Revision: f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29

URL: https://github.com/llvm/llvm-project/commit/f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29
DIFF: https://github.com/llvm/llvm-project/commit/f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29.diff

LOG: [mlir][AsmParser] Improve parse{Attribute,Type} error handling

Currently these functions report errors directly to stderr, this updates
them to use diagnostics instead. This also makes partially-consumed
strings an error if the `numRead` parameter isn't provided (the
docstrings already claimed this happened, but it didn't.)

While here I also tried to reduce the number of overloads by switching
to using default parameters.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D144804

Added: 
    

Modified: 
    mlir/include/mlir/AsmParser/AsmParser.h
    mlir/lib/AsmParser/DialectSymbolParser.cpp
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/test/Dialect/Linalg/transform-op-pad.mlir
    mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
    mlir/unittests/Parser/ParserTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index 60ce797f01589..d60df41198986 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -43,38 +43,22 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
                    AsmParserState *asmState = nullptr,
                    AsmParserCodeCompleteContext *codeCompleteContext = nullptr);
 
-/// This parses a single MLIR attribute to an MLIR context if it was valid.  If
-/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `attrStr`. If the passed `attrStr` has additional tokens that were not part
-/// of the type, an error is emitted.
-// TODO: Improve diagnostic reporting.
-Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context);
-Attribute parseAttribute(llvm::StringRef attrStr, Type type);
-
-/// This parses a single MLIR attribute to an MLIR context if it was valid.  If
-/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `attrStr`. The number of characters of `attrStr` parsed in the process is
-/// returned in `numRead`.
+/// This parses a single MLIR attribute to an MLIR context if it was valid. If
+/// not, an error diagnostic is emitted to the context and a null value is
+/// returned.
+/// If `numRead` is provided, it is set to the number of consumed characters on
+/// succesful parse. Otherwise, parsing fails if the entire string is not
+/// consumed.
 Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
-                         size_t &numRead);
-Attribute parseAttribute(llvm::StringRef attrStr, Type type, size_t &numRead);
-
-/// This parses a single MLIR type to an MLIR context if it was valid.  If not,
-/// an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `typeStr`. If the passed `typeStr` has additional tokens that were not part
-/// of the type, an error is emitted.
-// TODO: Improve diagnostic reporting.
-Type parseType(llvm::StringRef typeStr, MLIRContext *context);
+                         Type type = {}, size_t *numRead = nullptr);
 
-/// This parses a single MLIR type to an MLIR context if it was valid.  If not,
-/// an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `typeStr`. The number of characters of `typeStr` parsed in the process is
-/// returned in `numRead`.
-Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t &numRead);
+/// This parses a single MLIR type to an MLIR context if it was valid. If not,
+/// an error diagnostic is emitted to the context.
+/// If `numRead` is provided, it is set to the number of consumed characters on
+/// succesful parse. Otherwise, parsing fails if the entire string is not
+/// consumed.
+Type parseType(llvm::StringRef typeStr, MLIRContext *context,
+               size_t *numRead = nullptr);
 
 /// This parses a single IntegerSet/AffineMap to an MLIR context if it was
 /// valid. If not, an error message is emitted through a new

diff  --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 5902b6f381256..a3198e050b149 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -309,12 +309,13 @@ Type Parser::parseExtendedType() {
 /// parsing failed, nullptr is returned. The number of bytes read from the input
 /// string is returned in 'numRead'.
 template <typename T, typename ParserFn>
-static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
-                     ParserFn &&parserFn) {
+static T parseSymbol(StringRef inputStr, MLIRContext *context,
+                     size_t *numReadOut, ParserFn &&parserFn) {
+  // Set the buffer name to the string being parsed, so that it appears in error
+  // diagnostics.
+  auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr,
+                                              /*RequiresNullTerminator=*/true);
   SourceMgr sourceMgr;
-  auto memBuffer = MemoryBuffer::getMemBuffer(
-      inputStr, /*BufferName=*/"<mlir_parser_buffer>",
-      /*RequiresNullTerminator=*/false);
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SymbolState aliasState;
   ParserConfig config(context);
@@ -322,9 +323,6 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
                     /*codeCompleteContext=*/nullptr);
   Parser parser(state);
 
-  SourceMgrDiagnosticHandler handler(
-      const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
-      parser.getContext());
   Token startTok = parser.getToken();
   T symbol = parserFn(parser);
   if (!symbol)
@@ -332,38 +330,25 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
 
   // Provide the number of bytes that were read.
   Token endTok = parser.getToken();
-  numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
-                                startTok.getLoc().getPointer());
+  size_t numRead =
+      endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
+  if (numReadOut) {
+    *numReadOut = numRead;
+  } else if (numRead != inputStr.size()) {
+    parser.emitError(endTok.getLoc()) << "found trailing characters: '"
+                                      << inputStr.drop_front(numRead) << "'";
+    return T();
+  }
   return symbol;
 }
 
-Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
-  size_t numRead = 0;
-  return parseAttribute(attrStr, context, numRead);
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
-  size_t numRead = 0;
-  return parseAttribute(attrStr, type, numRead);
-}
-
 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
-                               size_t &numRead) {
-  return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
-    return parser.parseAttribute();
-  });
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
+                               Type type, size_t *numRead) {
   return parseSymbol<Attribute>(
-      attrStr, type.getContext(), numRead,
+      attrStr, context, numRead,
       [type](Parser &parser) { return parser.parseAttribute(type); });
 }
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
-  size_t numRead = 0;
-  return parseType(typeStr, context, numRead);
-}
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) {
   return parseSymbol<Type>(typeStr, context, numRead,
                            [](Parser &parser) { return parser.parseType(); });
 }

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 4a09cb78368b1..f00962151b61b 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1031,9 +1031,9 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
   size_t numRead = 0;
   MLIRContext *context = fileLoc->getContext();
   if constexpr (std::is_same_v<T, Type>)
-    result = ::parseType(asmStr, context, numRead);
+    result = ::parseType(asmStr, context, &numRead);
   else
-    result = ::parseAttribute(asmStr, context, numRead);
+    result = ::parseAttribute(asmStr, context, Type(), &numRead);
   if (!result)
     return failure();
 

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b82c51e0dc556..dfcc2bc72c95e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1692,14 +1692,15 @@ transform::PadOp::applyToOne(LinalgOp target,
     Type elementType = getElementTypeOrSelf(std::get<1>(it));
     // Try to parse string attributes to obtain an attribute of element type.
     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
-      paddingValues.push_back(
-          parseAttribute(attr.cast<StringAttr>(), elementType));
-      if (!paddingValues.back()) {
+      auto parsedAttr = dyn_cast_if_present<TypedAttr>(
+          parseAttribute(stringAttr, getContext(), elementType));
+      if (!parsedAttr || parsedAttr.getType() != elementType) {
         auto diag = this->emitOpError("expects a padding that parses to ")
                     << elementType << ", got " << std::get<0>(it);
         diag.attachNote(target.getLoc()) << "when applied to this op";
         return DiagnosedSilenceableFailure::definiteFailure();
       }
+      paddingValues.push_back(parsedAttr);
       continue;
     }
     // Otherwise, add the attribute directly.

diff  --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index cf01e4715697a..685f70648b043 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -117,9 +117,9 @@ func.func @pad(%arg0: tensor<24x12xf32>,
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
   %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-  // expected-error @below {{expects a padding that parses to 'f32', got "foo"}}
+  // expected-error @below {{expects a padding that parses to 'f32', got "{foo}"}}
   %1 = transform.structured.pad %0 {
-    padding_values=["foo", 0.0 : f32, 0.0 : f32],
+    padding_values=["{foo}", 0.0 : f32, 0.0 : f32],
     padding_dimensions=[0, 1, 2],
     pack_paddings=[1, 1, 0]
   }

diff  --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index a51bdb50da9ed..f88e0fab4404c 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -317,8 +317,10 @@ struct ScalarTraits<SerializedAffineMap> {
                          SerializedAffineMap &value) {
     assert(rawYamlContext);
     auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
-    if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
-                        .dyn_cast_or_null<AffineMapAttr>())
+    std::string nullTerminatedScalar(scalar);
+    if (auto attr =
+            mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext)
+                .dyn_cast_or_null<AffineMapAttr>())
       value.affineMapAttr = attr;
     else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
       return "could not parse as an affine map attribute";

diff  --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp
index ef951102cc121..6b3ac5c5ddf9c 100644
--- a/mlir/unittests/Parser/ParserTest.cpp
+++ b/mlir/unittests/Parser/ParserTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Parser/Parser.h"
+#include "mlir/AsmParser/AsmParser.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Verifier.h"
 
@@ -55,4 +56,44 @@ TEST(MLIRParser, ParseAtEnd) {
   EXPECT_EQ(block.front().getName().getStringRef(), "test.first");
   EXPECT_EQ(block.back().getName().getStringRef(), "test.second");
 }
+
+TEST(MLIRParser, ParseAttr) {
+  using namespace testing;
+  MLIRContext context;
+  Builder b(&context);
+  { // Successful parse
+    StringLiteral attrAsm = "array<i64: 1, 2, 3>";
+    size_t numRead = 0;
+    Attribute attr = parseAttribute(attrAsm, &context, Type(), &numRead);
+    EXPECT_EQ(attr, b.getDenseI64ArrayAttr({1, 2, 3}));
+    EXPECT_EQ(numRead, attrAsm.size());
+  }
+  { // Failed parse
+    std::vector<std::string> diagnostics;
+    ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) {
+      llvm::raw_string_ostream(diagnostics.emplace_back())
+          << d.getLocation() << ": " << d;
+    });
+    size_t numRead = 0;
+    EXPECT_FALSE(parseAttribute("dense<>", &context, Type(), &numRead));
+    EXPECT_THAT(diagnostics, ElementsAre("loc(\"dense<>\":1:7): expected ':'"));
+    EXPECT_EQ(numRead, size_t(0));
+  }
+  { // Parse with trailing characters
+    std::vector<std::string> diagnostics;
+    ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) {
+      llvm::raw_string_ostream(diagnostics.emplace_back())
+          << d.getLocation() << ": " << d;
+    });
+    EXPECT_FALSE(parseAttribute("10  foo", &context));
+    EXPECT_THAT(
+        diagnostics,
+        ElementsAre("loc(\"10  foo\":1:5): found trailing characters: 'foo'"));
+
+    size_t numRead = 0;
+    EXPECT_EQ(parseAttribute("10  foo", &context, Type(), &numRead),
+              b.getI64IntegerAttr(10));
+    EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace
+  }
+}
 } // namespace


        


More information about the Mlir-commits mailing list