[Mlir-commits] [mlir] 8106c81 - [MLIR][Bytecode] Enforce alignment requirements (#157004)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 4 22:31:44 PDT 2025
Author: Nikhil Kalra
Date: 2025-09-04T22:31:40-07:00
New Revision: 8106c816eb8b279ee4220936c43a0e495d1bb1a0
URL: https://github.com/llvm/llvm-project/commit/8106c816eb8b279ee4220936c43a0e495d1bb1a0
DIFF: https://github.com/llvm/llvm-project/commit/8106c816eb8b279ee4220936c43a0e495d1bb1a0.diff
LOG: [MLIR][Bytecode] Enforce alignment requirements (#157004)
Adds a check that the bytecode buffer is aligned to any section
alignment requirements. Without this check, if the source buffer is not
sufficiently aligned, we may return early when aligning the data
pointer. In that case, we may end up trying to read successive sections
from an incorrect offset, giving the appearance of invalid bytecode.
This requirement is documented in the bytecode unit tests, but is not
otherwise documented in the code or bytecode reference.
Added:
Modified:
mlir/docs/BytecodeFormat.md
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/unittests/Bytecode/BytecodeTest.cpp
Removed:
################################################################################
diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
index ebc94c9f0d8ba..9846df8726295 100644
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -125,7 +125,7 @@ lazy-loading, and more. Each section contains a Section ID, whose high bit
indicates if the section has alignment requirements, a length (which allows for
skipping over the section), and an optional alignment. When an alignment is
present, a variable number of padding bytes (0xCB) may appear before the section
-data. The alignment of a section must be a power of 2.
+data. The alignment of a section must be a power of 2. The input bytecode buffer must satisfy the same alignment requirements as those of every section.
## MLIR Encoding
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
index 44458d010c6c8..d29053a2b6e65 100644
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -22,10 +22,13 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
#include <cstddef>
+#include <cstdint>
#include <list>
#include <memory>
#include <numeric>
@@ -111,6 +114,9 @@ class EncodingReader {
};
// Shift the reader position to the next alignment boundary.
+ // Note: this assumes the pointer alignment matches the alignment of the
+ // data from the start of the buffer. In other words, this code is only
+ // valid if `dataIt` is offsetting into an already aligned buffer.
while (isUnaligned(dataIt)) {
uint8_t padding;
if (failed(parseByte(padding)))
@@ -258,9 +264,13 @@ class EncodingReader {
return success();
}
+ /// Validate that the alignment requested in the section is valid.
+ using ValidateAlignmentFn = function_ref<LogicalResult(unsigned alignment)>;
+
/// Parse a section header, placing the kind of section in `sectionID` and the
/// contents of the section in `sectionData`.
LogicalResult parseSection(bytecode::Section::ID §ionID,
+ ValidateAlignmentFn alignmentValidator,
ArrayRef<uint8_t> §ionData) {
uint8_t sectionIDAndHasAlignment;
uint64_t length;
@@ -281,8 +291,22 @@ class EncodingReader {
// Process the section alignment if present.
if (hasAlignment) {
+ // Read the requested alignment from the bytecode parser.
uint64_t alignment;
- if (failed(parseVarInt(alignment)) || failed(alignTo(alignment)))
+ if (failed(parseVarInt(alignment)))
+ return failure();
+
+ // Check that the requested alignment is less than or equal to the
+ // alignment of the root buffer. If it is not, we cannot safely guarantee
+ // that the specified alignment is globally correct.
+ //
+ // E.g. if the buffer is 8k aligned and the section is 16k aligned,
+ // we could end up at an offset of 24k, which is not globally 16k aligned.
+ if (failed(alignmentValidator(alignment)))
+ return emitError("failed to align section ID: ", unsigned(sectionID));
+
+ // Align the buffer.
+ if (failed(alignTo(alignment)))
return failure();
}
@@ -1396,6 +1420,29 @@ class mlir::BytecodeReader::Impl {
return success();
}
+ LogicalResult checkSectionAlignment(
+ unsigned alignment,
+ function_ref<InFlightDiagnostic(const Twine &error)> emitError) {
+ // Check that the bytecode buffer meets the requested section alignment.
+ //
+ // If it does not, the virtual address of the item in the section will
+ // not be aligned to the requested alignment.
+ //
+ // The typical case where this is necessary is the resource blob
+ // optimization in `parseAsBlob` where we reference the weights from the
+ // provided buffer instead of copying them to a new allocation.
+ const bool isGloballyAligned =
+ ((uintptr_t)buffer.getBufferStart() & (alignment - 1)) == 0;
+
+ if (!isGloballyAligned)
+ return emitError("expected section alignment ")
+ << alignment << " but bytecode buffer 0x"
+ << Twine::utohexstr((uint64_t)buffer.getBufferStart())
+ << " is not aligned";
+
+ return success();
+ };
+
/// Return the context for this config.
MLIRContext *getContext() const { return config.getContext(); }
@@ -1506,7 +1553,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
- isIndexPairEncoding(isIndexPairEncoding){};
+ isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
@@ -1651,6 +1698,11 @@ LogicalResult BytecodeReader::Impl::read(
return failure();
});
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return reader.emitError(msg); });
+ };
+
// Parse the raw data for each of the top-level sections of the bytecode.
std::optional<ArrayRef<uint8_t>>
sectionDatas[bytecode::Section::kNumSections];
@@ -1658,7 +1710,8 @@ LogicalResult BytecodeReader::Impl::read(
// Read the next section from the bytecode.
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(
+ reader.parseSection(sectionID, checkSectionAlignment, sectionData)))
return failure();
// Check for duplicate sections, we only expect one instance of each.
@@ -1778,6 +1831,12 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
dialects.resize(numDialects);
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(alignment, [&](const auto &msg) {
+ return sectionReader.emitError(msg);
+ });
+ };
+
// Parse each of the dialects.
for (uint64_t i = 0; i < numDialects; ++i) {
dialects[i] = std::make_unique<BytecodeDialect>();
@@ -1800,7 +1859,7 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return failure();
if (versionAvailable) {
bytecode::Section::ID sectionID;
- if (failed(sectionReader.parseSection(sectionID,
+ if (failed(sectionReader.parseSection(sectionID, checkSectionAlignment,
dialects[i]->versionBuffer)))
return failure();
if (sectionID != bytecode::Section::kDialectVersions) {
@@ -2121,6 +2180,11 @@ BytecodeReader::Impl::parseIRSection(ArrayRef<uint8_t> sectionData,
LogicalResult
BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
RegionReadState &readState) {
+ const auto checkSectionAlignment = [&](unsigned alignment) {
+ return this->checkSectionAlignment(
+ alignment, [&](const auto &msg) { return emitError(fileLoc, msg); });
+ };
+
// Process regions, blocks, and operations until the end or if a nested
// region is encountered. In this case we push a new state in regionStack and
// return, the processing of the current region will resume afterward.
@@ -2161,7 +2225,8 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> ®ionStack,
if (version >= bytecode::kLazyLoading && isIsolatedFromAbove) {
bytecode::Section::ID sectionID;
ArrayRef<uint8_t> sectionData;
- if (failed(reader.parseSection(sectionID, sectionData)))
+ if (failed(reader.parseSection(sectionID, checkSectionAlignment,
+ sectionData)))
return failure();
if (sectionID != bytecode::Section::kIR)
return emitError(fileLoc, "expected IR section for region");
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index c036fe26b1b36..9ea6560f712a1 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -10,11 +10,13 @@
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Alignment.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/raw_ostream.h"
@@ -117,6 +119,57 @@ TEST(Bytecode, MultiModuleWithResource) {
checkResourceAttribute(*roundTripModule);
}
+TEST(Bytecode, AlignmentFailure) {
+ MLIRContext context;
+ Builder builder(&context);
+ ParserConfig parseConfig(&context);
+ OwningOpRef<Operation *> module =
+ parseSourceString<Operation *>(irWithResources, parseConfig);
+ ASSERT_TRUE(module);
+
+ // Write the module to bytecode.
+ MockOstream ostream;
+ EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
+ ostream.buffer = std::make_unique<std::byte[]>(space);
+ ostream.size = space;
+ });
+ ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
+
+ // Create copy of buffer which is not aligned to requested resource alignment.
+ std::string buffer((char *)ostream.buffer.get(),
+ (char *)ostream.buffer.get() + ostream.size);
+ size_t bufferSize = buffer.size();
+
+ // Increment into the buffer until we get to a power of 2 alignment that is
+ // not 32 bit aligned.
+ size_t pad = 0;
+ while (true) {
+ if (llvm::isAddrAligned(Align(2), &buffer[pad]) &&
+ !llvm::isAddrAligned(Align(32), &buffer[pad]))
+ break;
+
+ pad++;
+ buffer.reserve(bufferSize + pad);
+ }
+
+ buffer.insert(0, pad, ' ');
+ StringRef alignedBuffer(buffer.data() + pad, bufferSize);
+
+ // Attach a diagnostic handler to get the error message.
+ llvm::SmallVector<std::string> msg;
+ ScopedDiagnosticHandler handler(
+ &context, [&msg](Diagnostic &diag) { msg.push_back(diag.str()); });
+
+ // Parse it back
+ OwningOpRef<Operation *> roundTripModule =
+ parseSourceString<Operation *>(alignedBuffer, parseConfig);
+ ASSERT_FALSE(roundTripModule);
+ ASSERT_THAT(msg[0].data(), ::testing::StartsWith(
+ "expected section alignment 32 but bytecode "
+ "buffer"));
+ ASSERT_STREQ(msg[1].data(), "failed to align section ID: 5");
+}
+
namespace {
/// A custom operation for the purpose of showcasing how discardable attributes
/// are handled in absence of properties.
More information about the Mlir-commits
mailing list