[llvm] Fix compress/decompress in LLVM Offloading API (PR #150064)

James Henderson via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 23 00:22:31 PDT 2025


================
@@ -260,11 +299,233 @@ static std::string formatWithCommas(unsigned long long Value) {
 }
 
 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
+CompressedOffloadBundle::compress(llvm::compression::Params P,
+                                  const llvm::MemoryBuffer &Input,
+                                  uint16_t Version, bool Verbose) {
+  if (!llvm::compression::zstd::isAvailable() &&
+      !llvm::compression::zlib::isAvailable())
+    return createStringError(llvm::inconvertibleErrorCode(),
+                             "Compression not supported");
+  llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+                        OffloadBundlerTimerGroup);
+  if (Verbose)
+    HashTimer.startTimer();
+  llvm::MD5 Hash;
+  llvm::MD5::MD5Result Result;
+  Hash.update(Input.getBuffer());
+  Hash.final(Result);
+  uint64_t TruncatedHash = Result.low();
+  if (Verbose)
+    HashTimer.stopTimer();
+
+  SmallVector<uint8_t, 0> CompressedBuffer;
+  auto BufferUint8 = llvm::ArrayRef<uint8_t>(
+      reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+      Input.getBuffer().size());
+  llvm::Timer CompressTimer("Compression Timer", "Compression time",
+                            OffloadBundlerTimerGroup);
+  if (Verbose)
+    CompressTimer.startTimer();
+  llvm::compression::compress(P, BufferUint8, CompressedBuffer);
+  if (Verbose)
+    CompressTimer.stopTimer();
+
+  uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
+
+  // Store sizes in 64-bit variables first
+  uint64_t UncompressedSize64 = Input.getBuffer().size();
+  uint64_t TotalFileSize64;
+
+  // Calculate total file size based on version
+  if (Version == 2) {
+    // For V2, ensure the sizes don't exceed 32-bit limit
+    if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
+      return createStringError(llvm::inconvertibleErrorCode(),
+                               "Uncompressed size exceeds version 2 limit");
+    if ((MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+         sizeof(CompressionMethod) + sizeof(uint32_t) + sizeof(TruncatedHash) +
+         CompressedBuffer.size()) > std::numeric_limits<uint32_t>::max())
+      return createStringError(llvm::inconvertibleErrorCode(),
+                               "Total file size exceeds version 2 limit");
+
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint32_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
+  } else { // Version 3
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint64_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
+  }
+
+  SmallVector<char, 0> FinalBuffer;
+  llvm::raw_svector_ostream OS(FinalBuffer);
+  OS << MagicNumber;
+  OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
+  OS.write(reinterpret_cast<const char *>(&CompressionMethod),
+           sizeof(CompressionMethod));
+
+  // Write size fields according to version
+  if (Version == 2) {
+    uint32_t TotalFileSize32 = static_cast<uint32_t>(TotalFileSize64);
+    uint32_t UncompressedSize32 = static_cast<uint32_t>(UncompressedSize64);
+    OS.write(reinterpret_cast<const char *>(&TotalFileSize32),
+             sizeof(TotalFileSize32));
+    OS.write(reinterpret_cast<const char *>(&UncompressedSize32),
+             sizeof(UncompressedSize32));
+  } else { // Version 3
+    OS.write(reinterpret_cast<const char *>(&TotalFileSize64),
+             sizeof(TotalFileSize64));
+    OS.write(reinterpret_cast<const char *>(&UncompressedSize64),
+             sizeof(UncompressedSize64));
+  }
+
+  OS.write(reinterpret_cast<const char *>(&TruncatedHash),
+           sizeof(TruncatedHash));
+  OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
+           CompressedBuffer.size());
+
+  if (Verbose) {
+    auto MethodUsed =
+        P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
+    double CompressionRate =
+        static_cast<double>(UncompressedSize64) / CompressedBuffer.size();
+    double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
+    double CompressionSpeedMBs =
+        (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds;
+    llvm::errs() << "Compressed bundle format version: " << Version << "\n"
+                 << "Total file size (including headers): "
+                 << formatWithCommas(TotalFileSize64) << " bytes\n"
+                 << "Compression method used: " << MethodUsed << "\n"
+                 << "Compression level: " << P.level << "\n"
+                 << "Binary size before compression: "
+                 << formatWithCommas(UncompressedSize64) << " bytes\n"
+                 << "Binary size after compression: "
+                 << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
+                 << "Compression rate: "
+                 << llvm::format("%.2lf", CompressionRate) << "\n"
+                 << "Compression ratio: "
+                 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
+                 << "Compression speed: "
+                 << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
+                 << "Truncated MD5 hash: "
+                 << llvm::format_hex(TruncatedHash, 16) << "\n";
+  }
+
+  return llvm::MemoryBuffer::getMemBufferCopy(
+      llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
+}
+
+// Use packed structs to avoid padding, such that the structs map the serialized
+// format.
+LLVM_PACKED_START
+union RawCompressedBundleHeader {
+  struct CommonFields {
+    uint32_t Magic;
+    uint16_t Version;
+    uint16_t Method;
+  };
+
+  struct V1Header {
+    CommonFields Common;
+    uint32_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  struct V2Header {
+    CommonFields Common;
+    uint32_t FileSize;
+    uint32_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  struct V3Header {
+    CommonFields Common;
+    uint64_t FileSize;
+    uint64_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  CommonFields Common;
+  V1Header V1;
+  V2Header V2;
+  V3Header V3;
+};
+LLVM_PACKED_END
+
+// Helper method to get header size based on version
----------------
jh7370 wrote:

```suggestion
// Helper method to get header size based on version.
```
And in many other places.

https://github.com/llvm/llvm-project/pull/150064


More information about the llvm-commits mailing list