[Mlir-commits] [llvm] [mlir] [mlir] API to serialize bytecode to mmap'd buffer (PR #126953)
Nikhil Kalra
llvmlistbot at llvm.org
Wed Feb 12 12:52:25 PST 2025
https://github.com/nikalra updated https://github.com/llvm/llvm-project/pull/126953
>From e8aa8969d4b0cea12f91c6c68afd761d59f98ad6 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 12 Feb 2025 09:34:01 -0800
Subject: [PATCH 1/4] add api to write to mmap'd buffer
---
llvm/include/llvm/Support/raw_ostream.h | 13 ++++
llvm/lib/Support/raw_ostream.cpp | 13 ++++
mlir/include/mlir/Bytecode/BytecodeWriter.h | 6 ++
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 71 ++++++++++++++++++---
mlir/unittests/Bytecode/BytecodeTest.cpp | 22 +++++++
5 files changed, 117 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h
index d3b411590e7fd..90c0c013e38c8 100644
--- a/llvm/include/llvm/Support/raw_ostream.h
+++ b/llvm/include/llvm/Support/raw_ostream.h
@@ -13,6 +13,7 @@
#ifndef LLVM_SUPPORT_RAW_OSTREAM_H
#define LLVM_SUPPORT_RAW_OSTREAM_H
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DataTypes.h"
@@ -769,6 +770,18 @@ class buffer_unique_ostream : public raw_svector_ostream {
~buffer_unique_ostream() override { *OS << str(); }
};
+// Creates an output stream with a fixed size buffer.
+class fixed_buffer_ostream : public raw_ostream {
+ MutableArrayRef<std::byte> Buffer;
+ size_t Pos = 0;
+
+ void write_impl(const char *Ptr, size_t Size) final;
+ uint64_t current_pos() const final { return Pos; }
+
+public:
+ fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer);
+};
+
// Helper struct to add indentation to raw_ostream. Instead of
// OS.indent(6) << "more stuff";
// you can use
diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp
index e75ddc66b7d16..875c14782dd2e 100644
--- a/llvm/lib/Support/raw_ostream.cpp
+++ b/llvm/lib/Support/raw_ostream.cpp
@@ -1009,6 +1009,19 @@ void buffer_ostream::anchor() {}
void buffer_unique_ostream::anchor() {}
+void fixed_buffer_ostream::write_impl(const char *Ptr, size_t Size) {
+ if (Pos + Size <= Buffer.size()) {
+ memcpy((void *)(Buffer.data() + Pos), Ptr, Size);
+ Pos += Size;
+ } else {
+ report_fatal_error(
+ "Attempted to write past the end of the fixed size buffer.");
+ }
+}
+
+fixed_buffer_ostream::fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer)
+ : raw_ostream(true), Buffer{Buffer} {}
+
Error llvm::writeToOutput(StringRef OutputFileName,
std::function<Error(raw_ostream &)> Write) {
if (OutputFileName == "-")
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index c6cff0bc81314..4945adc3e9304 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -192,6 +192,12 @@ class BytecodeWriterConfig {
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config = {});
+/// Writes the bytecode for the given operation to a memory-mapped buffer.
+/// It only ever fails if setDesiredByteCodeVersion can't be honored.
+/// Returns nullptr on failure.
+std::shared_ptr<ArrayRef<std::byte>>
+writeBytecode(Operation *op, const BytecodeWriterConfig &config = {});
+
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEWRITER_H
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 2b4697434717d..c2a33e897ec07 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -20,8 +20,11 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/Memory.h"
#include "llvm/Support/raw_ostream.h"
+#include <cstddef>
#include <optional>
+#include <system_error>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -652,7 +655,7 @@ class BytecodeWriter {
propertiesSection(numberingState, stringSection, config.getImpl()) {}
/// Write the bytecode for the given root operation.
- LogicalResult write(Operation *rootOp, raw_ostream &os);
+ LogicalResult writeInto(Operation *rootOp, EncodingEmitter &emitter);
private:
//===--------------------------------------------------------------------===//
@@ -718,9 +721,8 @@ class BytecodeWriter {
};
} // namespace
-LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
- EncodingEmitter emitter;
-
+LogicalResult BytecodeWriter::writeInto(Operation *rootOp,
+ EncodingEmitter &emitter) {
// Emit the bytecode file header. This is how we identify the output as a
// bytecode file.
emitter.emitString("ML\xefR", "bytecode header");
@@ -761,9 +763,6 @@ LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
return rootOp->emitError(
"unexpected properties emitted incompatible with bytecode <5");
- // Write the generated bytecode to the provided output stream.
- emitter.writeTo(os);
-
return success();
}
@@ -1348,5 +1347,61 @@ void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
BytecodeWriter writer(op, config);
- return writer.write(op, os);
+ EncodingEmitter emitter;
+
+ if (succeeded(writer.writeInto(op, emitter))) {
+ emitter.writeTo(os);
+ return success();
+ }
+
+ return failure();
+}
+
+namespace {
+struct MemoryMappedBlock {
+ static std::shared_ptr<MemoryMappedBlock>
+ createMemoryMappedBlock(size_t numBytes) {
+ auto instance = std::make_shared<MemoryMappedBlock>();
+
+ std::error_code ec;
+ instance->mmapBlock =
+ llvm::sys::OwningMemoryBlock{llvm::sys::Memory::allocateMappedMemory(
+ numBytes, nullptr, llvm::sys::Memory::MF_WRITE, ec)};
+ if (ec)
+ return nullptr;
+
+ instance->writableView = MutableArrayRef<std::byte>(
+ (std::byte *)instance->mmapBlock.base(), numBytes);
+
+ return instance;
+ }
+
+ llvm::sys::OwningMemoryBlock mmapBlock;
+ MutableArrayRef<std::byte> writableView;
+};
+} // namespace
+
+std::shared_ptr<ArrayRef<std::byte>>
+mlir::writeBytecode(Operation *op, const BytecodeWriterConfig &config) {
+ BytecodeWriter writer(op, config);
+ EncodingEmitter emitter;
+ if (succeeded(writer.writeInto(op, emitter))) {
+ // Allocate a new memory block for the emitter to write into.
+ auto block = MemoryMappedBlock::createMemoryMappedBlock(emitter.size());
+ if (!block)
+ return nullptr;
+
+ // Wrap the block in an output stream.
+ llvm::fixed_buffer_ostream stream(block->writableView);
+ emitter.writeTo(stream);
+
+ // Write protect the block.
+ if (llvm::sys::Memory::protectMappedMemory(
+ block->mmapBlock.getMemoryBlock(), llvm::sys::Memory::MF_READ))
+ return nullptr;
+
+ return std::shared_ptr<ArrayRef<std::byte>>(block, &block->writableView);
+ }
+
+ return nullptr;
}
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index cb915a092a0be..a3c069fbcab58 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -16,9 +16,11 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include <cstring>
using namespace llvm;
using namespace mlir;
@@ -88,6 +90,26 @@ TEST(Bytecode, MultiModuleWithResource) {
checkResourceAttribute(*roundTripModule);
}
+TEST(Bytecode, WriteEquivalence) {
+ MLIRContext context;
+ Builder builder(&context);
+ ParserConfig parseConfig(&context);
+ OwningOpRef<Operation *> module =
+ parseSourceString<Operation *>(irWithResources, parseConfig);
+ ASSERT_TRUE(module);
+
+ // Write the module to bytecode
+ std::string buffer;
+ llvm::raw_string_ostream ostream(buffer);
+ ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
+
+ // Write the module to bytecode using the mmap API.
+ auto writeBuffer = writeBytecode(module.get());
+ ASSERT_TRUE(writeBuffer);
+ ASSERT_EQ(writeBuffer->size(), buffer.size());
+ ASSERT_EQ(memcmp(buffer.data(), writeBuffer->data(), writeBuffer->size()), 0);
+}
+
namespace {
/// A custom operation for the purpose of showcasing how discardable attributes
/// are handled in absence of properties.
>From 0906d51cd3e9c820515f18c04d695d600f303f97 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 12 Feb 2025 11:01:56 -0800
Subject: [PATCH 2/4] Revert "add api to write to mmap'd buffer"
This reverts commit e8aa8969d4b0cea12f91c6c68afd761d59f98ad6.
---
llvm/include/llvm/Support/raw_ostream.h | 13 ----
llvm/lib/Support/raw_ostream.cpp | 13 ----
mlir/include/mlir/Bytecode/BytecodeWriter.h | 6 --
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 71 +++------------------
mlir/unittests/Bytecode/BytecodeTest.cpp | 22 -------
5 files changed, 8 insertions(+), 117 deletions(-)
diff --git a/llvm/include/llvm/Support/raw_ostream.h b/llvm/include/llvm/Support/raw_ostream.h
index 90c0c013e38c8..d3b411590e7fd 100644
--- a/llvm/include/llvm/Support/raw_ostream.h
+++ b/llvm/include/llvm/Support/raw_ostream.h
@@ -13,7 +13,6 @@
#ifndef LLVM_SUPPORT_RAW_OSTREAM_H
#define LLVM_SUPPORT_RAW_OSTREAM_H
-#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DataTypes.h"
@@ -770,18 +769,6 @@ class buffer_unique_ostream : public raw_svector_ostream {
~buffer_unique_ostream() override { *OS << str(); }
};
-// Creates an output stream with a fixed size buffer.
-class fixed_buffer_ostream : public raw_ostream {
- MutableArrayRef<std::byte> Buffer;
- size_t Pos = 0;
-
- void write_impl(const char *Ptr, size_t Size) final;
- uint64_t current_pos() const final { return Pos; }
-
-public:
- fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer);
-};
-
// Helper struct to add indentation to raw_ostream. Instead of
// OS.indent(6) << "more stuff";
// you can use
diff --git a/llvm/lib/Support/raw_ostream.cpp b/llvm/lib/Support/raw_ostream.cpp
index 875c14782dd2e..e75ddc66b7d16 100644
--- a/llvm/lib/Support/raw_ostream.cpp
+++ b/llvm/lib/Support/raw_ostream.cpp
@@ -1009,19 +1009,6 @@ void buffer_ostream::anchor() {}
void buffer_unique_ostream::anchor() {}
-void fixed_buffer_ostream::write_impl(const char *Ptr, size_t Size) {
- if (Pos + Size <= Buffer.size()) {
- memcpy((void *)(Buffer.data() + Pos), Ptr, Size);
- Pos += Size;
- } else {
- report_fatal_error(
- "Attempted to write past the end of the fixed size buffer.");
- }
-}
-
-fixed_buffer_ostream::fixed_buffer_ostream(MutableArrayRef<std::byte> Buffer)
- : raw_ostream(true), Buffer{Buffer} {}
-
Error llvm::writeToOutput(StringRef OutputFileName,
std::function<Error(raw_ostream &)> Write) {
if (OutputFileName == "-")
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index 4945adc3e9304..c6cff0bc81314 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -192,12 +192,6 @@ class BytecodeWriterConfig {
LogicalResult writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config = {});
-/// Writes the bytecode for the given operation to a memory-mapped buffer.
-/// It only ever fails if setDesiredByteCodeVersion can't be honored.
-/// Returns nullptr on failure.
-std::shared_ptr<ArrayRef<std::byte>>
-writeBytecode(Operation *op, const BytecodeWriterConfig &config = {});
-
} // namespace mlir
#endif // MLIR_BYTECODE_BYTECODEWRITER_H
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index c2a33e897ec07..2b4697434717d 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -20,11 +20,8 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Endian.h"
-#include "llvm/Support/Memory.h"
#include "llvm/Support/raw_ostream.h"
-#include <cstddef>
#include <optional>
-#include <system_error>
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -655,7 +652,7 @@ class BytecodeWriter {
propertiesSection(numberingState, stringSection, config.getImpl()) {}
/// Write the bytecode for the given root operation.
- LogicalResult writeInto(Operation *rootOp, EncodingEmitter &emitter);
+ LogicalResult write(Operation *rootOp, raw_ostream &os);
private:
//===--------------------------------------------------------------------===//
@@ -721,8 +718,9 @@ class BytecodeWriter {
};
} // namespace
-LogicalResult BytecodeWriter::writeInto(Operation *rootOp,
- EncodingEmitter &emitter) {
+LogicalResult BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
+ EncodingEmitter emitter;
+
// Emit the bytecode file header. This is how we identify the output as a
// bytecode file.
emitter.emitString("ML\xefR", "bytecode header");
@@ -763,6 +761,9 @@ LogicalResult BytecodeWriter::writeInto(Operation *rootOp,
return rootOp->emitError(
"unexpected properties emitted incompatible with bytecode <5");
+ // Write the generated bytecode to the provided output stream.
+ emitter.writeTo(os);
+
return success();
}
@@ -1347,61 +1348,5 @@ void BytecodeWriter::writePropertiesSection(EncodingEmitter &emitter) {
LogicalResult mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
const BytecodeWriterConfig &config) {
BytecodeWriter writer(op, config);
- EncodingEmitter emitter;
-
- if (succeeded(writer.writeInto(op, emitter))) {
- emitter.writeTo(os);
- return success();
- }
-
- return failure();
-}
-
-namespace {
-struct MemoryMappedBlock {
- static std::shared_ptr<MemoryMappedBlock>
- createMemoryMappedBlock(size_t numBytes) {
- auto instance = std::make_shared<MemoryMappedBlock>();
-
- std::error_code ec;
- instance->mmapBlock =
- llvm::sys::OwningMemoryBlock{llvm::sys::Memory::allocateMappedMemory(
- numBytes, nullptr, llvm::sys::Memory::MF_WRITE, ec)};
- if (ec)
- return nullptr;
-
- instance->writableView = MutableArrayRef<std::byte>(
- (std::byte *)instance->mmapBlock.base(), numBytes);
-
- return instance;
- }
-
- llvm::sys::OwningMemoryBlock mmapBlock;
- MutableArrayRef<std::byte> writableView;
-};
-} // namespace
-
-std::shared_ptr<ArrayRef<std::byte>>
-mlir::writeBytecode(Operation *op, const BytecodeWriterConfig &config) {
- BytecodeWriter writer(op, config);
- EncodingEmitter emitter;
- if (succeeded(writer.writeInto(op, emitter))) {
- // Allocate a new memory block for the emitter to write into.
- auto block = MemoryMappedBlock::createMemoryMappedBlock(emitter.size());
- if (!block)
- return nullptr;
-
- // Wrap the block in an output stream.
- llvm::fixed_buffer_ostream stream(block->writableView);
- emitter.writeTo(stream);
-
- // Write protect the block.
- if (llvm::sys::Memory::protectMappedMemory(
- block->mmapBlock.getMemoryBlock(), llvm::sys::Memory::MF_READ))
- return nullptr;
-
- return std::shared_ptr<ArrayRef<std::byte>>(block, &block->writableView);
- }
-
- return nullptr;
+ return writer.write(op, os);
}
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index a3c069fbcab58..cb915a092a0be 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -16,11 +16,9 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
-#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include <cstring>
using namespace llvm;
using namespace mlir;
@@ -90,26 +88,6 @@ TEST(Bytecode, MultiModuleWithResource) {
checkResourceAttribute(*roundTripModule);
}
-TEST(Bytecode, WriteEquivalence) {
- MLIRContext context;
- Builder builder(&context);
- ParserConfig parseConfig(&context);
- OwningOpRef<Operation *> module =
- parseSourceString<Operation *>(irWithResources, parseConfig);
- ASSERT_TRUE(module);
-
- // Write the module to bytecode
- std::string buffer;
- llvm::raw_string_ostream ostream(buffer);
- ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
-
- // Write the module to bytecode using the mmap API.
- auto writeBuffer = writeBytecode(module.get());
- ASSERT_TRUE(writeBuffer);
- ASSERT_EQ(writeBuffer->size(), buffer.size());
- ASSERT_EQ(memcmp(buffer.data(), writeBuffer->data(), writeBuffer->size()), 0);
-}
-
namespace {
/// A custom operation for the purpose of showcasing how discardable attributes
/// are handled in absence of properties.
>From d105382fb9f974a72d132bcdf19ceee398124fe9 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 12 Feb 2025 12:48:02 -0800
Subject: [PATCH 3/4] use reserve extra space instead
---
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp | 3 ++
mlir/unittests/Bytecode/BytecodeTest.cpp | 33 +++++++++++++++++++--
2 files changed, 34 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index 2b4697434717d..cc5aaed416512 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -613,6 +613,9 @@ class RawEmitterOstream : public raw_ostream {
} // namespace
void EncodingEmitter::writeTo(raw_ostream &os) const {
+ // Reserve space in the ostream for the encoded contents.
+ os.reserveExtraSpace(size());
+
for (auto &prevResult : prevResultList)
os.write((const char *)prevResult.data(), prevResult.size());
os.write((const char *)currentResult.data(), currentResult.size());
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index cb915a092a0be..292e827544f19 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -17,6 +17,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/raw_ostream.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -37,6 +38,33 @@ module @TestDialectResources attributes {
#-}
)";
+struct AllocatingOstream final : public raw_ostream {
+ std::unique_ptr<std::byte[]> buffer;
+ size_t size = 0;
+
+ void reserveExtraSpace(uint64_t extraSpace) override {
+ ASSERT_TRUE(buffer == nullptr);
+ buffer = std::make_unique<std::byte[]>(extraSpace);
+ size = extraSpace;
+ }
+
+ AllocatingOstream() : raw_ostream(true) {}
+ uint64_t current_pos() const override { return size; }
+
+private:
+ size_t pos = 0;
+
+ void write_impl(const char *ptr, size_t length) override {
+ if (pos + length <= size) {
+ memcpy((void *)(buffer.get() + pos), ptr, length);
+ pos += length;
+ } else {
+ report_fatal_error(
+ "Attempted to write past the end of the fixed size buffer.");
+ }
+ }
+};
+
TEST(Bytecode, MultiModuleWithResource) {
MLIRContext context;
Builder builder(&context);
@@ -46,11 +74,12 @@ TEST(Bytecode, MultiModuleWithResource) {
ASSERT_TRUE(module);
// Write the module to bytecode
- std::string buffer;
- llvm::raw_string_ostream ostream(buffer);
+ AllocatingOstream ostream;
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
// Create copy of buffer which is aligned to requested resource alignment.
+ std::string buffer((char *)ostream.buffer.get(),
+ (char *)ostream.buffer.get() + ostream.size);
constexpr size_t kAlignment = 0x20;
size_t bufferSize = buffer.size();
buffer.reserve(bufferSize + kAlignment - 1);
>From 8aba4526148f35626c4d0c5a2e2bf089ffc9aa8d Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Wed, 12 Feb 2025 12:52:10 -0800
Subject: [PATCH 4/4] use mock
---
mlir/unittests/Bytecode/BytecodeTest.cpp | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp
index 292e827544f19..c036fe26b1b36 100644
--- a/mlir/unittests/Bytecode/BytecodeTest.cpp
+++ b/mlir/unittests/Bytecode/BytecodeTest.cpp
@@ -38,18 +38,14 @@ module @TestDialectResources attributes {
#-}
)";
-struct AllocatingOstream final : public raw_ostream {
+struct MockOstream final : public raw_ostream {
std::unique_ptr<std::byte[]> buffer;
size_t size = 0;
- void reserveExtraSpace(uint64_t extraSpace) override {
- ASSERT_TRUE(buffer == nullptr);
- buffer = std::make_unique<std::byte[]>(extraSpace);
- size = extraSpace;
- }
+ MOCK_METHOD(void, reserveExtraSpace, (uint64_t extraSpace), (override));
- AllocatingOstream() : raw_ostream(true) {}
- uint64_t current_pos() const override { return size; }
+ MockOstream() : raw_ostream(true) {}
+ uint64_t current_pos() const override { return pos; }
private:
size_t pos = 0;
@@ -73,8 +69,12 @@ TEST(Bytecode, MultiModuleWithResource) {
parseSourceString<Operation *>(irWithResources, parseConfig);
ASSERT_TRUE(module);
- // Write the module to bytecode
- AllocatingOstream ostream;
+ // Write the module to bytecode.
+ MockOstream ostream;
+ EXPECT_CALL(ostream, reserveExtraSpace).WillOnce([&](uint64_t space) {
+ ostream.buffer = std::make_unique<std::byte[]>(space);
+ ostream.size = space;
+ });
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
// Create copy of buffer which is aligned to requested resource alignment.
More information about the Mlir-commits
mailing list