[Mlir-commits] [mlir] 18546ff - [mlir:Bytecode] Add shared_ptr<SourceMgr> overloads to allow safe mmap of data
River Riddle
llvmlistbot at llvm.org
Sun Dec 11 22:45:54 PST 2022
Author: River Riddle
Date: 2022-12-11T22:45:34-08:00
New Revision: 18546ff8dd45a81e72c0a2ed0561b5aec8c15ca3
URL: https://github.com/llvm/llvm-project/commit/18546ff8dd45a81e72c0a2ed0561b5aec8c15ca3
DIFF: https://github.com/llvm/llvm-project/commit/18546ff8dd45a81e72c0a2ed0561b5aec8c15ca3.diff
LOG: [mlir:Bytecode] Add shared_ptr<SourceMgr> overloads to allow safe mmap of data
The bytecode reader currently has no mechanism that allows for directly referencing
data from the input buffer safely. This commit adds shared_ptr<SourceMgr> overloads
that provide an explicit and safe way of extending the lifetime of the input. The usage of
these new overloads is adopted in all of our tooling, and is implicitly used in the filename
only parser methods.
Differential Revision: https://reviews.llvm.org/D139366
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeReader.h
mlir/include/mlir/IR/AsmState.h
mlir/include/mlir/Parser/Parser.h
mlir/include/mlir/Tools/ParseUtilities.h
mlir/include/mlir/Tools/mlir-translate/Translation.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
mlir/lib/Tools/mlir-translate/Translation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
index 68e1d1a0e53f4..d7cb916646035 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReader.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -18,6 +18,7 @@
namespace llvm {
class MemoryBufferRef;
+class SourceMgr;
} // namespace llvm
namespace mlir {
@@ -29,6 +30,12 @@ bool isBytecode(llvm::MemoryBufferRef buffer);
/// bytecode, into the provided block.
LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
const ParserConfig &config);
+/// An overload with a source manager whose main file buffer is used for
+/// parsing. The lifetime of the source manager may be freely extended during
+/// parsing such that the source manager is not destroyed before the parsed IR.
+LogicalResult
+readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config);
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEREADER_H
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index adec4b721b909..1f12f7c6ad4c5 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -215,17 +215,19 @@ class UnmanagedAsmResourceBlob {
/// Create a new unmanaged resource directly referencing the provided data.
/// `dataIsMutable` indicates if the allocated data can be mutated. By
/// default, we treat unmanaged blobs as immutable.
- static AsmResourceBlob allocateWithAlign(ArrayRef<char> data, size_t align,
- bool dataIsMutable = false) {
- return AsmResourceBlob(data, align, /*deleter=*/{},
- /*dataIsMutable=*/false);
+ static AsmResourceBlob
+ allocateWithAlign(ArrayRef<char> data, size_t align,
+ AsmResourceBlob::DeleterFn deleter = {},
+ bool dataIsMutable = false) {
+ return AsmResourceBlob(data, align, std::move(deleter), dataIsMutable);
}
template <typename T>
- static AsmResourceBlob allocateInferAlign(ArrayRef<T> data,
- bool dataIsMutable = false) {
+ static AsmResourceBlob
+ allocateInferAlign(ArrayRef<T> data, AsmResourceBlob::DeleterFn deleter = {},
+ bool dataIsMutable = false) {
return allocateWithAlign(
ArrayRef<char>((const char *)data.data(), data.size() * sizeof(T)),
- alignof(T));
+ alignof(T), std::move(deleter), dataIsMutable);
}
};
diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h
index c17b568165ccb..1f38a2e8c7e02 100644
--- a/mlir/include/mlir/Parser/Parser.h
+++ b/mlir/include/mlir/Parser/Parser.h
@@ -93,6 +93,14 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr);
+/// An overload with a source manager that may have references taken during the
+/// parsing process, and whose lifetime can be freely extended (such that the
+/// source manager is not destroyed before the parsed IR). This is useful, for
+/// example, to avoid copying some large resources into the MLIRContext and
+/// instead referencing the data directly from the input buffers.
+LogicalResult parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config,
+ LocationAttr *sourceFileLoc = nullptr);
/// This parses the file specified by the indicated filename and appends parsed
/// operations to the given block. If the block is non-empty, the operations are
@@ -116,6 +124,15 @@ LogicalResult parseSourceFile(llvm::StringRef filename,
llvm::SourceMgr &sourceMgr, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc = nullptr);
+/// An overload with a source manager that may have references taken during the
+/// parsing process, and whose lifetime can be freely extended (such that the
+/// source manager is not destroyed before the parsed IR). This is useful, for
+/// example, to avoid copying some large resources into the MLIRContext and
+/// instead referencing the data directly from the input buffers.
+LogicalResult parseSourceFile(llvm::StringRef filename,
+ const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config,
+ LocationAttr *sourceFileLoc = nullptr);
/// This parses the IR string and appends parsed operations to the given block.
/// If the block is non-empty, the operations are placed before the current
@@ -157,6 +174,17 @@ inline OwningOpRef<ContainerOpT>
parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) {
return detail::parseSourceFile<ContainerOpT>(config, sourceMgr);
}
+/// An overload with a source manager that may have references taken during the
+/// parsing process, and whose lifetime can be freely extended (such that the
+/// source manager is not destroyed before the parsed IR). This is useful, for
+/// example, to avoid copying some large resources into the MLIRContext and
+/// instead referencing the data directly from the input buffers.
+template <typename ContainerOpT = Operation *>
+inline OwningOpRef<ContainerOpT>
+parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ const ParserConfig &config) {
+ return detail::parseSourceFile<ContainerOpT>(config, sourceMgr);
+}
/// This parses the file specified by the indicated filename. If the source IR
/// contained a single instance of `ContainerOpT`, it is returned. Otherwise, a
@@ -186,6 +214,18 @@ inline OwningOpRef<ContainerOpT> parseSourceFile(llvm::StringRef filename,
const ParserConfig &config) {
return detail::parseSourceFile<ContainerOpT>(config, filename, sourceMgr);
}
+/// An overload with a source manager that may have references taken during the
+/// parsing process, and whose lifetime can be freely extended (such that the
+/// source manager is not destroyed before the parsed IR). This is useful, for
+/// example, to avoid copying some large resources into the MLIRContext and
+/// instead referencing the data directly from the input buffers.
+template <typename ContainerOpT = Operation *>
+inline OwningOpRef<ContainerOpT>
+parseSourceFile(llvm::StringRef filename,
+ const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ const ParserConfig &config) {
+ return detail::parseSourceFile<ContainerOpT>(config, filename, sourceMgr);
+}
/// This parses the provided string containing MLIR. If the source IR contained
/// a single instance of `ContainerOpT`, it is returned. Otherwise, a new
diff --git a/mlir/include/mlir/Tools/ParseUtilities.h b/mlir/include/mlir/Tools/ParseUtilities.h
index 75b18f89e75a3..f366f6826c9a1 100644
--- a/mlir/include/mlir/Tools/ParseUtilities.h
+++ b/mlir/include/mlir/Tools/ParseUtilities.h
@@ -24,8 +24,8 @@ namespace mlir {
/// If 'insertImplicitModule' is true a top-level 'builtin.module' op will be
/// inserted that contains the parsed IR, unless one exists already.
inline OwningOpRef<Operation *>
-parseSourceFileForTool(llvm::SourceMgr &sourceMgr, const ParserConfig &config,
- bool insertImplicitModule) {
+parseSourceFileForTool(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ const ParserConfig &config, bool insertImplicitModule) {
if (insertImplicitModule) {
// TODO: Move implicit module logic out of 'parseSourceFile' and into here.
return parseSourceFile<ModuleOp>(sourceMgr, config);
diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h
index 0fe6fdc89d882..7d1896f6db7a4 100644
--- a/mlir/include/mlir/Tools/mlir-translate/Translation.h
+++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h
@@ -25,7 +25,10 @@ class OwningOpRef;
/// should create a new MLIR Operation in the given context and return a
/// pointer to it, or a nullptr in case of any error.
using TranslateSourceMgrToMLIRFunction = std::function<OwningOpRef<Operation *>(
- llvm::SourceMgr &sourceMgr, MLIRContext *)>;
+ const std::shared_ptr<llvm::SourceMgr> &sourceMgr, MLIRContext *)>;
+using TranslateRawSourceMgrToMLIRFunction =
+ std::function<OwningOpRef<Operation *>(llvm::SourceMgr &sourceMgr,
+ MLIRContext *)>;
/// Interface of the function that translates the given string to MLIR. The
/// implementation should create a new MLIR Operation in the given context. If
@@ -45,7 +48,8 @@ using TranslateFromMLIRFunction =
/// all MLIR constructs needed during the process inside the given context. This
/// can be used for round-tripping external formats through the MLIR system.
using TranslateFunction = std::function<LogicalResult(
- llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
+ const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ llvm::raw_ostream &output, MLIRContext *)>;
/// This class contains all of the components necessary for performing a
/// translation.
@@ -64,7 +68,7 @@ class Translation {
Optional<llvm::Align> getInputAlignment() const { return inputAlignment; }
/// Invoke the translation function with the given input and output streams.
- LogicalResult operator()(llvm::SourceMgr &sourceMgr,
+ LogicalResult operator()(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
llvm::raw_ostream &output,
MLIRContext *context) const {
return function(sourceMgr, output, context);
@@ -101,6 +105,10 @@ struct TranslateToMLIRRegistration {
llvm::StringRef name, llvm::StringRef description,
const TranslateSourceMgrToMLIRFunction &function,
Optional<llvm::Align> inputAlignment = std::nullopt);
+ TranslateToMLIRRegistration(
+ llvm::StringRef name, llvm::StringRef description,
+ const TranslateRawSourceMgrToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment = std::nullopt);
TranslateToMLIRRegistration(
llvm::StringRef name, llvm::StringRef description,
const TranslateStringRefToMLIRFunction &function,
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1c291ab86e19f..d19dd91b3340b 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -23,6 +23,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
+#include "llvm/Support/SourceMgr.h"
#include <optional>
#define DEBUG_TYPE "mlir-bytecode-reader"
@@ -492,11 +493,12 @@ namespace {
class ResourceSectionReader {
public:
/// Initialize the resource section reader with the given section data.
- LogicalResult initialize(Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
- StringSectionReader &stringReader,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData);
+ LogicalResult
+ initialize(Location fileLoc, const ParserConfig &config,
+ MutableArrayRef<BytecodeDialect> dialects,
+ StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef);
/// Parse a dialect resource handle from the resource section.
LogicalResult parseResourceHandle(EncodingReader &reader,
@@ -512,8 +514,10 @@ class ResourceSectionReader {
class ParsedResourceEntry : public AsmParsedResourceEntry {
public:
ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind,
- EncodingReader &reader, StringSectionReader &stringReader)
- : key(key), kind(kind), reader(reader), stringReader(stringReader) {}
+ EncodingReader &reader, StringSectionReader &stringReader,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
+ : key(key), kind(kind), reader(reader), stringReader(stringReader),
+ bufferOwnerRef(bufferOwnerRef) {}
~ParsedResourceEntry() override = default;
StringRef getKey() const final { return key; }
@@ -554,11 +558,22 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
if (failed(reader.parseBlobAndAlignment(data, alignment)))
return failure();
+ // If we have an extendable reference to the buffer owner, we don't need to
+ // allocate a new buffer for the data, and can use the data directly.
+ if (bufferOwnerRef) {
+ ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()),
+ data.size());
+
+ // Allocate an unmanager buffer which captures a reference to the owner.
+ // For now we just mark this as immutable, but in the future we should
+ // explore marking this as mutable when desired.
+ return UnmanagedAsmResourceBlob::allocateWithAlign(
+ charData, alignment,
+ [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {});
+ }
+
// Allocate memory for the blob using the provided allocator and copy the
// data into it.
- // FIXME: If the current holder of the bytecode can ensure its lifetime
- // (e.g. when mmap'd), we should not copy the data. We should use the data
- // from the bytecode directly.
AsmResourceBlob blob = allocator(data.size(), alignment);
assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&
blob.isMutable() &&
@@ -572,6 +587,7 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
AsmResourceEntryKind kind;
EncodingReader &reader;
StringSectionReader &stringReader;
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
} // namespace
@@ -580,6 +596,7 @@ static LogicalResult
parseResourceGroup(Location fileLoc, bool allowEmpty,
EncodingReader &offsetReader, EncodingReader &resourceReader,
StringSectionReader &stringReader, T *handler,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
uint64_t numResources;
if (failed(offsetReader.parseVarInt(numResources)))
@@ -611,7 +628,8 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
// Otherwise, parse the resource value.
EncodingReader entryReader(data, fileLoc);
- ParsedResourceEntry entry(key, kind, entryReader, stringReader);
+ ParsedResourceEntry entry(key, kind, entryReader, stringReader,
+ bufferOwnerRef);
if (failed(handler->parseResource(entry)))
return failure();
if (!entryReader.empty()) {
@@ -622,12 +640,12 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
return success();
}
-LogicalResult
-ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
- MutableArrayRef<BytecodeDialect> dialects,
- StringSectionReader &stringReader,
- ArrayRef<uint8_t> sectionData,
- ArrayRef<uint8_t> offsetSectionData) {
+LogicalResult ResourceSectionReader::initialize(
+ Location fileLoc, const ParserConfig &config,
+ MutableArrayRef<BytecodeDialect> dialects,
+ StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData,
+ ArrayRef<uint8_t> offsetSectionData,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
@@ -641,7 +659,7 @@ ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config,
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
function_ref<LogicalResult(StringRef)> keyFn = {}) {
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
- stringReader, handler, keyFn);
+ stringReader, handler, bufferOwnerRef, keyFn);
};
// Read the external resources from the bytecode.
@@ -1058,14 +1076,16 @@ namespace {
/// This class is used to read a bytecode buffer and translate it into MLIR.
class BytecodeReader {
public:
- BytecodeReader(Location fileLoc, const ParserConfig &config)
+ BytecodeReader(Location fileLoc, const ParserConfig &config,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
: config(config), fileLoc(fileLoc),
attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
"builtin.unrealized_conversion_cast", ValueRange(),
- NoneType::get(config.getContext())) {}
+ NoneType::get(config.getContext())),
+ bufferOwnerRef(bufferOwnerRef) {}
/// Read the bytecode defined within `buffer` into the given block.
LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
@@ -1222,6 +1242,10 @@ class BytecodeReader {
Block openForwardRefOps;
/// An operation state used when instantiating forward references.
OperationState forwardRefOpState;
+
+ /// The optional owning source manager, which when present may be used to
+ /// extend the lifetime of the input buffer.
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
} // namespace
@@ -1383,7 +1407,8 @@ LogicalResult BytecodeReader::parseResourceSection(
// Initialize the resource reader with the resource sections.
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
- *resourceData, *resourceOffsetData);
+ *resourceData, *resourceOffsetData,
+ bufferOwnerRef);
}
//===----------------------------------------------------------------------===//
@@ -1719,8 +1744,13 @@ bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
return buffer.getBuffer().startswith("ML\xefR");
}
-LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
- const ParserConfig &config) {
+/// Read the bytecode from the provided memory buffer reference.
+/// `bufferOwnerRef` if provided is the owning source manager for the buffer,
+/// and may be used to extend the lifetime of the buffer.
+static LogicalResult
+readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
Location sourceFileLoc =
FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
/*line=*/0, /*column=*/0);
@@ -1729,6 +1759,18 @@ LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
"input buffer is not an MLIR bytecode file");
}
- BytecodeReader reader(sourceFileLoc, config);
+ BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef);
return reader.read(buffer, block);
}
+
+LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config) {
+ return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{});
+}
+LogicalResult
+mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config) {
+ return readBytecodeFileImpl(
+ *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config,
+ sourceMgr);
+}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index c7210afdc7d55..518fd16066780 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -129,8 +129,8 @@ static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
return nullptr;
}
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+ auto sourceMgr = std::make_shared<llvm::SourceMgr>();
+ sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
OwningOpRef<Operation *> module =
parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
if (!module)
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 4c0742895b608..57dd3eeb2714a 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -30,30 +30,60 @@ LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
return readBytecodeFile(*sourceBuf, block, config);
return parseAsmSourceFile(sourceMgr, block, config);
}
+LogicalResult
+mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config,
+ LocationAttr *sourceFileLoc) {
+ const auto *sourceBuf =
+ sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
+ if (sourceFileLoc) {
+ *sourceFileLoc = FileLineColLoc::get(config.getContext(),
+ sourceBuf->getBufferIdentifier(),
+ /*line=*/0, /*column=*/0);
+ }
+ if (isBytecode(*sourceBuf))
+ return readBytecodeFile(sourceMgr, block, config);
+ return parseAsmSourceFile(*sourceMgr, block, config);
+}
LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block,
const ParserConfig &config,
LocationAttr *sourceFileLoc) {
- llvm::SourceMgr sourceMgr;
+ auto sourceMgr = std::make_shared<llvm::SourceMgr>();
return parseSourceFile(filename, sourceMgr, block, config, sourceFileLoc);
}
-LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
- llvm::SourceMgr &sourceMgr, Block *block,
- const ParserConfig &config,
- LocationAttr *sourceFileLoc) {
+static LogicalResult loadSourceFileBuffer(llvm::StringRef filename,
+ llvm::SourceMgr &sourceMgr,
+ MLIRContext *ctx) {
if (sourceMgr.getNumBuffers() != 0) {
// TODO: Extend to support multiple buffers.
- return emitError(mlir::UnknownLoc::get(config.getContext()),
+ return emitError(mlir::UnknownLoc::get(ctx),
"only main buffer parsed at the moment");
}
auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename);
if (std::error_code error = fileOrErr.getError())
- return emitError(mlir::UnknownLoc::get(config.getContext()),
+ return emitError(mlir::UnknownLoc::get(ctx),
"could not open input file " + filename);
// Load the MLIR source file.
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc());
+ return success();
+}
+
+LogicalResult mlir::parseSourceFile(llvm::StringRef filename,
+ llvm::SourceMgr &sourceMgr, Block *block,
+ const ParserConfig &config,
+ LocationAttr *sourceFileLoc) {
+ if (failed(loadSourceFileBuffer(filename, sourceMgr, config.getContext())))
+ return failure();
+ return parseSourceFile(sourceMgr, block, config, sourceFileLoc);
+}
+LogicalResult mlir::parseSourceFile(
+ llvm::StringRef filename, const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) {
+ if (failed(loadSourceFileBuffer(filename, *sourceMgr, config.getContext())))
+ return failure();
return parseSourceFile(sourceMgr, block, config, sourceFileLoc);
}
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index d75bff3ce33b5..62c84e2884a1f 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -46,11 +46,11 @@ using namespace llvm;
/// This typically parses the main source file, runs zero or more optimization
/// passes, then prints the output.
///
-static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
- bool verifyPasses, SourceMgr &sourceMgr,
- MLIRContext *context,
- PassPipelineFn passManagerSetupFn,
- bool emitBytecode, bool implicitModule) {
+static LogicalResult
+performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
+ const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ MLIRContext *context, PassPipelineFn passManagerSetupFn,
+ bool emitBytecode, bool implicitModule) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -115,8 +115,8 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
- SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
+ auto sourceMgr = std::make_shared<SourceMgr>();
+ sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
// since we'll inject the thread-pool separately.
@@ -135,13 +135,13 @@ processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
// If we are in verify diagnostics mode then we have a lot of work to do,
// otherwise just perform the actions without worrying about it.
if (!verifyDiagnostics) {
- SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+ SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
&context, passManagerSetupFn, emitBytecode,
implicitModule);
}
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
index c1570dda412b1..2f0ab8aa1a4fe 100644
--- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
+++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp
@@ -41,8 +41,8 @@ OwningOpRef<Operation *> loadModule(MLIRContext &context,
return nullptr;
}
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+ auto sourceMgr = std::make_shared<llvm::SourceMgr>();
+ sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc());
return parseSourceFileForTool(sourceMgr, &context, insertImplictModule);
}
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index 51b21f251747a..02c91285dbedf 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -87,18 +87,18 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
MLIRContext context;
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
- llvm::SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
+ auto sourceMgr = std::make_shared<llvm::SourceMgr>();
+ sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
if (!verifyDiagnostics) {
- SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
+ SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
return (*translationRequested)(sourceMgr, os, &context);
}
// In the diagnostic verification flow, we ignore whether the translation
// failed (in most cases, it is expected to fail). Instead, we check if the
// diagnostics were produced as expected.
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
(void)(*translationRequested)(sourceMgr, os, &context);
return sourceMgrHandler.verify();
};
diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp
index d8a2e9c4b5c00..afeaed52e329e 100644
--- a/mlir/lib/Tools/mlir-translate/Translation.cpp
+++ b/mlir/lib/Tools/mlir-translate/Translation.cpp
@@ -75,8 +75,8 @@ TranslateRegistration::TranslateRegistration(
static void registerTranslateToMLIRFunction(
StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
const TranslateSourceMgrToMLIRFunction &function) {
- auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
- MLIRContext *context) {
+ auto wrappedFn = [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ raw_ostream &output, MLIRContext *context) {
OwningOpRef<Operation *> op = function(sourceMgr, context);
if (!op || failed(verify(*op)))
return failure();
@@ -92,6 +92,15 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(name, description, inputAlignment, function);
}
+TranslateToMLIRRegistration::TranslateToMLIRRegistration(
+ StringRef name, StringRef description,
+ const TranslateRawSourceMgrToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment) {
+ registerTranslateToMLIRFunction(
+ name, description, inputAlignment,
+ [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ MLIRContext *ctx) { return function(*sourceMgr, ctx); });
+}
/// Wraps `function` with a lambda that extracts a StringRef from a source
/// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
@@ -100,9 +109,10 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(
name, description, inputAlignment,
- [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
+ [function](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
- sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+ sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
return function(buffer->getBuffer(), ctx);
});
}
@@ -117,9 +127,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
const std::function<void(DialectRegistry &)> &dialectRegistration) {
registerTranslation(
name, description, /*inputAlignment=*/std::nullopt,
- [function, dialectRegistration](llvm::SourceMgr &sourceMgr,
- raw_ostream &output,
- MLIRContext *context) {
+ [function,
+ dialectRegistration](const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
+ raw_ostream &output, MLIRContext *context) {
DialectRegistry registry;
dialectRegistration(registry);
context->appendDialectRegistry(registry);
More information about the Mlir-commits
mailing list