[Mlir-commits] [mlir] f5f8a46 - [mlir][AsmParser] Improve parse{Attribute, Type} error handling
Rahul Kayaith
llvmlistbot at llvm.org
Wed Mar 1 14:15:04 PST 2023
Author: Rahul Kayaith
Date: 2023-03-01T17:14:59-05:00
New Revision: f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29
URL: https://github.com/llvm/llvm-project/commit/f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29
DIFF: https://github.com/llvm/llvm-project/commit/f5f8a46bb0ce3dd2f3e024696d5a0aef5fb12a29.diff
LOG: [mlir][AsmParser] Improve parse{Attribute,Type} error handling
Currently these functions report errors directly to stderr, this updates
them to use diagnostics instead. This also makes partially-consumed
strings an error if the `numRead` parameter isn't provided (the
docstrings already claimed this happened, but it didn't.)
While here I also tried to reduce the number of overloads by switching
to using default parameters.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D144804
Added:
Modified:
mlir/include/mlir/AsmParser/AsmParser.h
mlir/lib/AsmParser/DialectSymbolParser.cpp
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-pad.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/unittests/Parser/ParserTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index 60ce797f01589..d60df41198986 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -43,38 +43,22 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
AsmParserState *asmState = nullptr,
AsmParserCodeCompleteContext *codeCompleteContext = nullptr);
-/// This parses a single MLIR attribute to an MLIR context if it was valid. If
-/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `attrStr`. If the passed `attrStr` has additional tokens that were not part
-/// of the type, an error is emitted.
-// TODO: Improve diagnostic reporting.
-Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context);
-Attribute parseAttribute(llvm::StringRef attrStr, Type type);
-
-/// This parses a single MLIR attribute to an MLIR context if it was valid. If
-/// not, an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `attrStr`. The number of characters of `attrStr` parsed in the process is
-/// returned in `numRead`.
+/// This parses a single MLIR attribute to an MLIR context if it was valid. If
+/// not, an error diagnostic is emitted to the context and a null value is
+/// returned.
+/// If `numRead` is provided, it is set to the number of consumed characters on
+/// succesful parse. Otherwise, parsing fails if the entire string is not
+/// consumed.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
- size_t &numRead);
-Attribute parseAttribute(llvm::StringRef attrStr, Type type, size_t &numRead);
-
-/// This parses a single MLIR type to an MLIR context if it was valid. If not,
-/// an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `typeStr`. If the passed `typeStr` has additional tokens that were not part
-/// of the type, an error is emitted.
-// TODO: Improve diagnostic reporting.
-Type parseType(llvm::StringRef typeStr, MLIRContext *context);
+ Type type = {}, size_t *numRead = nullptr);
-/// This parses a single MLIR type to an MLIR context if it was valid. If not,
-/// an error message is emitted through a new SourceMgrDiagnosticHandler
-/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
-/// `typeStr`. The number of characters of `typeStr` parsed in the process is
-/// returned in `numRead`.
-Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t &numRead);
+/// This parses a single MLIR type to an MLIR context if it was valid. If not,
+/// an error diagnostic is emitted to the context.
+/// If `numRead` is provided, it is set to the number of consumed characters on
+/// succesful parse. Otherwise, parsing fails if the entire string is not
+/// consumed.
+Type parseType(llvm::StringRef typeStr, MLIRContext *context,
+ size_t *numRead = nullptr);
/// This parses a single IntegerSet/AffineMap to an MLIR context if it was
/// valid. If not, an error message is emitted through a new
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 5902b6f381256..a3198e050b149 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -309,12 +309,13 @@ Type Parser::parseExtendedType() {
/// parsing failed, nullptr is returned. The number of bytes read from the input
/// string is returned in 'numRead'.
template <typename T, typename ParserFn>
-static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
- ParserFn &&parserFn) {
+static T parseSymbol(StringRef inputStr, MLIRContext *context,
+ size_t *numReadOut, ParserFn &&parserFn) {
+ // Set the buffer name to the string being parsed, so that it appears in error
+ // diagnostics.
+ auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr,
+ /*RequiresNullTerminator=*/true);
SourceMgr sourceMgr;
- auto memBuffer = MemoryBuffer::getMemBuffer(
- inputStr, /*BufferName=*/"<mlir_parser_buffer>",
- /*RequiresNullTerminator=*/false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
ParserConfig config(context);
@@ -322,9 +323,6 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
/*codeCompleteContext=*/nullptr);
Parser parser(state);
- SourceMgrDiagnosticHandler handler(
- const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
- parser.getContext());
Token startTok = parser.getToken();
T symbol = parserFn(parser);
if (!symbol)
@@ -332,38 +330,25 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
// Provide the number of bytes that were read.
Token endTok = parser.getToken();
- numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
- startTok.getLoc().getPointer());
+ size_t numRead =
+ endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
+ if (numReadOut) {
+ *numReadOut = numRead;
+ } else if (numRead != inputStr.size()) {
+ parser.emitError(endTok.getLoc()) << "found trailing characters: '"
+ << inputStr.drop_front(numRead) << "'";
+ return T();
+ }
return symbol;
}
-Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
- size_t numRead = 0;
- return parseAttribute(attrStr, context, numRead);
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
- size_t numRead = 0;
- return parseAttribute(attrStr, type, numRead);
-}
-
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
- size_t &numRead) {
- return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
- return parser.parseAttribute();
- });
-}
-Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
+ Type type, size_t *numRead) {
return parseSymbol<Attribute>(
- attrStr, type.getContext(), numRead,
+ attrStr, context, numRead,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
- size_t numRead = 0;
- return parseType(typeStr, context, numRead);
-}
-
-Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) {
return parseSymbol<Type>(typeStr, context, numRead,
[](Parser &parser) { return parser.parseType(); });
}
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 4a09cb78368b1..f00962151b61b 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1031,9 +1031,9 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
size_t numRead = 0;
MLIRContext *context = fileLoc->getContext();
if constexpr (std::is_same_v<T, Type>)
- result = ::parseType(asmStr, context, numRead);
+ result = ::parseType(asmStr, context, &numRead);
else
- result = ::parseAttribute(asmStr, context, numRead);
+ result = ::parseAttribute(asmStr, context, Type(), &numRead);
if (!result)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b82c51e0dc556..dfcc2bc72c95e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1692,14 +1692,15 @@ transform::PadOp::applyToOne(LinalgOp target,
Type elementType = getElementTypeOrSelf(std::get<1>(it));
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
- paddingValues.push_back(
- parseAttribute(attr.cast<StringAttr>(), elementType));
- if (!paddingValues.back()) {
+ auto parsedAttr = dyn_cast_if_present<TypedAttr>(
+ parseAttribute(stringAttr, getContext(), elementType));
+ if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
<< elementType << ", got " << std::get<0>(it);
diag.attachNote(target.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
+ paddingValues.push_back(parsedAttr);
continue;
}
// Otherwise, add the attribute directly.
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index cf01e4715697a..685f70648b043 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -117,9 +117,9 @@ func.func @pad(%arg0: tensor<24x12xf32>,
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- // expected-error @below {{expects a padding that parses to 'f32', got "foo"}}
+ // expected-error @below {{expects a padding that parses to 'f32', got "{foo}"}}
%1 = transform.structured.pad %0 {
- padding_values=["foo", 0.0 : f32, 0.0 : f32],
+ padding_values=["{foo}", 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
}
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index a51bdb50da9ed..f88e0fab4404c 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -317,8 +317,10 @@ struct ScalarTraits<SerializedAffineMap> {
SerializedAffineMap &value) {
assert(rawYamlContext);
auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
- if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
- .dyn_cast_or_null<AffineMapAttr>())
+ std::string nullTerminatedScalar(scalar);
+ if (auto attr =
+ mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext)
+ .dyn_cast_or_null<AffineMapAttr>())
value.affineMapAttr = attr;
else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
return "could not parse as an affine map attribute";
diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp
index ef951102cc121..6b3ac5c5ddf9c 100644
--- a/mlir/unittests/Parser/ParserTest.cpp
+++ b/mlir/unittests/Parser/ParserTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Parser/Parser.h"
+#include "mlir/AsmParser/AsmParser.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
@@ -55,4 +56,44 @@ TEST(MLIRParser, ParseAtEnd) {
EXPECT_EQ(block.front().getName().getStringRef(), "test.first");
EXPECT_EQ(block.back().getName().getStringRef(), "test.second");
}
+
+TEST(MLIRParser, ParseAttr) {
+ using namespace testing;
+ MLIRContext context;
+ Builder b(&context);
+ { // Successful parse
+ StringLiteral attrAsm = "array<i64: 1, 2, 3>";
+ size_t numRead = 0;
+ Attribute attr = parseAttribute(attrAsm, &context, Type(), &numRead);
+ EXPECT_EQ(attr, b.getDenseI64ArrayAttr({1, 2, 3}));
+ EXPECT_EQ(numRead, attrAsm.size());
+ }
+ { // Failed parse
+ std::vector<std::string> diagnostics;
+ ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) {
+ llvm::raw_string_ostream(diagnostics.emplace_back())
+ << d.getLocation() << ": " << d;
+ });
+ size_t numRead = 0;
+ EXPECT_FALSE(parseAttribute("dense<>", &context, Type(), &numRead));
+ EXPECT_THAT(diagnostics, ElementsAre("loc(\"dense<>\":1:7): expected ':'"));
+ EXPECT_EQ(numRead, size_t(0));
+ }
+ { // Parse with trailing characters
+ std::vector<std::string> diagnostics;
+ ScopedDiagnosticHandler handler(&context, [&](Diagnostic &d) {
+ llvm::raw_string_ostream(diagnostics.emplace_back())
+ << d.getLocation() << ": " << d;
+ });
+ EXPECT_FALSE(parseAttribute("10 foo", &context));
+ EXPECT_THAT(
+ diagnostics,
+ ElementsAre("loc(\"10 foo\":1:5): found trailing characters: 'foo'"));
+
+ size_t numRead = 0;
+ EXPECT_EQ(parseAttribute("10 foo", &context, Type(), &numRead),
+ b.getI64IntegerAttr(10));
+ EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace
+ }
+}
} // namespace
More information about the Mlir-commits
mailing list