[Mlir-commits] [mlir] [MLIR][Bytecode] Enforce alignment requirements (PR #157004)
Nikhil Kalra
llvmlistbot at llvm.org
Thu Sep 4 22:09:41 PDT 2025
https://github.com/nikalra updated https://github.com/llvm/llvm-project/pull/157004
>From a26424427b53440c7c22e9899c9d46e895e34f3f Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 4 Sep 2025 19:55:53 -0700
Subject: [PATCH 1/2] [MLIR][Bytecode] Enforce alignment requirements
Adds a check that the bytecode buffer is aligned to any section alignment requirements. Without this check, if the source buffer is not sufficiently aligned, we may early return when aligning the data pointer. In that case, we may end up trying to read successive sections from an incorrect offset, giving the appearance of invalid bytecode.
This requirement is documented in the bytecode unit tests, but is not otherwise documented in the code or Bytecode reference.
---
mlir/docs/BytecodeFormat.md | 2 +-
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 75 +++++++++++++++++++--
mlir/unittests/Bytecode/BytecodeTest.cpp | 53 +++++++++++++++
3 files changed, 124 insertions(+), 6 deletions(-)
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index ebc94c9f0d8ba..9846df8726295 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -125,7 +125,7 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit
indicates if the section has alignment requirements, a length (which allows for
skipping over the section), and an optional alignment. When an alignment is
present, a variable number of padding bytes (0xCB) may appear before the section
-data. The alignment of a section must be a power of 2.
+data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section.
## MLIR Encoding
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d010c6c8..4156f05b45301 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -22,10 +22,13 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
#include <cstddef>
+#include <cstdint>
#include <list>
#include <memory>
#include <numeric>
@@ -111,6 +114,9 @@ class EncodingReader {
};
// Shift the reader position to the next alignment boundary.
+ // Note: this assumes the pointer alignment matches the alignment of the
+ // data from the start of the buffer. In other words, this code is only
+ // valid if the buffer `dataIt` is offsetting into is already aligned.
while (isUnaligned(dataIt)) {
uint8_t padding;
if (failed(parseByte(padding)))
@@ -258,9 +264,13 @@ class EncodingReader {
return success();
}
+ /// Validate that the alignment requested in the section is valid.
+ using ValidateAlignment = function_ref<LogicalResult(unsigned alignment)>;
+
/// Parse a section header, placing the kind of section in `sectionID` and the
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID §ionID,
+ ValidateAlignment alignmentValidator,
ArrayRef<uint8_t> §ionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
@@ -281,8 +291,22 @@ class EncodingReader {
// Process the section alignment if present.
if (hasAlignment) {
+ // Read the requested alignment from the bytecode parser.
uint64_t alignment;
- if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
+ if (failed(parseVarInt(alignment)))
+ return failure();
+
+ // Check that the requested alignment is less than or equal to the
+ // alignment of the root buffer. If it is not, we cannot safely guarantee
+ // that the specified alignment is globally correct.
+ //
+ // e.g. if the buffer is 8k aligned and the section is 16k aligned,
+ // we could end up at an offset of 24k, which is not globally 16k aligned.
+ if (failed(alignmentValidator(alignment)))
+ return emitError("failed to align section ID: ", unsigned(sectionID));
+
+ // Align the buffer.
+ if (failed(alignTo(alignment)))
return failure();
}
@@ -1396,6 +1420,30 @@ class mlir::BytecodeReader::Impl {
return success();
}
+ LogicalResult checkSectionAlignment(
+ unsigned alignment,
+ function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
+ // Check that the bytecode buffer meets
+ // the requested section alignment.
+ //
+ // If it does not, the virtual address of the item in the section will
+ // not be aligned to the requested alignment.
+ //
+ // The typical case where this is necessary is the resource blob
+ // optimization in `parseAsBlob` where we reference the weights from the
+ // provided buffer instead of copying them to a new allocation.
+ const bool isGloballyAligned =
+ ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
+
+ if (!isGloballyAligned)
+ return emitError("expected section alignment ")
+ << alignment << " but bytecode buffer 0x"
+ << Twine::utohexstr((uint64_t)buffer.getBufferStart())
+ << " is not aligned";
+
+ return success();
+ };
+
/// Return the context for this config.
MLIRContext *getContext() const { return config.getContext(); }
@@ -1506,7 +1554,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1651,6 +1699,11 @@ LogicalResult BytecodeReader::Impl::read(
return failure();
});
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return reader.emitError(msg); });
+ };
+
// Parse the raw data for each of the top-level sections of the bytecode.
std::optional<ArrayRef<uint8_t>>
sectionDatas[bytecode::Section::kNumSections];
@@ -1658,7 +1711,8 @@ LogicalResult BytecodeReader::Impl::read(
// Read the next section from the bytecode.
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(
+ reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
return failure();
// Check for duplicate sections, we only expect one instance of each.
@@ -1778,6 +1832,11 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
dialects.resize(numDialects);
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
+ };
+
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
@@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(sectionReader.parseSection(sectionID,
+ if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
@@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
LogicalResult
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
RegionReadState &readState) {
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
+ };
+
// Process regions, blocks, and operations until the end or if a nested
// region is encountered. In this case we push a new state in regionStack and
// return, the processing of the current region will resume afterward.
@@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(reader.parseSection(sectionID, checkSectionAlignment,
+ sectionData)))
return failure();
if (sectionID != bytecode::Section::kIR)
return emitError(fileLoc, "expected IR section for region");
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index c036fe26b1b36..9ea6560f712a1 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -10,11 +10,13 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
@@ -117,6 +119,57 @@ TEST(Bytecode, MultiModuleWithResource) {
checkResourceAttribute(*roundTripModule);
}
+TEST(Bytecode, AlignmentFailure) {
+ MLIRContext context;
+ Builder builder(&context);
+ ParserConfig parseConfig(&context);
+ OwningOpRef<Operation *> module =
+ parseSourceString<Operation *>(irWithResources, parseConfig);
+ ASSERT_TRUE(module);
+
+ // Write the module to bytecode.
+ MockOstream ostream;
+ EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
+ ostream.buffer = std::make_unique<std::byte[]>(space);
+ ostream.size = space;
+ });
+ ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
+
+ // Create copy of buffer which is not aligned to requested resource alignment.
+ std::string buffer((char *)ostream.buffer.get(),
+ (char *)ostream.buffer.get() + ostream.size);
+ size_t bufferSize = buffer.size();
+
+ // Increment into the buffer until we get to a power of 2 alignment that is
+ // not 32 bit aligned.
+ size_t pad = 0;
+ while (true) {
+ if (llvm::isAddrAligned(Align(2), &buffer[pad]) &&
+ !llvm::isAddrAligned(Align(32), &buffer[pad]))
+ break;
+
+ pad++;
+ buffer.reserve(bufferSize + pad);
+ }
+
+ buffer.insert(0, pad, ' ');
+ StringRef alignedBuffer(buffer.data() + pad, bufferSize);
+
+ // Attach a diagnostic handler to get the error message.
+ llvm::SmallVector<std::string> msg;
+ ScopedDiagnosticHandler handler(
+ &context, [&msg](Diagnostic &diag) { msg.push_back(diag.str()); });
+
+ // Parse it back
+ OwningOpRef<Operation *> roundTripModule =
+ parseSourceString<Operation *>(alignedBuffer, parseConfig);
+ ASSERT_FALSE(roundTripModule);
+ ASSERT_THAT(msg[0].data(), ::testing::StartsWith(
+ "expected section alignment 32 but bytecode "
+ "buffer"));
+ ASSERT_STREQ(msg[1].data(), "failed to align section ID: 5");
+}
+
namespace {
/// A custom operation for the purpose of showcasing how discardable attributes
/// are handled in absence of properties.
>From 09ba822e15a2c615e0c8ad7bdb41bfd4ea18b293 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Thu, 4 Sep 2025 22:09:05 -0700
Subject: [PATCH 2/2] address comments
---
mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 4156f05b45301..d29053a2b6e65 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -116,7 +116,7 @@ class EncodingReader {
// Shift the reader position to the next alignment boundary.
// Note: this assumes the pointer alignment matches the alignment of the
// data from the start of the buffer. In other words, this code is only
- // valid if the buffer `dataIt` is offsetting into is already aligned.
+ // valid if `dataIt` is offsetting into an already aligned buffer.
while (isUnaligned(dataIt)) {
uint8_t padding;
if (failed(parseByte(padding)))
@@ -265,12 +265,12 @@ class EncodingReader {
}
/// Validate that the alignment requested in the section is valid.
- using ValidateAlignment = function_ref<LogicalResult(unsigned alignment)>;
+ using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
/// Parse a section header, placing the kind of section in `sectionID` and the
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID §ionID,
- ValidateAlignment alignmentValidator,
+ ValidateAlignmentFn alignmentValidator,
ArrayRef<uint8_t> §ionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
@@ -300,7 +300,7 @@ class EncodingReader {
// alignment of the root buffer. If it is not, we cannot safely guarantee
// that the specified alignment is globally correct.
//
- // e.g. if the buffer is 8k aligned and the section is 16k aligned,
+ // E.g. if the buffer is 8k aligned and the section is 16k aligned,
// we could end up at an offset of 24k, which is not globally 16k aligned.
if (failed(alignmentValidator(alignment)))
return emitError("failed to align section ID: ", unsigned(sectionID));
@@ -1423,8 +1423,7 @@ class mlir::BytecodeReader::Impl {
LogicalResult checkSectionAlignment(
unsigned alignment,
function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
- // Check that the bytecode buffer meets
- // the requested section alignment.
+ // Check that the bytecode buffer meets the requested section alignment.
//
// If it does not, the virtual address of the item in the section will
// not be aligned to the requested alignment.
@@ -1833,8 +1832,9 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
dialects.resize(numDialects);
const auto checkSectionAlignment = [&](unsigned alignment) {
- return this->checkSectionAlignment(
- alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
+ return this->checkSectionAlignment(alignment, [&](const auto &msg) {
+ return sectionReader.emitError(msg);
+ });
};
// Parse each of the dialects.
More information about the Mlir-commits
mailing list