[Mlir-commits] [mlir] 55cf53f - [mlir][Parser] Make parse{Attribute, Type} null-terminate input
Rahul Kayaith
llvmlistbot at llvm.org
Fri Mar 3 14:03:32 PST 2023
Author: Rahul Kayaith
Date: 2023-03-03T17:03:27-05:00
New Revision: 55cf53fd0f5594eb701b5760729fdc2bd4a70584
URL: https://github.com/llvm/llvm-project/commit/55cf53fd0f5594eb701b5760729fdc2bd4a70584
DIFF: https://github.com/llvm/llvm-project/commit/55cf53fd0f5594eb701b5760729fdc2bd4a70584.diff
LOG: [mlir][Parser] Make parse{Attribute,Type} null-terminate input
`parseAttribute` and `parseType` require null-terminated strings as
input, but this isn't great considering the argument type is
`StringRef`. This changes them to copy to a null-terminated buffer by
default, with a `isKnownNullTerminated` flag added to disable the
copying.
closes #58964
Reviewed By: rriddle, kuhar, lattner
Differential Revision: https://reviews.llvm.org/D145182
Added:
Modified:
mlir/include/mlir/AsmParser/AsmParser.h
mlir/lib/AsmParser/DialectSymbolParser.cpp
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/unittests/Parser/ParserTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h
index d60df41198986..3c1bff1fbc7f1 100644
--- a/mlir/include/mlir/AsmParser/AsmParser.h
+++ b/mlir/include/mlir/AsmParser/AsmParser.h
@@ -49,16 +49,21 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
/// If `numRead` is provided, it is set to the number of consumed characters on
/// succesful parse. Otherwise, parsing fails if the entire string is not
/// consumed.
+/// Some internal copying can be skipped if the source string is known to be
+/// null terminated.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
- Type type = {}, size_t *numRead = nullptr);
+ Type type = {}, size_t *numRead = nullptr,
+ bool isKnownNullTerminated = false);
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error diagnostic is emitted to the context.
/// If `numRead` is provided, it is set to the number of consumed characters on
/// succesful parse. Otherwise, parsing fails if the entire string is not
/// consumed.
+/// Some internal copying can be skipped if the source string is known to be
+/// null terminated.
Type parseType(llvm::StringRef typeStr, MLIRContext *context,
- size_t *numRead = nullptr);
+ size_t *numRead = nullptr, bool isKnownNullTerminated = false);
/// This parses a single IntegerSet/AffineMap to an MLIR context if it was
/// valid. If not, an error message is emitted through a new
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index a3198e050b149..c98b36862fb53 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -306,15 +306,18 @@ Type Parser::parseExtendedType() {
//===----------------------------------------------------------------------===//
/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
-/// parsing failed, nullptr is returned. The number of bytes read from the input
-/// string is returned in 'numRead'.
+/// parsing failed, nullptr is returned.
template <typename T, typename ParserFn>
static T parseSymbol(StringRef inputStr, MLIRContext *context,
- size_t *numReadOut, ParserFn &&parserFn) {
+ size_t *numReadOut, bool isKnownNullTerminated,
+ ParserFn &&parserFn) {
// Set the buffer name to the string being parsed, so that it appears in error
// diagnostics.
- auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr,
- /*RequiresNullTerminator=*/true);
+ auto memBuffer =
+ isKnownNullTerminated
+ ? MemoryBuffer::getMemBuffer(inputStr,
+ /*BufferName=*/inputStr)
+ : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState aliasState;
@@ -343,12 +346,14 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
}
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
- Type type, size_t *numRead) {
+ Type type, size_t *numRead,
+ bool isKnownNullTerminated) {
return parseSymbol<Attribute>(
- attrStr, context, numRead,
+ attrStr, context, numRead, isKnownNullTerminated,
[type](Parser &parser) { return parser.parseAttribute(type); });
}
-Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) {
- return parseSymbol<Type>(typeStr, context, numRead,
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
+ bool isKnownNullTerminated) {
+ return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
[](Parser &parser) { return parser.parseType(); });
}
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index f00962151b61b..5e71c3a9e5f45 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -1031,9 +1031,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
size_t numRead = 0;
MLIRContext *context = fileLoc->getContext();
if constexpr (std::is_same_v<T, Type>)
- result = ::parseType(asmStr, context, &numRead);
+ result =
+ ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
else
- result = ::parseAttribute(asmStr, context, Type(), &numRead);
+ result = ::parseAttribute(asmStr, context, Type(), &numRead,
+ /*isKnownNullTerminated=*/true);
if (!result)
return failure();
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index dfcc2bc72c95e..600cddec47d17 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -1693,7 +1693,8 @@ transform::PadOp::applyToOne(LinalgOp target,
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(
- parseAttribute(stringAttr, getContext(), elementType));
+ parseAttribute(stringAttr, getContext(), elementType,
+ /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
<< elementType << ", got " << std::get<0>(it);
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
index f88e0fab4404c..a51bdb50da9ed 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -317,10 +317,8 @@ struct ScalarTraits<SerializedAffineMap> {
SerializedAffineMap &value) {
assert(rawYamlContext);
auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
- std::string nullTerminatedScalar(scalar);
- if (auto attr =
- mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext)
- .dyn_cast_or_null<AffineMapAttr>())
+ if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
+ .dyn_cast_or_null<AffineMapAttr>())
value.affineMapAttr = attr;
else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
return "could not parse as an affine map attribute";
diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp
index 6b3ac5c5ddf9c..62f609ecf8049 100644
--- a/mlir/unittests/Parser/ParserTest.cpp
+++ b/mlir/unittests/Parser/ParserTest.cpp
@@ -95,5 +95,10 @@ TEST(MLIRParser, ParseAttr) {
b.getI64IntegerAttr(10));
EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace
}
+ { // Parse without null-terminator
+ StringRef attrAsm("999", 1);
+ Attribute attr = parseAttribute(attrAsm, &context);
+ EXPECT_EQ(attr, b.getI64IntegerAttr(9));
+ }
}
} // namespace
More information about the Mlir-commits
mailing list