[llvm] Fix compress/decompress in LLVM Offloading API (PR #150064)
David Salinas via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 18 12:57:57 PDT 2025
================
@@ -259,87 +293,282 @@ static std::string formatWithCommas(unsigned long long Value) {
return Num;
}
-llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
- bool Verbose) {
- StringRef Blob = Input.getBuffer();
+Expected<std::unique_ptr<MemoryBuffer>>
+CompressedOffloadBundle::compress(compression::Params P,
+ const MemoryBuffer &Input, uint16_t Version,
+ raw_ostream *VerboseStream) {
+ if (!compression::zstd::isAvailable() && !compression::zlib::isAvailable())
+ return createStringError("compression not supported.");
+ Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+ OffloadBundlerTimerGroup);
+ if (VerboseStream)
+ HashTimer.startTimer();
+ MD5 Hash;
+ MD5::MD5Result Result;
+ Hash.update(Input.getBuffer());
+ Hash.final(Result);
+ uint64_t TruncatedHash = Result.low();
+ if (VerboseStream)
+ HashTimer.stopTimer();
- if (Blob.size() < V1HeaderSize)
- return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+ SmallVector<uint8_t, 0> CompressedBuffer;
+ auto BufferUint8 = ArrayRef<uint8_t>(
+ reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+ Input.getBuffer().size());
+ Timer CompressTimer("Compression Timer", "Compression time",
+ OffloadBundlerTimerGroup);
+ if (VerboseStream)
+ CompressTimer.startTimer();
+ compression::compress(P, BufferUint8, CompressedBuffer);
+ if (VerboseStream)
+ CompressTimer.stopTimer();
- if (llvm::identify_magic(Blob) !=
- llvm::file_magic::offload_bundle_compressed) {
- if (Verbose)
- llvm::errs() << "Uncompressed bundle.\n";
- return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+ uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
+
+ // Store sizes in 64-bit variables first.
+ uint64_t UncompressedSize64 = Input.getBuffer().size();
+ uint64_t TotalFileSize64;
+
+ // Calculate total file size based on version.
+ if (Version == 2) {
+ // For V2, ensure the sizes don't exceed 32-bit limit.
+ if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
+ return createStringError(inconvertibleErrorCode(),
+ "uncompressed size exceeds version 2 limit");
----------------
david-salinas wrote:
ok
https://github.com/llvm/llvm-project/pull/150064
More information about the llvm-commits
mailing list