[Mlir-commits] [mlir] 3128b31 - Add support for Lazyloading to the MLIR bytecode

Mehdi Amini llvmlistbot at llvm.org
Sat May 20 15:25:44 PDT 2023


Author: Mehdi Amini
Date: 2023-05-20T15:24:33-07:00
New Revision: 3128b3105d7a226fc26174be265da479ff619f3e

URL: https://github.com/llvm/llvm-project/commit/3128b3105d7a226fc26174be265da479ff619f3e
DIFF: https://github.com/llvm/llvm-project/commit/3128b3105d7a226fc26174be265da479ff619f3e.diff

LOG: Add support for Lazyloading to the MLIR bytecode

IsolatedRegions are emitted in sections in order for the reader to be
able to skip over them. A new class is exposed to manage the state and
allow the readers to load these IsolatedRegions on-demand.

Differential Revision: https://reviews.llvm.org/D149515

Added: 
    mlir/test/Bytecode/bytecode-lazy-loading.mlir
    mlir/test/lib/IR/TestLazyLoading.cpp

Modified: 
    mlir/docs/BytecodeFormat.md
    mlir/include/mlir/Bytecode/BytecodeReader.h
    mlir/lib/Bytecode/Encoding.h
    mlir/lib/Bytecode/Reader/BytecodeReader.cpp
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/test/Bytecode/invalid/invalid-structure.mlir
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index b4f7400274f43..9586c262399a4 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -314,6 +314,12 @@ offsets provides more effective compression.
 
 The IR section contains the encoded form of operations within the bytecode.
 
+```
+ir_section {
+  block: block; // Single block without arguments.
+}
+```
+
 #### Operation Encoding
 
 ```
@@ -334,7 +340,9 @@ op {
   successors: varint[],
 
   regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove)
-  regions: region[]
+
+  // regions are stored in a section if isIsolatedFromAbove
+  regions: (region | region_section)[]
 }
 ```
 

diff  --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
index d7cb916646035..206e42870ad85 100644
--- a/mlir/include/mlir/Bytecode/BytecodeReader.h
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -15,6 +15,9 @@
 
 #include "mlir/IR/AsmState.h"
 #include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include <functional>
+#include <memory>
 
 namespace llvm {
 class MemoryBufferRef;
@@ -22,6 +25,59 @@ class SourceMgr;
 } // namespace llvm
 
 namespace mlir {
+
+/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
+/// state explicitly available in order to support lazy loading.
+/// The `finalize` method must be called before destruction.
+class BytecodeReader {
+public:
+  /// Create a bytecode reader for the given buffer. If `lazyLoad` is true,
+  /// isolated regions aren't loaded eagerly.
+  explicit BytecodeReader(
+      llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad,
+      const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef = {});
+  ~BytecodeReader();
+
+  /// Read the operations defined within the given memory buffer, containing
+  /// MLIR bytecode, into the provided block. If the reader was created with
+  /// `lazyLoad` enabled, isolated regions aren't loaded eagerly.
+  /// The lazyOps call back is invoked for every ops that can be lazy-loaded.
+  /// This let the client decide if the op should be materialized
+  /// immediately or delayed.
+  LogicalResult readTopLevel(
+      Block *block, llvm::function_ref<bool(Operation *)> lazyOps =
+                        [](Operation *) { return false; });
+
+  /// Return the number of ops that haven't been materialized yet.
+  int64_t getNumOpsToMaterialize() const;
+
+  /// Return true if the provided op is materializable.
+  bool isMaterializable(Operation *op);
+
+  /// Materialize the provide operation. The provided operation must be
+  /// materializable.
+  /// The lazyOps call back is invoked for every ops that can be lazy-loaded.
+  /// This let the client decide if the op should be materialized immediately or
+  /// delayed.
+  /// !! Using this materialize withing an IR walk() can be confusing: make sure
+  /// to use a PreOrder traversal !!
+  LogicalResult materialize(
+      Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback =
+                         [](Operation *) { return false; });
+
+  /// Finalize the lazy-loading by calling back with every op that hasn't been
+  /// materialized to let the client decide if the op should be deleted or
+  /// materialized. The op is materialized if the callback returns true, deleted
+  /// otherwise. The implementation of the callback must be thread-safe.
+  LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize =
+                             [](Operation *) { return true; });
+
+  class Impl;
+
+private:
+  std::unique_ptr<Impl> impl;
+};
+
 /// Returns true if the given buffer starts with the magic bytes that signal
 /// MLIR bytecode.
 bool isBytecode(llvm::MemoryBufferRef buffer);
@@ -36,6 +92,7 @@ LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
 LogicalResult
 readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
                  Block *block, const ParserConfig &config);
+
 } // namespace mlir
 
 #endif // MLIR_BYTECODE_BYTECODEREADER_H

diff  --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
index 0072538154806..20096ec4928e1 100644
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -27,7 +27,7 @@ enum {
   kMinSupportedVersion = 0,
 
   /// The current bytecode version.
-  kVersion = 1,
+  kVersion = 2,
 
   /// An arbitrary value used to fill alignment padding.
   kAlignmentByte = 0xCB,

diff  --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 9344ec9214c18..58145fa80db3c 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -17,6 +17,9 @@
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallString.h"
@@ -24,6 +27,8 @@
 #include "llvm/Support/MemoryBufferRef.h"
 #include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/SourceMgr.h"
+#include <list>
+#include <memory>
 #include <optional>
 
 #define DEBUG_TYPE "mlir-bytecode-reader"
@@ -1092,25 +1097,93 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
 // Bytecode Reader
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// This class is used to read a bytecode buffer and translate it into MLIR.
-class BytecodeReader {
+class mlir::BytecodeReader::Impl {
+  struct RegionReadState;
+  using LazyLoadableOpsInfo =
+      std::list<std::pair<Operation *, RegionReadState>>;
+  using LazyLoadableOpsMap =
+      DenseMap<Operation *, LazyLoadableOpsInfo::iterator>;
+
 public:
-  BytecodeReader(Location fileLoc, const ParserConfig &config,
-                 const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
-      : config(config), fileLoc(fileLoc),
+  Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading,
+       llvm::MemoryBufferRef buffer,
+       const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef)
+      : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading),
         attrTypeReader(stringReader, resourceReader, fileLoc),
         // Use the builtin unrealized conversion cast operation to represent
         // forward references to values that aren't yet defined.
         forwardRefOpState(UnknownLoc::get(config.getContext()),
                           "builtin.unrealized_conversion_cast", ValueRange(),
                           NoneType::get(config.getContext())),
-        bufferOwnerRef(bufferOwnerRef) {}
+        buffer(buffer), bufferOwnerRef(bufferOwnerRef) {}
 
   /// Read the bytecode defined within `buffer` into the given block.
-  LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
+  LogicalResult read(Block *block,
+                     llvm::function_ref<bool(Operation *)> lazyOps);
+
+  /// Return the number of ops that haven't been materialized yet.
+  int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); }
+
+  bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); }
+
+  /// Materialize the provided operation, invoke the lazyOpsCallback on every
+  /// newly found lazy operation.
+  LogicalResult
+  materialize(Operation *op,
+              llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+    this->lazyOpsCallback = lazyOpsCallback;
+    auto resetlazyOpsCallback =
+        llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
+    auto it = lazyLoadableOpsMap.find(op);
+    assert(it != lazyLoadableOpsMap.end() &&
+           "materialize called on non-materializable op");
+    return materialize(it);
+  }
+
+  /// Materialize all operations.
+  LogicalResult materializeAll() {
+    while (!lazyLoadableOpsMap.empty()) {
+      if (failed(materialize(lazyLoadableOpsMap.begin())))
+        return failure();
+    }
+    return success();
+  }
+
+  /// Finalize the lazy-loading by calling back with every op that hasn't been
+  /// materialized to let the client decide if the op should be deleted or
+  /// materialized. The op is materialized if the callback returns true, deleted
+  /// otherwise.
+  LogicalResult finalize(function_ref<bool(Operation *)> shouldMaterialize) {
+    while (!lazyLoadableOps.empty()) {
+      Operation *op = lazyLoadableOps.begin()->first;
+      if (shouldMaterialize(op)) {
+        if (failed(materialize(lazyLoadableOpsMap.find(op))))
+          return failure();
+        continue;
+      }
+      op->dropAllReferences();
+      op->erase();
+      lazyLoadableOps.pop_front();
+      lazyLoadableOpsMap.erase(op);
+    }
+    return success();
+  }
 
 private:
+  LogicalResult materialize(LazyLoadableOpsMap::iterator it) {
+    assert(it != lazyLoadableOpsMap.end() &&
+           "materialize called on non-materializable op");
+    valueScopes.emplace_back();
+    std::vector<RegionReadState> regionStack;
+    regionStack.push_back(std::move(it->getSecond()->second));
+    lazyLoadableOps.erase(it->getSecond());
+    lazyLoadableOpsMap.erase(it);
+    auto result = parseRegions(regionStack, regionStack.back());
+    assert(regionStack.empty());
+    return result;
+  }
+
   /// Return the context for this config.
   MLIRContext *getContext() const { return config.getContext(); }
 
@@ -1151,14 +1224,22 @@ class BytecodeReader {
   /// This struct represents the current read state of a range of regions. This
   /// struct is used to enable iterative parsing of regions.
   struct RegionReadState {
-    RegionReadState(Operation *op, bool isIsolatedFromAbove)
-        : RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
-    RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove)
-        : curRegion(regions.begin()), endRegion(regions.end()),
+    RegionReadState(Operation *op, EncodingReader *reader,
+                    bool isIsolatedFromAbove)
+        : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {}
+    RegionReadState(MutableArrayRef<Region> regions, EncodingReader *reader,
+                    bool isIsolatedFromAbove)
+        : curRegion(regions.begin()), endRegion(regions.end()), reader(reader),
           isIsolatedFromAbove(isIsolatedFromAbove) {}
 
     /// The current regions being read.
     MutableArrayRef<Region>::iterator curRegion, endRegion;
+    /// This is the reader to use for this region, this pointer is pointing to
+    /// the parent region reader unless the current region is IsolatedFromAbove,
+    /// in which case the pointer is pointing to the `owningReader` which is a
+    /// section dedicated to the current region.
+    EncodingReader *reader;
+    std::unique_ptr<EncodingReader> owningReader;
 
     /// The number of values defined immediately within this region.
     unsigned numValues = 0;
@@ -1176,15 +1257,15 @@ class BytecodeReader {
   };
 
   LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block);
-  LogicalResult parseRegions(EncodingReader &reader,
-                             std::vector<RegionReadState> &regionStack,
+  LogicalResult parseRegions(std::vector<RegionReadState> &regionStack,
                              RegionReadState &readState);
   FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
                                                RegionReadState &readState,
                                                bool &isIsolatedFromAbove);
 
-  LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
-  LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
+  LogicalResult parseRegion(RegionReadState &readState);
+  LogicalResult parseBlockHeader(EncodingReader &reader,
+                                 RegionReadState &readState);
   LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
 
   //===--------------------------------------------------------------------===//
@@ -1234,6 +1315,16 @@ class BytecodeReader {
   /// A location to use when emitting errors.
   Location fileLoc;
 
+  /// Flag that indicates if lazyloading is enabled.
+  bool lazyLoading;
+
+  /// Keep track of operations that have been lazy loaded (their regions haven't
+  /// been materialized), along with the `RegionReadState` that allows to
+  /// lazy-load the regions nested under the operation.
+  LazyLoadableOpsInfo lazyLoadableOps;
+  LazyLoadableOpsMap lazyLoadableOpsMap;
+  llvm::function_ref<bool(Operation *)> lazyOpsCallback;
+
   /// The reader used to process attribute and types within the bytecode.
   AttrTypeReader attrTypeReader;
 
@@ -1264,14 +1355,20 @@ class BytecodeReader {
   /// An operation state used when instantiating forward references.
   OperationState forwardRefOpState;
 
+  /// Reference to the input buffer.
+  llvm::MemoryBufferRef buffer;
+
   /// The optional owning source manager, which when present may be used to
   /// extend the lifetime of the input buffer.
   const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef;
 };
-} // namespace
 
-LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
+LogicalResult BytecodeReader::Impl::read(
+    Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
   EncodingReader reader(buffer.getBuffer(), fileLoc);
+  this->lazyOpsCallback = lazyOpsCallback;
+  auto resetlazyOpsCallback =
+      llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; });
 
   // Skip over the bytecode header, this should have already been checked.
   if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
@@ -1302,7 +1399,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
     // Check for duplicate sections, we only expect one instance of each.
     if (sectionDatas[sectionID]) {
       return reader.emitError("duplicate top-level section: ",
-                              toString(sectionID));
+                              ::toString(sectionID));
     }
     sectionDatas[sectionID] = sectionData;
   }
@@ -1311,7 +1408,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
     bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i);
     if (!sectionDatas[i] && !isSectionOptional(sectionID)) {
       return reader.emitError("missing data for top-level section: ",
-                              toString(sectionID));
+                              ::toString(sectionID));
     }
   }
 
@@ -1340,7 +1437,7 @@ LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
   return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
 }
 
-LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
+LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) {
   if (failed(reader.parseVarInt(version)))
     return failure();
 
@@ -1357,6 +1454,9 @@ LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
                             " is newer than the current version ",
                             currentVersion);
   }
+  // Override any request to lazy-load if the bytecode version is too old.
+  if (version < 2)
+    lazyLoading = false;
   return success();
 }
 
@@ -1396,7 +1496,7 @@ LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
 }
 
 LogicalResult
-BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
+BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
   EncodingReader sectionReader(sectionData, fileLoc);
 
   // Parse the number of dialects in the section.
@@ -1449,7 +1549,8 @@ BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) {
   return success();
 }
 
-FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
+FailureOr<OperationName>
+BytecodeReader::Impl::parseOpName(EncodingReader &reader) {
   BytecodeOperationName *opName = nullptr;
   if (failed(parseEntry(reader, opNames, opName, "operation name")))
     return failure();
@@ -1471,7 +1572,7 @@ FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) {
 //===----------------------------------------------------------------------===//
 // Resource Section
 
-LogicalResult BytecodeReader::parseResourceSection(
+LogicalResult BytecodeReader::Impl::parseResourceSection(
     EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData,
     std::optional<ArrayRef<uint8_t>> resourceOffsetData) {
   // Ensure both sections are either present or not.
@@ -1499,8 +1600,9 @@ LogicalResult BytecodeReader::parseResourceSection(
 //===----------------------------------------------------------------------===//
 // IR Section
 
-LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
-                                             Block *block) {
+LogicalResult
+BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
+                                     Block *block) {
   EncodingReader reader(sectionData, fileLoc);
 
   // A stack of operation regions currently being read from the bytecode.
@@ -1508,17 +1610,17 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
 
   // Parse the top-level block using a temporary module operation.
   OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc);
-  regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
+  regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true);
   regionStack.back().curBlocks.push_back(moduleOp->getBody());
   regionStack.back().curBlock = regionStack.back().curRegion->begin();
-  if (failed(parseBlock(reader, regionStack.back())))
+  if (failed(parseBlockHeader(reader, regionStack.back())))
     return failure();
   valueScopes.emplace_back();
   valueScopes.back().push(regionStack.back());
 
   // Iteratively parse regions until everything has been resolved.
   while (!regionStack.empty())
-    if (failed(parseRegions(reader, regionStack, regionStack.back())))
+    if (failed(parseRegions(regionStack, regionStack.back())))
       return failure();
   if (!forwardRefOps.empty()) {
     return reader.emitError(
@@ -1549,15 +1651,18 @@ LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData,
 }
 
 LogicalResult
-BytecodeReader::parseRegions(EncodingReader &reader,
-                             std::vector<RegionReadState> &regionStack,
-                             RegionReadState &readState) {
-  // Read the regions of this operation.
+BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
+                                   RegionReadState &readState) {
+  // Process regions, blocks, and operations until the end or if a nested
+  // region is encountered. In this case we push a new state in regionStack and
+  // return, the processing of the current region will resume afterward.
   for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
     // If the current block hasn't been setup yet, parse the header for this
-    // region.
+    // region. The current block is already setup when this function was
+    // interrupted to recurse down in a nested region and we resume the current
+    // block after processing the nested region.
     if (readState.curBlock == Region::iterator()) {
-      if (failed(parseRegion(reader, readState)))
+      if (failed(parseRegion(readState)))
         return failure();
 
       // If the region is empty, there is nothing to more to do.
@@ -1566,6 +1671,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
     }
 
     // Parse the blocks within the region.
+    EncodingReader &reader = *readState.reader;
     do {
       while (readState.numOpsRemaining--) {
         // Read in the next operation. We don't read its regions directly, we
@@ -1576,9 +1682,38 @@ BytecodeReader::parseRegions(EncodingReader &reader,
         if (failed(op))
           return failure();
 
-        // If the op has regions, add it to the stack for processing.
+        // If the op has regions, add it to the stack for processing and return:
+        // we stop the processing of the current region and resume it after the
+        // inner one is completed. Unless LazyLoading is activated in which case
+        // nested region parsing is delayed.
         if ((*op)->getNumRegions()) {
-          regionStack.emplace_back(*op, isIsolatedFromAbove);
+          RegionReadState childState(*op, &reader, isIsolatedFromAbove);
+
+          // Isolated regions are encoded as a section in version 2 and above.
+          if (version >= 2 && isIsolatedFromAbove) {
+            bytecode::Section::ID sectionID;
+            ArrayRef<uint8_t> sectionData;
+            if (failed(reader.parseSection(sectionID, sectionData)))
+              return failure();
+            if (sectionID != bytecode::Section::kIR)
+              return emitError(fileLoc, "expected IR section for region");
+            childState.owningReader =
+                std::make_unique<EncodingReader>(sectionData, fileLoc);
+            childState.reader = childState.owningReader.get();
+          }
+
+          if (lazyLoading) {
+            // If the user has a callback set, they have the opportunity
+            // to control lazyloading as we go.
+            if (!lazyOpsCallback || !lazyOpsCallback(*op)) {
+              lazyLoadableOps.push_back(
+                  std::make_pair(*op, std::move(childState)));
+              lazyLoadableOpsMap.try_emplace(*op,
+                                             std::prev(lazyLoadableOps.end()));
+              continue;
+            }
+          }
+          regionStack.push_back(std::move(childState));
 
           // If the op is isolated from above, push a new value scope.
           if (isIsolatedFromAbove)
@@ -1590,7 +1725,7 @@ BytecodeReader::parseRegions(EncodingReader &reader,
       // Move to the next block of the region.
       if (++readState.curBlock == readState.curRegion->end())
         break;
-      if (failed(parseBlock(reader, readState)))
+      if (failed(parseBlockHeader(reader, readState)))
         return failure();
     } while (true);
 
@@ -1601,16 +1736,19 @@ BytecodeReader::parseRegions(EncodingReader &reader,
 
   // When the regions have been fully parsed, pop them off of the read stack. If
   // the regions were isolated from above, we also pop the last value scope.
-  if (readState.isIsolatedFromAbove)
+  if (readState.isIsolatedFromAbove) {
+    assert(!valueScopes.empty() && "Expect a valueScope after reading region");
     valueScopes.pop_back();
+  }
+  assert(!regionStack.empty() && "Expect a regionStack after reading region");
   regionStack.pop_back();
   return success();
 }
 
 FailureOr<Operation *>
-BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
-                                      RegionReadState &readState,
-                                      bool &isIsolatedFromAbove) {
+BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
+                                            RegionReadState &readState,
+                                            bool &isIsolatedFromAbove) {
   // Parse the name of the operation.
   FailureOr<OperationName> opName = parseOpName(reader);
   if (failed(opName))
@@ -1696,8 +1834,9 @@ BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
   return op;
 }
 
-LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
-                                          RegionReadState &readState) {
+LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) {
+  EncodingReader &reader = *readState.reader;
+
   // Parse the number of blocks in the region.
   uint64_t numBlocks;
   if (failed(reader.parseVarInt(numBlocks)))
@@ -1727,11 +1866,12 @@ LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
 
   // Parse the entry block of the region.
   readState.curBlock = readState.curRegion->begin();
-  return parseBlock(reader, readState);
+  return parseBlockHeader(reader, readState);
 }
 
-LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
-                                         RegionReadState &readState) {
+LogicalResult
+BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader,
+                                       RegionReadState &readState) {
   bool hasArgs;
   if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs)))
     return failure();
@@ -1744,8 +1884,8 @@ LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
   return success();
 }
 
-LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
-                                                  Block *block) {
+LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
+                                                        Block *block) {
   // Parse the value ID for the first argument, and the number of arguments.
   uint64_t numArgs;
   if (failed(reader.parseVarInt(numArgs)))
@@ -1773,7 +1913,7 @@ LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
 //===----------------------------------------------------------------------===//
 // Value Processing
 
-Value BytecodeReader::parseOperand(EncodingReader &reader) {
+Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) {
   std::vector<Value> &values = valueScopes.back().values;
   Value *value = nullptr;
   if (failed(parseEntry(reader, values, value, "value")))
@@ -1785,8 +1925,8 @@ Value BytecodeReader::parseOperand(EncodingReader &reader) {
   return *value;
 }
 
-LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
-                                           ValueRange newValues) {
+LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader,
+                                                 ValueRange newValues) {
   ValueScope &valueScope = valueScopes.back();
   std::vector<Value> &values = valueScope.values;
 
@@ -1821,7 +1961,7 @@ LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
   return success();
 }
 
-Value BytecodeReader::createForwardRef() {
+Value BytecodeReader::Impl::createForwardRef() {
   // Check for an avaliable existing operation to use. Otherwise, create a new
   // fake operation to use for the reference.
   if (!openForwardRefOps.empty()) {
@@ -1837,6 +1977,41 @@ Value BytecodeReader::createForwardRef() {
 // Entry Points
 //===----------------------------------------------------------------------===//
 
+BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); }
+
+BytecodeReader::BytecodeReader(
+    llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading,
+    const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) {
+  Location sourceFileLoc =
+      FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
+                          /*line=*/0, /*column=*/0);
+  impl = std::make_unique<Impl>(sourceFileLoc, config, lazyLoading, buffer,
+                                bufferOwnerRef);
+}
+
+LogicalResult BytecodeReader::readTopLevel(
+    Block *block, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+  return impl->read(block, lazyOpsCallback);
+}
+
+int64_t BytecodeReader::getNumOpsToMaterialize() const {
+  return impl->getNumOpsToMaterialize();
+}
+
+bool BytecodeReader::isMaterializable(Operation *op) {
+  return impl->isMaterializable(op);
+}
+
+LogicalResult BytecodeReader::materialize(
+    Operation *op, llvm::function_ref<bool(Operation *)> lazyOpsCallback) {
+  return impl->materialize(op, lazyOpsCallback);
+}
+
+LogicalResult
+BytecodeReader::finalize(function_ref<bool(Operation *)> shouldMaterialize) {
+  return impl->finalize(shouldMaterialize);
+}
+
 bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
   return buffer.getBuffer().startswith("ML\xefR");
 }
@@ -1856,8 +2031,9 @@ readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block,
                      "input buffer is not an MLIR bytecode file");
   }
 
-  BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef);
-  return reader.read(buffer, block);
+  BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false,
+                              buffer, bufferOwnerRef);
+  return reader.read(block, /*lazyOpsCallback=*/nullptr);
 }
 
 LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 801f3022d0e47..158dbe6d161db 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -734,8 +734,18 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
     bool isIsolatedFromAbove = op->hasTrait<OpTrait::IsIsolatedFromAbove>();
     emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
 
-    for (Region &region : op->getRegions())
-      writeRegion(emitter, &region);
+    for (Region &region : op->getRegions()) {
+      // If the region is not isolated from above, or we are emitting bytecode
+      // targetting version <2, we don't use a section.
+      if (!isIsolatedFromAbove || config.bytecodeVersion < 2) {
+        writeRegion(emitter, &region);
+        continue;
+      }
+
+      EncodingEmitter regionEmitter;
+      writeRegion(regionEmitter, &region);
+      emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter));
+    }
   }
 }
 

diff  --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
new file mode 100644
index 0000000000000..a4f7974b0b690
--- /dev/null
+++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading)" %s -o %t | FileCheck %s
+// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading{bytecode-version=1})" %s -o %t | FileCheck %s --check-prefix=OLD-BYTECODE
+
+
+func.func @op_with_passthrough_region_args() {
+  %0 = arith.constant 10 : index
+  test.isolated_region %0 {
+    "test.consumer"(%0) : (index) -> ()
+  }
+  %result:2 = "test.op"() : () -> (index, index)
+  test.isolated_region %result#1 {
+    "test.consumer"(%result#1) : (index) -> ()
+  }
+  return
+}
+
+// Before version 2, we can't support lazy loading.
+// OLD-BYTECODE-NOT: Has 1 ops to materialize
+// OLD-BYTECODE-NOT: Materializing
+// OLD-BYTECODE: Has 0 ops to materialize
+
+
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: "builtin.module"() ({
+// CHECK-NOT: func
+// CHECK: Materializing...
+// CHECK: "builtin.module"() ({
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK-NOT: arith
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK-NOT: arith
+// CHECK: Materializing...
+// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({
+// CHECK: arith
+// CHECK: isolated_region
+// CHECK-NOT: test.consumer
+// CHECK: Has 2 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: test.isolated_region
+// CHECK-NOT:  test.consumer
+// CHECK: Materializing...
+// CHECK: test.isolated_region
+// CHECK: ^bb0(%arg0: index):
+// CHECK:  test.consumer
+// CHECK: Has 1 ops to materialize
+
+// CHECK: Before Materializing...
+// CHECK: test.isolated_region
+// CHECK-NOT: test.consumer
+// CHECK: Materializing...
+// CHECK: test.isolated_region
+// CHECK: test.consumer
+// CHECK: Has 0 ops to materialize

diff  --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
index d98c6c6191b87..4668878b10560 100644
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
 //===--------------------------------------------------------------------===//
 
 // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 1
+// VERSION: bytecode version 127 is newer than the current version 2
 
 //===--------------------------------------------------------------------===//
 // Producer

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 8b519538719e5..627036d021fb7 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRTestIR
   TestFunc.cpp
   TestInterfaces.cpp
   TestMatchers.cpp
+  TestLazyLoading.cpp
   TestOpaqueLoc.cpp
   TestOperationEquals.cpp
   TestPrintDefUse.cpp

diff  --git a/mlir/test/lib/IR/TestLazyLoading.cpp b/mlir/test/lib/IR/TestLazyLoading.cpp
new file mode 100644
index 0000000000000..187b977c7daaa
--- /dev/null
+++ b/mlir/test/lib/IR/TestLazyLoading.cpp
@@ -0,0 +1,93 @@
+//===- TestLazyLoading.cpp - Pass to test operation lazy loading  ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
+#include <list>
+
+using namespace mlir;
+
+namespace {
+
+/// This is a test pass which LazyLoads the current operation recursively.
+struct LazyLoadingPass : public PassWrapper<LazyLoadingPass, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LazyLoadingPass)
+
+  StringRef getArgument() const final { return "test-lazy-loading"; }
+  StringRef getDescription() const final { return "Test LazyLoading of op"; }
+  LazyLoadingPass() = default;
+  LazyLoadingPass(const LazyLoadingPass &) {}
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    std::string bytecode;
+    {
+      BytecodeWriterConfig config;
+      if (version >= 0)
+        config.setDesiredBytecodeVersion(version);
+      llvm::raw_string_ostream os(bytecode);
+      if (failed(writeBytecodeToFile(op, os, config))) {
+        op->emitError() << "failed to write bytecode at version "
+                        << (int)version;
+        signalPassFailure();
+        return;
+      }
+    }
+    llvm::MemoryBufferRef buffer(bytecode, "test-lazy-loading");
+    Block block;
+    ParserConfig config(op->getContext(), /*verifyAfterParse=*/false);
+    BytecodeReader reader(buffer, config,
+                          /*lazyLoad=*/true);
+    std::list<Operation *> toLoadOps;
+    if (failed(reader.readTopLevel(&block, [&](Operation *op) {
+          toLoadOps.push_back(op);
+          return false;
+        }))) {
+      op->emitError() << "failed to read bytecode";
+      return;
+    }
+
+    llvm::outs() << "Has " << reader.getNumOpsToMaterialize()
+                 << " ops to materialize\n";
+
+    // Recursively print the operations, before and after lazy loading.
+    while (!toLoadOps.empty()) {
+      Operation *toLoad = toLoadOps.front();
+      toLoadOps.pop_front();
+      llvm::outs() << "\n\nBefore Materializing...\n\n";
+      toLoad->print(llvm::outs());
+      llvm::outs() << "\n\nMaterializing...\n\n";
+      if (failed(reader.materialize(toLoad, [&](Operation *op) {
+            toLoadOps.push_back(op);
+            return false;
+          }))) {
+        toLoad->emitError() << "failed to materialize";
+        signalPassFailure();
+        return;
+      }
+      toLoad->print(llvm::outs());
+      llvm::outs() << "\n";
+      llvm::outs() << "Has " << reader.getNumOpsToMaterialize()
+                   << " ops to materialize\n";
+    }
+  }
+  Option<int> version{*this, "bytecode-version",
+                      llvm::cl::desc("Specifies the bytecode version to use."),
+                      llvm::cl::init(-1)};
+};
+} // namespace
+
+namespace mlir {
+void registerLazyLoadingTestPasses() { PassRegistration<LazyLoadingPass>(); }
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 13a525f0bcff8..40b9c827fa610 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -31,6 +31,7 @@ using namespace mlir;
 namespace mlir {
 void registerConvertToTargetEnvPass();
 void registerCloneTestPasses();
+void registerLazyLoadingTestPasses();
 void registerPassManagerTestPass();
 void registerPrintSpirvAvailabilityPass();
 void registerLoopLikeInterfaceTestPasses();
@@ -146,6 +147,7 @@ void registerTestPasses() {
   registerConvertToTargetEnvPass();
   registerPassManagerTestPass();
   registerPrintSpirvAvailabilityPass();
+  registerLazyLoadingTestPasses();
   registerLoopLikeInterfaceTestPasses();
   registerShapeFunctionTestPasses();
   registerSideEffectTestPasses();


        


More information about the Mlir-commits mailing list