[Mlir-commits] [mlir] [mlir][bytecode] Add support for deferred attribute/type parsing. (PR #170993)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Dec 6 10:48:39 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir
Author: Jacques Pienaar (jpienaar)
<details>
<summary>Changes</summary>
Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode the recursion could overflow).
I tried a few different variants of this (peeking, thunks, adding preload section, considered breaking encoding change), but this felt the cleanest but at cost of introducing a worklist that mostly only has 1 item.
---
Full diff: https://github.com/llvm/llvm-project/pull/170993.diff
5 Files Affected:
- (modified) mlir/docs/DefiningDialects/_index.md (+20)
- (modified) mlir/include/mlir/Bytecode/BytecodeImplementation.h (+21)
- (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+8-2)
- (modified) mlir/lib/Bytecode/Reader/BytecodeReader.cpp (+115-24)
- (modified) mlir/unittests/Bytecode/BytecodeTest.cpp (+37)
``````````diff
diff --git a/mlir/docs/DefiningDialects/_index.md b/mlir/docs/DefiningDialects/_index.md
index 987b51b4ab4ef..9c9f9c93fcf39 100644
--- a/mlir/docs/DefiningDialects/_index.md
+++ b/mlir/docs/DefiningDialects/_index.md
@@ -425,6 +425,26 @@ struct FooDialectBytecodeInterface : public BytecodeDialectInterface {
along with defining the corresponding build rules to invoke generator
(`-gen-bytecode -bytecode-dialect="Quant"`).
+#### Deferred Parsing for Recursive Dependencies
+
+When parsing attributes or types that reference other attributes or types (e.g.,
+`CallSiteLoc` which contains nested location attributes), the referenced entries
+may not yet be resolved. The `DialectBytecodeReader` provides helpers to handle
+this:
+
+```c++
+Attribute attr = reader.getOrDeferParsingAttribute();
+if (!attr)
+ return failure(); // Will be retried after dependencies are resolved
+```
+
+The `getOrDeferParsingAttribute()` method reads the attribute index from the
+stream and attempts to resolve it. If the referenced attribute hasn't been
+parsed yet, it registers for deferred parsing and returns nullptr. The bytecode
+reader will automatically retry parsing after processing the dependencies.
+
+Note: for error cases, one needs to return failure *before* deferring parsing.
+
## Defining an Extensible dialect
This section documents the design and API of the extensible dialects. Extensible
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
index 0ddc531073e23..65e7d23fa1139 100644
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -103,6 +103,16 @@ class DialectBytecodeReader {
/// the Attribute isn't present.
virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0;
+ /// Try to get an attribute, deferring parsing if not yet resolved (returning
+ /// nullptr and enqueuing for deferred parsing).
+ virtual Attribute getOrDeferParsingAttribute() = 0;
+
+ /// Typed version of getOrDeferParsingAttribute. Returns the attribute cast
+ /// to the specified type, or nullptr if not resolved or cast fails.
+ template <typename T> T getOrDeferParsingAttribute() {
+ return llvm::dyn_cast_or_null<T>(getOrDeferParsingAttribute());
+ }
+
template <typename T>
LogicalResult readAttributes(SmallVectorImpl<T> &attrs) {
return readList(attrs, [this](T &attr) { return readAttribute(attr); });
@@ -132,6 +142,17 @@ class DialectBytecodeReader {
/// Read a reference to the given type.
virtual LogicalResult readType(Type &result) = 0;
+
+ /// Try to get an type, deferring parsing if not yet resolved (returning
+ /// nullptr and enqueuing for deferred parsing).
+ virtual Type getOrDeferParsingType() = 0;
+
+ /// Typed version of getOrDeferParsingType. Returns the type cast
+ /// to the specified type, or nullptr if not resolved or cast fails.
+ template <typename T> T getOrDeferParsingType() {
+ return llvm::dyn_cast_or_null<T>(getOrDeferParsingType());
+ }
+
template <typename T>
LogicalResult readTypes(SmallVectorImpl<T> &types) {
return readList(types, [this](T &type) { return readType(type); });
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..4162d6dad3c67 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -25,6 +25,12 @@ def Location : CompositeBytecode {
let cBuilder = "Location($_args)";
}
+def MaybeDeferredLocationAttr :
+ WithParser <"($_var = $_reader.getOrDeferParsingAttribute<LocationAttr>())",
+ WithBuilder<"$_args",
+ WithPrinter<"$_writer.writeAttribute($_getter)",
+ WithType <"LocationAttr">>>>;
+
def String :
WithParser <"succeeded($_reader.readString($_var))",
WithBuilder<"$_args",
@@ -91,8 +97,8 @@ def FloatAttr : DialectAttribute<(attr
}
def CallSiteLoc : DialectAttribute<(attr
- LocationAttr:$callee,
- LocationAttr:$caller
+ MaybeDeferredLocationAttr:$callee,
+ MaybeDeferredLocationAttr:$caller
)>;
let cType = "FileLineColRange" in {
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 1659437e1eb24..9c6de9ed6ebec 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -27,6 +27,7 @@
#include <cstddef>
#include <cstdint>
+#include <deque>
#include <list>
#include <memory>
#include <numeric>
@@ -925,7 +926,7 @@ class AttrTypeReader {
/// bytecode format.
template <typename T>
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
- StringRef entryType);
+ StringRef entryType, uint64_t index);
/// The string section reader used to resolve string references when parsing
/// custom encoded attribute/type entries.
@@ -951,6 +952,28 @@ class AttrTypeReader {
/// Reference to the parser configuration.
const ParserConfig &parserConfig;
+
+ /// Worklist for deferred attribute/type parsing. This is used to
+ /// handle deeply nested structures like CallSiteLoc iteratively.
+ std::vector<uint64_t> deferredWorklist;
+
+public:
+ /// Get the attribute at the given index, returning null if not resolved.
+ Attribute getAttributeOrSentinel(size_t index) {
+ if (index >= attributes.size())
+ return {};
+ return attributes[index].entry;
+ }
+
+ /// Get the type at the given index, returning null if not resolved.
+ Type getTypeOrSentinel(size_t index) {
+ if (index >= types.size())
+ return {};
+ return types[index].entry;
+ }
+
+ /// Add an index to the deferred worklist for re-parsing.
+ void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
};
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +982,12 @@ class DialectReader : public DialectBytecodeReader {
const StringSectionReader &stringReader,
const ResourceSectionReader &resourceReader,
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
- EncodingReader &reader, uint64_t &bytecodeVersion)
+ EncodingReader &reader, uint64_t &bytecodeVersion,
+ uint64_t currentIndex = 0)
: attrTypeReader(attrTypeReader), stringReader(stringReader),
resourceReader(resourceReader), dialectsMap(dialectsMap),
- reader(reader), bytecodeVersion(bytecodeVersion) {}
+ reader(reader), bytecodeVersion(bytecodeVersion),
+ currentIndex(currentIndex) {}
InFlightDiagnostic emitError(const Twine &msg) const override {
return reader.emitError(msg);
@@ -989,7 +1014,7 @@ class DialectReader : public DialectBytecodeReader {
DialectReader withEncodingReader(EncodingReader &encReader) const {
return DialectReader(attrTypeReader, stringReader, resourceReader,
- dialectsMap, encReader, bytecodeVersion);
+ dialectsMap, encReader, bytecodeVersion, currentIndex);
}
Location getLoc() const { return reader.getLoc(); }
@@ -1004,9 +1029,27 @@ class DialectReader : public DialectBytecodeReader {
LogicalResult readOptionalAttribute(Attribute &result) override {
return attrTypeReader.parseOptionalAttribute(reader, result);
}
+ Attribute getOrDeferParsingAttribute() override {
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return nullptr;
+ Attribute attr = attrTypeReader.getAttributeOrSentinel(index);
+ if (!attr)
+ attrTypeReader.addDeferredParsing(index);
+ return attr;
+ }
LogicalResult readType(Type &result) override {
return attrTypeReader.parseType(reader, result);
}
+ Type getOrDeferParsingType() override {
+ uint64_t index;
+ if (failed(reader.parseVarInt(index)))
+ return nullptr;
+ Type type = attrTypeReader.getTypeOrSentinel(index);
+ if (!type)
+ attrTypeReader.addDeferredParsing(index);
+ return type;
+ }
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
AsmDialectResourceHandle handle;
@@ -1095,6 +1138,7 @@ class DialectReader : public DialectBytecodeReader {
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
EncodingReader &reader;
uint64_t &bytecodeVersion;
+ uint64_t currentIndex;
};
/// Wraps the properties section and handles reading properties out of it.
@@ -1245,27 +1289,74 @@ T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
return {};
}
- // If the entry has already been resolved, there is nothing left to do.
- Entry<T> &entry = entries[index];
- if (entry.entry)
- return entry.entry;
+ // Use a deque to iteratively resolve entries with dependencies.
+ // - Pop from front to process
+ // - Push new dependencies to front (depth-first)
+ // - Move failed entries to back (retry after dependencies)
+ std::deque<size_t> worklist;
+ llvm::DenseSet<size_t> inWorklist;
+ worklist.push_back(index);
+ inWorklist.insert(index);
- // Parse the entry.
- EncodingReader reader(entry.data, fileLoc);
+ while (!worklist.empty()) {
+ size_t currentIndex = worklist.front();
+ worklist.pop_front();
+
+ if (currentIndex >= entries.size()) {
+ emitError(fileLoc) << "invalid " << entryType
+ << " index: " << currentIndex;
+ return {};
+ }
- // Parse based on how the entry was encoded.
- if (entry.hasCustomEncoding) {
- if (failed(parseCustomEntry(entry, reader, entryType)))
+ Entry<T> &entry = entries[currentIndex];
+
+ // If already resolved, continue.
+ if (entry.entry) {
+ inWorklist.erase(currentIndex);
+ continue;
+ }
+
+ // Clear the deferred worklist before parsing to capture any new entries.
+ deferredWorklist.clear();
+
+ // Parse the entry.
+ EncodingReader reader(entry.data, fileLoc);
+
+ // Parse based on how the entry was encoded.
+ LogicalResult parsed =
+ entry.hasCustomEncoding
+ ? parseCustomEntry(entry, reader, entryType, currentIndex)
+ : parseAsmEntry(entry.entry, reader, entryType);
+ bool parseSucceeded = succeeded(parsed);
+
+ if (parseSucceeded && !reader.empty()) {
+ reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+ parseSucceeded = false;
+ }
+
+ if (parseSucceeded && entry.entry) {
+ // Successfully parsed, done with this entry.
+ inWorklist.erase(currentIndex);
+ } else if (!deferredWorklist.empty()) {
+ // Check if deferred parsing was requested.
+
+ // Move this entry to the back to retry after dependencies.
+ worklist.push_back(currentIndex);
+
+ // Add dependencies to the front (in reverse so they maintain order).
+ for (uint64_t idx : llvm::reverse(deferredWorklist)) {
+ if (inWorklist.insert(idx).second)
+ worklist.push_front(idx);
+ }
+ deferredWorklist.clear();
+ } else {
+ // Parsing failed with no deferred entries which implies an error.
return T();
- } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
- return T();
+ }
}
- if (!reader.empty()) {
- reader.emitError("unexpected trailing bytes after " + entryType + " entry");
- return T();
- }
- return entry.entry;
+ return entries[index].entry;
}
template <typename T>
@@ -1296,11 +1387,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
}
template <typename T>
-LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
- EncodingReader &reader,
- StringRef entryType) {
+LogicalResult
+AttrTypeReader::parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
+ StringRef entryType, uint64_t index) {
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
- reader, bytecodeVersion);
+ reader, bytecodeVersion, index);
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index d7b442f6832d0..d5c6f010f5b8a 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -228,3 +228,40 @@ TEST(Bytecode, OpWithoutProperties) {
EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
OperationEquivalence::computeHash(roundtripped));
}
+
+TEST(Bytecode, DeepCallSiteLoc) {
+ MLIRContext context;
+ ParserConfig config(&context);
+
+ // Create a deep CallSiteLoc chain to test iterative parsing.
+ // Use a depth that fits in the stack for writing but is still substantial.
+ Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1);
+ Location loc = baseLoc;
+ constexpr int kDepth = 1000;
+ for (int i = 0; i < kDepth; ++i) {
+ loc = CallSiteLoc::get(loc, baseLoc);
+ }
+
+ // Create a simple module with the deep location.
+ OwningOpRef<Operation *> module =
+ parseSourceString<Operation *>("module {}", config);
+ ASSERT_TRUE(module);
+ module.get()->setLoc(loc);
+
+ // Write to bytecode.
+ std::string bytecode;
+ llvm::raw_string_ostream os(bytecode);
+ ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os)));
+
+ // Parse it back using the bytecode reader.
+ std::unique_ptr<Block> block = std::make_unique<Block>();
+ ASSERT_TRUE(succeeded(readBytecodeFile(
+ llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
+
+ // Verify we got the roundtripped module.
+ ASSERT_FALSE(block->empty());
+ Operation *roundTripped = &block->front();
+
+ // Verify the location matches.
+ EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc());
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/170993
More information about the Mlir-commits
mailing list