[Mlir-commits] [mlir] 3128b31 - Add support for Lazyloading to the MLIR bytecode
Mehdi Amini
llvmlistbot at llvm.org
Sat May 20 15:25:44 PDT 2023
Author: Mehdi Amini
Date: 2023-05-20T15:24:33-07:00
New Revision: 3128b3105d7a226fc26174be265da479ff619f3e
URL: https://github.com/llvm/llvm-project/commit/3128b3105d7a226fc26174be265da479ff619f3e
DIFF: https://github.com/llvm/llvm-project/commit/3128b3105d7a226fc26174be265da479ff619f3e.diff
LOG: Add support for Lazyloading to the MLIR bytecode
IsolatedRegions are emitted in sections in order for the reader to be
able to skip over them. A new class is exposed to manage the state and
allow the readers to load these IsolatedRegions on-demand.
Differential Revision: https://reviews.llvm.org/D149515
Added:
mlir/test/Bytecode/bytecode-lazy-loading.mlir
mlir/test/lib/IR/TestLazyLoading.cpp
Modified:
mlir/docs/BytecodeFormat.md
mlir/include/mlir/Bytecode/BytecodeReader.h
mlir/lib/Bytecode/Encoding.h
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/test/Bytecode/invalid/invalid-structure.mlir
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index b4f7400274f43..9586c262399a4 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -314,6 +314,12 @@ offsets provides more effective compression.
The IR section contains the encoded form of operations within the bytecode.
+```
+ir_section {
+ block: block; // Single block without arguments.
+}
+```
+
#### Operation Encoding
```
@@ -334,7 +340,9 @@ op {
successors: varint[],
regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove)
- regions: region[]
+
+ // regions are stored in a section if isIsolatedFromAbove
+ regions: (region | region_section)[]
}
```
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
index d7cb916646035..206e42870ad85 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReader.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -15,6 +15,9 @@
#include "mlir/IR/AsmState.h"
#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include <functional>
+#include <memory>
namespace llvm {
class MemoryBufferRef;
@@ -22,6 +25,59 @@ class SourceMgr;
} // namespace llvm
namespace mlir {
+
+/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
+/// state explicitly available in order to support lazy loading.
+/// The `finalize` method must be called before destruction.
+class BytecodeReader {
+public:
+ /// Create a bytecode reader for the given buffer. If `lazyLoad` is true,
+ /// isolated regions aren't loaded eagerly.
+ explicit BytecodeReader(
+ llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef = {});
+ ~BytecodeReader();
+
+ /// Read the operations defined within the given memory buffer, containing
+ /// MLIR bytecode, into the provided block. If the reader was created with
+ /// `lazyLoad` enabled, isolated regions aren't loaded eagerly.
+ /// The lazyOps call back is invoked for every ops that can be lazy-loaded.
+ /// This let the client decide if the op should be materialized
+ /// immediately or delayed.
+ LogicalResult readTopLevel(
+ Block *block, llvm::function_ref<bool(Operation *)> lazyOps =
+ [](Operation *) { return false; });
+
+ /// Return the number of ops that haven't been materialized yet.
+ int64_t getNumOpsToMaterialize() const;
+
+ /// Return true if the provided op is materializable.
+ bool isMaterializable(Operation *op);
+
+ /// Materialize the provide operation. The provided operation must be
+ /// materializable.
+ /// The lazyOps call back is invoked for every ops that can be lazy-loaded.
+ /// This let the client decide if the op should be materialized immediately or
+ /// delayed.
+ /// !! Using this materialize withing an IR walk() can be confusing: make sure
+ /// to use a PreOrder traversal !!
+ LogicalResult materialize(
+ Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback =
+ [](Operation *) { return false; });
+
+ /// Finalize the lazy-loading by calling back with every op that hasn't been
+ /// materialized to let the client decide if the op should be deleted or
+ /// materialized. The op is materialized if the callback returns true, deleted
+ /// otherwise. The implementation of the callback must be thread-safe.
+ LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize =
+ [](Operation *) { return true; });
+
+ class Impl;
+
+private:
+ std::unique_ptr<Impl> impl;
+};
+
/// Returns true if the given buffer starts with the magic bytes that signal
/// MLIR bytecode.
bool isBytecode(llvm::MemoryBufferRef buffer);
@@ -36,6 +92,7 @@ LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
LogicalResult
readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
Block *block, const ParserConfig &config);
+
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEREADER_H
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
index 0072538154806..20096ec4928e1 100644
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -27,7 +27,7 @@ enum {
kMinSupportedVersion = 0,
/// The current bytecode version.
- kVersion = 1,
+ kVersion = 2,
/// An arbitrary value used to fill alignment padding.
kAlignmentByte = 0xCB,
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 9344ec9214c18..58145fa80db3c 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -17,6 +17,9 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
@@ -24,6 +27,8 @@
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
+#include <list>
+#include <memory>
#include <optional>
#define DEBUG_TYPE "mlir-bytecode-reader"
@@ -1092,25 +1097,93 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
// Bytecode Reader
//===----------------------------------------------------------------------===//
-namespace {
/// This class is used to read a bytecode buffer and translate it into MLIR.
-class BytecodeReader {
+class mlir::BytecodeReader::Impl {
+ struct RegionReadState;
+ using LazyLoadableOpsInfo =
+ std::list<std::pair<Operation *, RegionReadState>>;
+ using LazyLoadableOpsMap =
+ DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
+
public:
- BytecodeReader(Location fileLoc, const ParserConfig &config,
- const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
- : config(config), fileLoc(fileLoc),
+ Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
+ llvm::MemoryBufferRef buffer,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
+ : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
attrTypeReader(stringReader, resourceReader, fileLoc),
// Use the builtin unrealized conversion cast operation to represent
// forward references to values that aren't yet defined.
forwardRefOpState(UnknownLoc::get(config.getContext()),
"builtin.unrealized_conversion_cast", ValueRange(),
NoneType::get(config.getContext())),
- bufferOwnerRef(bufferOwnerRef) {}
+ buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
/// Read the bytecode defined within `buffer` into the given block.
- LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
+ LogicalResult read(Block *block,
+ llvm::function_ref<bool(Operation *)> lazyOps);
+
+ /// Return the number of ops that haven't been materialized yet.
+ int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
+
+ bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
+
+ /// Materialize the provided operation, invoke the lazyOpsCallback on every
+ /// newly found lazy operation.
+ LogicalResult
+ materialize(Operation *op,
+ llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+ this->lazyOpsCallback = lazyOpsCallback;
+ auto resetlazyOpsCallback =
+ llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
+ auto it = lazyLoadableOpsMap.find(op);
+ assert(it != lazyLoadableOpsMap.end() &&
+ "materialize called on non-materializable op");
+ return materialize(it);
+ }
+
+ /// Materialize all operations.
+ LogicalResult materializeAll() {
+ while (!lazyLoadableOpsMap.empty()) {
+ if (failed(materialize(lazyLoadableOpsMap.begin())))
+ return failure();
+ }
+ return success();
+ }
+
+ /// Finalize the lazy-loading by calling back with every op that hasn't been
+ /// materialized to let the client decide if the op should be deleted or
+ /// materialized. The op is materialized if the callback returns true, deleted
+ /// otherwise.
+ LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
+ while (!lazyLoadableOps.empty()) {
+ Operation *op = lazyLoadableOps.begin()->first;
+ if (shouldMaterialize(op)) {
+ if (failed(materialize(lazyLoadableOpsMap.find(op))))
+ return failure();
+ continue;
+ }
+ op->dropAllReferences();
+ op->erase();
+ lazyLoadableOps.pop_front();
+ lazyLoadableOpsMap.erase(op);
+ }
+ return success();
+ }
private:
+ LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
+ assert(it != lazyLoadableOpsMap.end() &&
+ "materialize called on non-materializable op");
+ valueScopes.emplace_back();
+ std::vector<RegionReadState> regionStack;
+ regionStack.push_back(std::move(it->getSecond()->second));
+ lazyLoadableOps.erase(it->getSecond());
+ lazyLoadableOpsMap.erase(it);
+ auto result = parseRegions(regionStack, regionStack.back());
+ assert(regionStack.empty());
+ return result;
+ }
+
/// Return the context for this config.
MLIRContext *getContext() const { return config.getContext(); }
@@ -1151,14 +1224,22 @@ class BytecodeReader {
/// This struct represents the current read state of a range of regions. This
/// struct is used to enable iterative parsing of regions.
struct RegionReadState {
- RegionReadState(Operation *op, bool isIsolatedFromAbove)
- : RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
- RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove)
- : curRegion(regions.begin()), endRegion(regions.end()),
+ RegionReadState(Operation *op, EncodingReader *reader,
+ bool isIsolatedFromAbove)
+ : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
+ RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
+ bool isIsolatedFromAbove)
+ : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
isIsolatedFromAbove(isIsolatedFromAbove) {}
/// The current regions being read.
MutableArrayRef<Region>::iterator curRegion, endRegion;
+ /// This is the reader to use for this region, this pointer is pointing to
+ /// the parent region reader unless the current region is IsolatedFromAbove,
+ /// in which case the pointer is pointing to the `owningReader` which is a
+ /// section dedicated to the current region.
+ EncodingReader *reader;
+ std::unique_ptr<EncodingReader> owningReader;
/// The number of values defined immediately within this region.
unsigned numValues = 0;
@@ -1176,15 +1257,15 @@ class BytecodeReader {
};
LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
- LogicalResult parseRegions(EncodingReader &reader,
- std::vector<RegionReadState> ®ionStack,
+ LogicalResult parseRegions(std::vector<RegionReadState> ®ionStack,
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove);
- LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
- LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
+ LogicalResult parseRegion(RegionReadState &readState);
+ LogicalResult parseBlockHeader(EncodingReader &reader,
+ RegionReadState &readState);
LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
//===--------------------------------------------------------------------===//
@@ -1234,6 +1315,16 @@ class BytecodeReader {
/// A location to use when emitting errors.
Location fileLoc;
+ /// Flag that indicates if lazyloading is enabled.
+ bool lazyLoading;
+
+ /// Keep track of operations that have been lazy loaded (their regions haven't
+ /// been materialized), along with the `RegionReadState` that allows to
+ /// lazy-load the regions nested under the operation.
+ LazyLoadableOpsInfo lazyLoadableOps;
+ LazyLoadableOpsMap lazyLoadableOpsMap;
+ llvm::function_ref<bool(Operation *)> lazyOpsCallback;
+
/// The reader used to process attribute and types within the bytecode.
AttrTypeReader attrTypeReader;
@@ -1264,14 +1355,20 @@ class BytecodeReader {
/// An operation state used when instantiating forward references.
OperationState forwardRefOpState;
+ /// Reference to the input buffer.
+ llvm::MemoryBufferRef buffer;
+
/// The optional owning source manager, which when present may be used to
/// extend the lifetime of the input buffer.
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
};
-} // namespace
-LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
+LogicalResult BytecodeReader::Impl::read(
+ Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
EncodingReader reader(buffer.getBuffer(), fileLoc);
+ this->lazyOpsCallback = lazyOpsCallback;
+ auto resetlazyOpsCallback =
+ llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
// Skip over the bytecode header, this should have already been checked.
if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
@@ -1302,7 +1399,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
// Check for duplicate sections, we only expect one instance of each.
if (sectionDatas[sectionID]) {
return reader.emitError("duplicate top-level section: ",
- toString(sectionID));
+ ::toString(sectionID));
}
sectionDatas[sectionID] = sectionData;
}
@@ -1311,7 +1408,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
return reader.emitError("missing data for top-level section: ",
- toString(sectionID));
+ ::toString(sectionID));
}
}
@@ -1340,7 +1437,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
}
-LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
+LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
if (failed(reader.parseVarInt(version)))
return failure();
@@ -1357,6 +1454,9 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
" is newer than the current version ",
currentVersion);
}
+ // Override any request to lazy-load if the bytecode version is too old.
+ if (version < 2)
+ lazyLoading = false;
return success();
}
@@ -1396,7 +1496,7 @@ LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
}
LogicalResult
-BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
+BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
// Parse the number of dialects in the section.
@@ -1449,7 +1549,8 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}
-FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
+FailureOr<OperationName>
+BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
@@ -1471,7 +1572,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
//===----------------------------------------------------------------------===//
// Resource Section
-LogicalResult BytecodeReader::parseResourceSection(
+LogicalResult BytecodeReader::Impl::parseResourceSection(
EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
// Ensure both sections are either present or not.
@@ -1499,8 +1600,9 @@ LogicalResult BytecodeReader::parseResourceSection(
//===----------------------------------------------------------------------===//
// IR Section
-LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
- Block *block) {
+LogicalResult
+BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
+ Block *block) {
EncodingReader reader(sectionData, fileLoc);
// A stack of operation regions currently being read from the bytecode.
@@ -1508,17 +1610,17 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
// Parse the top-level block using a temporary module operation.
OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
- regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
+ regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
regionStack.back().curBlocks.push_back(moduleOp->getBody());
regionStack.back().curBlock = regionStack.back().curRegion->begin();
- if (failed(parseBlock(reader, regionStack.back())))
+ if (failed(parseBlockHeader(reader, regionStack.back())))
return failure();
valueScopes.emplace_back();
valueScopes.back().push(regionStack.back());
// Iteratively parse regions until everything has been resolved.
while (!regionStack.empty())
- if (failed(parseRegions(reader, regionStack, regionStack.back())))
+ if (failed(parseRegions(regionStack, regionStack.back())))
return failure();
if (!forwardRefOps.empty()) {
return reader.emitError(
@@ -1549,15 +1651,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
}
LogicalResult
-BytecodeReader::parseRegions(EncodingReader &reader,
- std::vector<RegionReadState> ®ionStack,
- RegionReadState &readState) {
- // Read the regions of this operation.
+BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
+ RegionReadState &readState) {
+ // Process regions, blocks, and operations until the end or if a nested
+ // region is encountered. In this case we push a new state in regionStack and
+ // return, the processing of the current region will resume afterward.
for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
// If the current block hasn't been setup yet, parse the header for this
- // region.
+ // region. The current block is already setup when this function was
+ // interrupted to recurse down in a nested region and we resume the current
+ // block after processing the nested region.
if (readState.curBlock == Region::iterator()) {
- if (failed(parseRegion(reader, readState)))
+ if (failed(parseRegion(readState)))
return failure();
// If the region is empty, there is nothing to more to do.
@@ -1566,6 +1671,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
}
// Parse the blocks within the region.
+ EncodingReader &reader = *readState.reader;
do {
while (readState.numOpsRemaining--) {
// Read in the next operation. We don't read its regions directly, we
@@ -1576,9 +1682,38 @@ BytecodeReader::parseRegions(EncodingReader &reader,
if (failed(op))
return failure();
- // If the op has regions, add it to the stack for processing.
+ // If the op has regions, add it to the stack for processing and return:
+ // we stop the processing of the current region and resume it after the
+ // inner one is completed. Unless LazyLoading is activated in which case
+ // nested region parsing is delayed.
if ((*op)->getNumRegions()) {
- regionStack.emplace_back(*op, isIsolatedFromAbove);
+ RegionReadState childState(*op, &reader, isIsolatedFromAbove);
+
+ // Isolated regions are encoded as a section in version 2 and above.
+ if (version >= 2 && isIsolatedFromAbove) {
+ bytecode::Section::ID sectionID;
+ ArrayRef<uint8_t> sectionData;
+ if (failed(reader.parseSection(sectionID, sectionData)))
+ return failure();
+ if (sectionID != bytecode::Section::kIR)
+ return emitError(fileLoc, "expected IR section for region");
+ childState.owningReader =
+ std::make_unique<EncodingReader>(sectionData, fileLoc);
+ childState.reader = childState.owningReader.get();
+ }
+
+ if (lazyLoading) {
+ // If the user has a callback set, they have the opportunity
+ // to control lazyloading as we go.
+ if (!lazyOpsCallback || !lazyOpsCallback(*op)) {
+ lazyLoadableOps.push_back(
+ std::make_pair(*op, std::move(childState)));
+ lazyLoadableOpsMap.try_emplace(*op,
+ std::prev(lazyLoadableOps.end()));
+ continue;
+ }
+ }
+ regionStack.push_back(std::move(childState));
// If the op is isolated from above, push a new value scope.
if (isIsolatedFromAbove)
@@ -1590,7 +1725,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
// Move to the next block of the region.
if (++readState.curBlock == readState.curRegion->end())
break;
- if (failed(parseBlock(reader, readState)))
+ if (failed(parseBlockHeader(reader, readState)))
return failure();
} while (true);
@@ -1601,16 +1736,19 @@ BytecodeReader::parseRegions(EncodingReader &reader,
// When the regions have been fully parsed, pop them off of the read stack. If
// the regions were isolated from above, we also pop the last value scope.
- if (readState.isIsolatedFromAbove)
+ if (readState.isIsolatedFromAbove) {
+ assert(!valueScopes.empty() && "Expect a valueScope after reading region");
valueScopes.pop_back();
+ }
+ assert(!regionStack.empty() && "Expect a regionStack after reading region");
regionStack.pop_back();
return success();
}
FailureOr<Operation *>
-BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
- RegionReadState &readState,
- bool &isIsolatedFromAbove) {
+BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
+ RegionReadState &readState,
+ bool &isIsolatedFromAbove) {
// Parse the name of the operation.
FailureOr<OperationName> opName = parseOpName(reader);
if (failed(opName))
@@ -1696,8 +1834,9 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
return op;
}
-LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
- RegionReadState &readState) {
+LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
+ EncodingReader &reader = *readState.reader;
+
// Parse the number of blocks in the region.
uint64_t numBlocks;
if (failed(reader.parseVarInt(numBlocks)))
@@ -1727,11 +1866,12 @@ LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
// Parse the entry block of the region.
readState.curBlock = readState.curRegion->begin();
- return parseBlock(reader, readState);
+ return parseBlockHeader(reader, readState);
}
-LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
- RegionReadState &readState) {
+LogicalResult
+BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
+ RegionReadState &readState) {
bool hasArgs;
if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
return failure();
@@ -1744,8 +1884,8 @@ LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
return success();
}
-LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
- Block *block) {
+LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
+ Block *block) {
// Parse the value ID for the first argument, and the number of arguments.
uint64_t numArgs;
if (failed(reader.parseVarInt(numArgs)))
@@ -1773,7 +1913,7 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
//===----------------------------------------------------------------------===//
// Value Processing
-Value BytecodeReader::parseOperand(EncodingReader &reader) {
+Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
std::vector<Value> &values = valueScopes.back().values;
Value *value = nullptr;
if (failed(parseEntry(reader, values, value, "value")))
@@ -1785,8 +1925,8 @@ Value BytecodeReader::parseOperand(EncodingReader &reader) {
return *value;
}
-LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
- ValueRange newValues) {
+LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
+ ValueRange newValues) {
ValueScope &valueScope = valueScopes.back();
std::vector<Value> &values = valueScope.values;
@@ -1821,7 +1961,7 @@ LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
return success();
}
-Value BytecodeReader::createForwardRef() {
+Value BytecodeReader::Impl::createForwardRef() {
// Check for an avaliable existing operation to use. Otherwise, create a new
// fake operation to use for the reference.
if (!openForwardRefOps.empty()) {
@@ -1837,6 +1977,41 @@ Value BytecodeReader::createForwardRef() {
// Entry Points
//===----------------------------------------------------------------------===//
+BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
+
+BytecodeReader::BytecodeReader(
+ llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
+ const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
+ Location sourceFileLoc =
+ FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
+ /*line=*/0, /*column=*/0);
+ impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
+ bufferOwnerRef);
+}
+
+LogicalResult BytecodeReader::readTopLevel(
+ Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+ return impl->read(block, lazyOpsCallback);
+}
+
+int64_t BytecodeReader::getNumOpsToMaterialize() const {
+ return impl->getNumOpsToMaterialize();
+}
+
+bool BytecodeReader::isMaterializable(Operation *op) {
+ return impl->isMaterializable(op);
+}
+
+LogicalResult BytecodeReader::materialize(
+ Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+ return impl->materialize(op, lazyOpsCallback);
+}
+
+LogicalResult
+BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
+ return impl->finalize(shouldMaterialize);
+}
+
bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
return buffer.getBuffer().startswith("ML\xefR");
}
@@ -1856,8 +2031,9 @@ readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
"input buffer is not an MLIR bytecode file");
}
- BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef);
- return reader.read(buffer, block);
+ BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
+ buffer, bufferOwnerRef);
+ return reader.read(block, /*lazyOpsCallback=*/nullptr);
}
LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 801f3022d0e47..158dbe6d161db 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -734,8 +734,18 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
bool isIsolatedFromAbove = op->hasTrait<OpTrait::IsIsolatedFromAbove>();
emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
- for (Region ®ion : op->getRegions())
- writeRegion(emitter, ®ion);
+ for (Region ®ion : op->getRegions()) {
+ // If the region is not isolated from above, or we are emitting bytecode
+ // targetting version <2, we don't use a section.
+ if (!isIsolatedFromAbove || config.bytecodeVersion < 2) {
+ writeRegion(emitter, ®ion);
+ continue;
+ }
+
+ EncodingEmitter regionEmitter;
+ writeRegion(regionEmitter, ®ion);
+ emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter));
+ }
}
}
diff --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
new file mode 100644
index 0000000000000..a4f7974b0b690
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading)" %s -o %t | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading{bytecode-version=1})" %s -o %t | FileCheck %s --check-prefix=OLD-BYTECODE
+
+
+func.func @op_with_passthrough_region_args() {
+ %0 = arith.constant 10 : index
+ test.isolated_region %0 {
+ "test.consumer"(%0) : (index) -> ()
+ }
+ %result:2 = "test.op"() : () -> (index, index)
+ test.isolated_region %result#1 {
+ "test.consumer"(%result#1) : (index) -> ()
+ }
+ return
+}
+
+// Before version 2, we can't support lazy loading.
+// OLD-BYTECODE-NOT: Has 1 ops to materialize
+// OLD-BYTECODE-NOT: Materializing
+// OLD-BYTECODE: Has 0 ops to materialize
+
+
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: "builtin.module"() ({
+// CHECK-NOT: func
+// CHECK: Materializing...
+// CHECK: "builtin.module"() ({
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK-NOT: arith
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK-NOT: arith
+// CHECK: Materializing...
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK: arith
+// CHECK: isolated_region
+// CHECK-NOT: test.consumer
+// CHECK: Has 2 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: test.isolated_region
+// CHECK-NOT: test.consumer
+// CHECK: Materializing...
+// CHECK: test.isolated_region
+// CHECK: ^bb0(%arg0: index):
+// CHECK: test.consumer
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: test.isolated_region
+// CHECK-NOT: test.consumer
+// CHECK: Materializing...
+// CHECK: test.isolated_region
+// CHECK: test.consumer
+// CHECK: Has 0 ops to materialize
diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index d98c6c6191b87..4668878b10560 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
//===--------------------------------------------------------------------===//
// RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 1
+// VERSION: bytecode version 127 is newer than the current version 2
//===--------------------------------------------------------------------===//
// Producer
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 8b519538719e5..627036d021fb7 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRTestIR
TestFunc.cpp
TestInterfaces.cpp
TestMatchers.cpp
+ TestLazyLoading.cpp
TestOpaqueLoc.cpp
TestOperationEquals.cpp
TestPrintDefUse.cpp
diff --git a/mlir/test/lib/IR/TestLazyLoading.cpp b/mlir/test/lib/IR/TestLazyLoading.cpp
new file mode 100644
index 0000000000000..187b977c7daaa
--- /dev/null
+++ b/mlir/test/lib/IR/TestLazyLoading.cpp
@@ -0,0 +1,93 @@
+//===- TestLazyLoading.cpp - Pass to test operation lazy loading ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <list>
+
+using namespace mlir;
+
+namespace {
+
+/// This is a test pass which LazyLoads the current operation recursively.
+struct LazyLoadingPass : public PassWrapper<LazyLoadingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LazyLoadingPass)
+
+ StringRef getArgument() const final { return "test-lazy-loading"; }
+ StringRef getDescription() const final { return "Test LazyLoading of op"; }
+ LazyLoadingPass() = default;
+ LazyLoadingPass(const LazyLoadingPass &) {}
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ std::string bytecode;
+ {
+ BytecodeWriterConfig config;
+ if (version >= 0)
+ config.setDesiredBytecodeVersion(version);
+ llvm::raw_string_ostream os(bytecode);
+ if (failed(writeBytecodeToFile(op, os, config))) {
+ op->emitError() << "failed to write bytecode at version "
+ << (int)version;
+ signalPassFailure();
+ return;
+ }
+ }
+ llvm::MemoryBufferRef buffer(bytecode, "test-lazy-loading");
+ Block block;
+ ParserConfig config(op->getContext(), /*verifyAfterParse=*/false);
+ BytecodeReader reader(buffer, config,
+ /*lazyLoad=*/true);
+ std::list<Operation *> toLoadOps;
+ if (failed(reader.readTopLevel(&block, [&](Operation *op) {
+ toLoadOps.push_back(op);
+ return false;
+ }))) {
+ op->emitError() << "failed to read bytecode";
+ return;
+ }
+
+ llvm::outs() << "Has " << reader.getNumOpsToMaterialize()
+ << " ops to materialize\n";
+
+ // Recursively print the operations, before and after lazy loading.
+ while (!toLoadOps.empty()) {
+ Operation *toLoad = toLoadOps.front();
+ toLoadOps.pop_front();
+ llvm::outs() << "\n\nBefore Materializing...\n\n";
+ toLoad->print(llvm::outs());
+ llvm::outs() << "\n\nMaterializing...\n\n";
+ if (failed(reader.materialize(toLoad, [&](Operation *op) {
+ toLoadOps.push_back(op);
+ return false;
+ }))) {
+ toLoad->emitError() << "failed to materialize";
+ signalPassFailure();
+ return;
+ }
+ toLoad->print(llvm::outs());
+ llvm::outs() << "\n";
+ llvm::outs() << "Has " << reader.getNumOpsToMaterialize()
+ << " ops to materialize\n";
+ }
+ }
+ Option<int> version{*this, "bytecode-version",
+ llvm::cl::desc("Specifies the bytecode version to use."),
+ llvm::cl::init(-1)};
+};
+} // namespace
+
+namespace mlir {
+void registerLazyLoadingTestPasses() { PassRegistration<LazyLoadingPass>(); }
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 13a525f0bcff8..40b9c827fa610 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -31,6 +31,7 @@ using namespace mlir;
namespace mlir {
void registerConvertToTargetEnvPass();
void registerCloneTestPasses();
+void registerLazyLoadingTestPasses();
void registerPassManagerTestPass();
void registerPrintSpirvAvailabilityPass();
void registerLoopLikeInterfaceTestPasses();
@@ -146,6 +147,7 @@ void registerTestPasses() {
registerConvertToTargetEnvPass();
registerPassManagerTestPass();
registerPrintSpirvAvailabilityPass();
+ registerLazyLoadingTestPasses();
registerLoopLikeInterfaceTestPasses();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
More information about the Mlir-commits
mailing list