[llvm] Fix compress/decompress in LLVM Offloading API (PR #150064)

David Salinas via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 19 08:39:22 PDT 2025


================
@@ -259,87 +295,288 @@ static std::string formatWithCommas(unsigned long long Value) {
   return Num;
 }
 
-llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
-                                    bool Verbose) {
-  StringRef Blob = Input.getBuffer();
+Expected<std::unique_ptr<MemoryBuffer>>
+CompressedOffloadBundle::compress(compression::Params P,
+                                  const MemoryBuffer &Input,
+                                  uint16_t Version, bool Verbose) {
+  if (!compression::zstd::isAvailable() &&
+      !compression::zlib::isAvailable())
+    return createStringError(
+                             "Compression not supported");
+  Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+                        OffloadBundlerTimerGroup);
+  if (Verbose)
+    HashTimer.startTimer();
+  MD5 Hash;
+  MD5::MD5Result Result;
+  Hash.update(Input.getBuffer());
+  Hash.final(Result);
+  uint64_t TruncatedHash = Result.low();
+  if (Verbose)
+    HashTimer.stopTimer();
+
+  SmallVector<uint8_t, 0> CompressedBuffer;
+  auto BufferUint8 = ArrayRef<uint8_t>(
+      reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+      Input.getBuffer().size());
+  Timer CompressTimer("Compression Timer", "Compression time",
+                            OffloadBundlerTimerGroup);
+  if (Verbose)
+    CompressTimer.startTimer();
+  compression::compress(P, BufferUint8, CompressedBuffer);
+  if (Verbose)
+    CompressTimer.stopTimer();
 
-  if (Blob.size() < V1HeaderSize)
-    return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+  uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
 
-  if (llvm::identify_magic(Blob) !=
-      llvm::file_magic::offload_bundle_compressed) {
-    if (Verbose)
-      llvm::errs() << "Uncompressed bundle.\n";
-    return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+  // Store sizes in 64-bit variables first.
+  uint64_t UncompressedSize64 = Input.getBuffer().size();
+  uint64_t TotalFileSize64;
+
+  // Calculate total file size based on version.
+  if (Version == 2) {
+    // For V2, ensure the sizes don't exceed 32-bit limit.
+    if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
+      return createStringError(inconvertibleErrorCode(),
+                               "Uncompressed size exceeds version 2 limit");
+    if ((MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+         sizeof(CompressionMethod) + sizeof(uint32_t) + sizeof(TruncatedHash) +
+         CompressedBuffer.size()) > std::numeric_limits<uint32_t>::max())
+      return createStringError(inconvertibleErrorCode(),
+                               "Total file size exceeds version 2 limit");
+
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint32_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
+  } else { // Version 3
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint64_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
   }
 
-  size_t CurrentOffset = MagicSize;
+  SmallVector<char, 0> FinalBuffer;
+  raw_svector_ostream OS(FinalBuffer);
+  OS << MagicNumber;
+  OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
+  OS.write(reinterpret_cast<const char *>(&CompressionMethod),
+           sizeof(CompressionMethod));
 
-  uint16_t ThisVersion;
-  memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
-  CurrentOffset += VersionFieldSize;
+  // Write size fields according to version
----------------
david-salinas wrote:

ok

https://github.com/llvm/llvm-project/pull/150064


More information about the llvm-commits mailing list