[Mlir-commits] [mlir] [mlir][bytecode] Add support for deferred attribute/type parsing. (PR #170993)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Dec 6 10:48:39 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

<details>
<summary>Changes</summary>

Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow).

I tried a few different variants of this (peeking, thunks, adding preload section, considered breaking encoding change), but this felt the cleanest but at cost of introducing a worklist that mostly only has 1 item.

---
Full diff: https://github.com/llvm/llvm-project/pull/170993.diff


5 Files Affected:

- (modified) mlir/docs/DefiningDialects/_index.md (+20) 
- (modified) mlir/include/mlir/Bytecode/BytecodeImplementation.h (+21) 
- (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+8-2) 
- (modified) mlir/lib/Bytecode/Reader/BytecodeReader.cpp (+115-24) 
- (modified) mlir/unittests/Bytecode/BytecodeTest.cpp (+37) 


``````````diff
diff --git a/mlir/docs/DefiningDialects/_index.md b/mlir/docs/DefiningDialects/_index.md
index 987b51b4ab4ef..9c9f9c93fcf39 100644
--- a/mlir/docs/DefiningDialects/_index.md
+++ b/mlir/docs/DefiningDialects/_index.md
@@ -425,6 +425,26 @@ struct FooDialectBytecodeInterface : public BytecodeDialectInterface {
 along with defining the corresponding build rules to invoke generator
 (`-gen-bytecode -bytecode-dialect="Quant"`).
 
+#### Deferred Parsing for Recursive Dependencies
+
+When parsing attributes or types that reference other attributes or types (e.g.,
+`CallSiteLoc` which contains nested location attributes), the referenced entries
+may not yet be resolved. The `DialectBytecodeReader` provides helpers to handle
+this:
+
+```c++
+Attribute attr = reader.getOrDeferParsingAttribute();
+if (!attr)
+  return failure();  // Will be retried after dependencies are resolved
+```
+
+The `getOrDeferParsingAttribute()` method reads the attribute index from the
+stream and attempts to resolve it. If the referenced attribute hasn't been
+parsed yet, it registers for deferred parsing and returns nullptr. The bytecode
+reader will automatically retry parsing after processing the dependencies.
+
+Note: for error cases, one needs to return failure *before* deferring parsing.
+
 ## Defining an Extensible dialect
 
 This section documents the design and API of the extensible dialects. Extensible
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..65e7d23fa1139 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -103,6 +103,16 @@ class DialectBytecodeReader {
   /// the Attribute isn't present.
   virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
 
+  /// Try to get an attribute, deferring parsing if not yet resolved (returning
+  /// nullptr and enqueuing for deferred parsing).
+  virtual Attribute getOrDeferParsingAttribute() = 0;
+
+  /// Typed version of getOrDeferParsingAttribute. Returns the attribute cast
+  /// to the specified type, or nullptr if not resolved or cast fails.
+  template <typename T> T getOrDeferParsingAttribute() {
+    return llvm::dyn_cast_or_null<T>(getOrDeferParsingAttribute());
+  }
+
   template <typename T>
   LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
     return readList(attrs, [this](T &attr) { return readAttribute(attr); });
@@ -132,6 +142,17 @@ class DialectBytecodeReader {
 
   /// Read a reference to the given type.
   virtual LogicalResult readType(Type &result) = 0;
+
+  /// Try to get an type, deferring parsing if not yet resolved (returning
+  /// nullptr and enqueuing for deferred parsing).
+  virtual Type getOrDeferParsingType() = 0;
+
+  /// Typed version of getOrDeferParsingType. Returns the type cast
+  /// to the specified type, or nullptr if not resolved or cast fails.
+  template <typename T> T getOrDeferParsingType() {
+    return llvm::dyn_cast_or_null<T>(getOrDeferParsingType());
+  }
+
   template <typename T>
   LogicalResult readTypes(SmallVectorImpl<T> &types) {
     return readList(types, [this](T &type) { return readType(type); });
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..4162d6dad3c67 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -25,6 +25,12 @@ def Location : CompositeBytecode {
   let cBuilder = "Location($_args)";
 }
 
+def MaybeDeferredLocationAttr :
+  WithParser <"($_var = $_reader.getOrDeferParsingAttribute<LocationAttr>())",
+  WithBuilder<"$_args",
+  WithPrinter<"$_writer.writeAttribute($_getter)",
+  WithType   <"LocationAttr">>>>;
+
 def String :
   WithParser <"succeeded($_reader.readString($_var))",
   WithBuilder<"$_args",
@@ -91,8 +97,8 @@ def FloatAttr : DialectAttribute<(attr
 }
 
 def CallSiteLoc : DialectAttribute<(attr
-  LocationAttr:$callee,
-  LocationAttr:$caller
+  MaybeDeferredLocationAttr:$callee,
+  MaybeDeferredLocationAttr:$caller
 )>;
 
 let cType = "FileLineColRange" in {
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437e1eb24..9c6de9ed6ebec 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
 
 #include <cstddef>
 #include <cstdint>
+#include <deque>
 #include <list>
 #include <memory>
 #include <numeric>
@@ -925,7 +926,7 @@ class AttrTypeReader {
   /// bytecode format.
   template <typename T>
   LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
-                                 StringRef entryType);
+                                 StringRef entryType, uint64_t index);
 
   /// The string section reader used to resolve string references when parsing
   /// custom encoded attribute/type entries.
@@ -951,6 +952,28 @@ class AttrTypeReader {
 
   /// Reference to the parser configuration.
   const ParserConfig &parserConfig;
+
+  /// Worklist for deferred attribute/type parsing. This is used to
+  /// handle deeply nested structures like CallSiteLoc iteratively.
+  std::vector<uint64_t> deferredWorklist;
+
+public:
+  /// Get the attribute at the given index, returning null if not resolved.
+  Attribute getAttributeOrSentinel(size_t index) {
+    if (index >= attributes.size())
+      return {};
+    return attributes[index].entry;
+  }
+
+  /// Get the type at the given index, returning null if not resolved.
+  Type getTypeOrSentinel(size_t index) {
+    if (index >= types.size())
+      return {};
+    return types[index].entry;
+  }
+
+  /// Add an index to the deferred worklist for re-parsing.
+  void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
 };
 
 class DialectReader : public DialectBytecodeReader {
@@ -959,10 +982,12 @@ class DialectReader : public DialectBytecodeReader {
                 const StringSectionReader &stringReader,
                 const ResourceSectionReader &resourceReader,
                 const llvm::StringMap<BytecodeDialect *> &dialectsMap,
-                EncodingReader &reader, uint64_t &bytecodeVersion)
+                EncodingReader &reader, uint64_t &bytecodeVersion,
+                uint64_t currentIndex = 0)
       : attrTypeReader(attrTypeReader), stringReader(stringReader),
         resourceReader(resourceReader), dialectsMap(dialectsMap),
-        reader(reader), bytecodeVersion(bytecodeVersion) {}
+        reader(reader), bytecodeVersion(bytecodeVersion),
+        currentIndex(currentIndex) {}
 
   InFlightDiagnostic emitError(const Twine &msg) const override {
     return reader.emitError(msg);
@@ -989,7 +1014,7 @@ class DialectReader : public DialectBytecodeReader {
 
   DialectReader withEncodingReader(EncodingReader &encReader) const {
     return DialectReader(attrTypeReader, stringReader, resourceReader,
-                         dialectsMap, encReader, bytecodeVersion);
+                         dialectsMap, encReader, bytecodeVersion, currentIndex);
   }
 
   Location getLoc() const { return reader.getLoc(); }
@@ -1004,9 +1029,27 @@ class DialectReader : public DialectBytecodeReader {
   LogicalResult readOptionalAttribute(Attribute &result) override {
     return attrTypeReader.parseOptionalAttribute(reader, result);
   }
+  Attribute getOrDeferParsingAttribute() override {
+    uint64_t index;
+    if (failed(reader.parseVarInt(index)))
+      return nullptr;
+    Attribute attr = attrTypeReader.getAttributeOrSentinel(index);
+    if (!attr)
+      attrTypeReader.addDeferredParsing(index);
+    return attr;
+  }
   LogicalResult readType(Type &result) override {
     return attrTypeReader.parseType(reader, result);
   }
+  Type getOrDeferParsingType() override {
+    uint64_t index;
+    if (failed(reader.parseVarInt(index)))
+      return nullptr;
+    Type type = attrTypeReader.getTypeOrSentinel(index);
+    if (!type)
+      attrTypeReader.addDeferredParsing(index);
+    return type;
+  }
 
   FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
     AsmDialectResourceHandle handle;
@@ -1095,6 +1138,7 @@ class DialectReader : public DialectBytecodeReader {
   const llvm::StringMap<BytecodeDialect *> &dialectsMap;
   EncodingReader &reader;
   uint64_t &bytecodeVersion;
+  uint64_t currentIndex;
 };
 
 /// Wraps the properties section and handles reading properties out of it.
@@ -1245,27 +1289,74 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
     return {};
   }
 
-  // If the entry has already been resolved, there is nothing left to do.
-  Entry<T> &entry = entries[index];
-  if (entry.entry)
-    return entry.entry;
+  // Use a deque to iteratively resolve entries with dependencies.
+  // - Pop from front to process
+  // - Push new dependencies to front (depth-first)
+  // - Move failed entries to back (retry after dependencies)
+  std::deque<size_t> worklist;
+  llvm::DenseSet<size_t> inWorklist;
+  worklist.push_back(index);
+  inWorklist.insert(index);
 
-  // Parse the entry.
-  EncodingReader reader(entry.data, fileLoc);
+  while (!worklist.empty()) {
+    size_t currentIndex = worklist.front();
+    worklist.pop_front();
+
+    if (currentIndex >= entries.size()) {
+      emitError(fileLoc) << "invalid " << entryType
+                         << " index: " << currentIndex;
+      return {};
+    }
 
-  // Parse based on how the entry was encoded.
-  if (entry.hasCustomEncoding) {
-    if (failed(parseCustomEntry(entry, reader, entryType)))
+    Entry<T> &entry = entries[currentIndex];
+
+    // If already resolved, continue.
+    if (entry.entry) {
+      inWorklist.erase(currentIndex);
+      continue;
+    }
+
+    // Clear the deferred worklist before parsing to capture any new entries.
+    deferredWorklist.clear();
+
+    // Parse the entry.
+    EncodingReader reader(entry.data, fileLoc);
+
+    // Parse based on how the entry was encoded.
+    LogicalResult parsed =
+        entry.hasCustomEncoding
+            ? parseCustomEntry(entry, reader, entryType, currentIndex)
+            : parseAsmEntry(entry.entry, reader, entryType);
+    bool parseSucceeded = succeeded(parsed);
+
+    if (parseSucceeded && !reader.empty()) {
+      reader.emitError("unexpected trailing bytes after " + entryType +
+                       " entry");
+      parseSucceeded = false;
+    }
+
+    if (parseSucceeded && entry.entry) {
+      // Successfully parsed, done with this entry.
+      inWorklist.erase(currentIndex);
+    } else if (!deferredWorklist.empty()) {
+      // Check if deferred parsing was requested.
+
+      // Move this entry to the back to retry after dependencies.
+      worklist.push_back(currentIndex);
+
+      // Add dependencies to the front (in reverse so they maintain order).
+      for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+        if (inWorklist.insert(idx).second)
+          worklist.push_front(idx);
+      }
+      deferredWorklist.clear();
+    } else {
+      // Parsing failed with no deferred entries which implies an error.
       return T();
-  } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
-    return T();
+    }
   }
 
-  if (!reader.empty()) {
-    reader.emitError("unexpected trailing bytes after " + entryType + " entry");
-    return T();
-  }
-  return entry.entry;
+  return entries[index].entry;
 }
 
 template <typename T>
@@ -1296,11 +1387,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
 }
 
 template <typename T>
-LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
-                                               EncodingReader &reader,
-                                               StringRef entryType) {
+LogicalResult
+AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
+                                 StringRef entryType, uint64_t index) {
   DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
-                              reader, bytecodeVersion);
+                              reader, bytecodeVersion, index);
   if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
     return failure();
 
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index d7b442f6832d0..d5c6f010f5b8a 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -228,3 +228,40 @@ TEST(Bytecode, OpWithoutProperties) {
   EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
               OperationEquivalence::computeHash(roundtripped));
 }
+
+TEST(Bytecode, DeepCallSiteLoc) {
+  MLIRContext context;
+  ParserConfig config(&context);
+
+  // Create a deep CallSiteLoc chain to test iterative parsing.
+  // Use a depth that fits in the stack for writing but is still substantial.
+  Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1);
+  Location loc = baseLoc;
+  constexpr int kDepth = 1000;
+  for (int i = 0; i < kDepth; ++i) {
+    loc = CallSiteLoc::get(loc, baseLoc);
+  }
+
+  // Create a simple module with the deep location.
+  OwningOpRef<Operation *> module =
+      parseSourceString<Operation *>("module {}", config);
+  ASSERT_TRUE(module);
+  module.get()->setLoc(loc);
+
+  // Write to bytecode.
+  std::string bytecode;
+  llvm::raw_string_ostream os(bytecode);
+  ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os)));
+
+  // Parse it back using the bytecode reader.
+  std::unique_ptr<Block> block = std::make_unique<Block>();
+  ASSERT_TRUE(succeeded(readBytecodeFile(
+      llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
+
+  // Verify we got the roundtripped module.
+  ASSERT_FALSE(block->empty());
+  Operation *roundTripped = &block->front();
+
+  // Verify the location matches.
+  EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc());
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/170993


More information about the Mlir-commits mailing list