[llvm] Extend LLVM Offloading API for binary fatbin Bundles (PR #114833)

David Salinas via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 4 09:19:26 PST 2024


https://github.com/david-salinas created https://github.com/llvm/llvm-project/pull/114833

  With the intention to provide a common API for offloading, this
  extension to the existing LLVM Offloading API adds support for
  Binary Fatbin Bundles; moving some support from the Clang offloading
  API.  The intention is to add functionality to LLVM tooling for
  Binary Fatbin Bundles in subsequent commits.

Change-Id: I907fdcbcd0545162a0ce1cf17ebf7c9f3a4dbde6

>From 71c9c5cb43c750ce35136b183d28a8a138d6f6e5 Mon Sep 17 00:00:00 2001
From: David <dsalinas at amd.com>
Date: Wed, 28 Aug 2024 04:01:38 +0100
Subject: [PATCH] Extend LLVM Offloading API for binary fatbin Bundles

  With the intention to provide a common API for offloading, this
  extension to the existing LLVM Offloading API adds support for
  Binary Fatbin Bundles; moving some support from the Clang offloading
  API.  The intention is to add functionality to LLVM tooling for
  Binary Fatbin Bundles in subsequent commits.

Change-Id: I907fdcbcd0545162a0ce1cf17ebf7c9f3a4dbde6
---
 llvm/include/llvm/Object/OffloadBinary.h | 142 ++++++++
 llvm/lib/Object/OffloadBinary.cpp        | 441 +++++++++++++++++++++++
 2 files changed, 583 insertions(+)

diff --git a/llvm/include/llvm/Object/OffloadBinary.h b/llvm/include/llvm/Object/OffloadBinary.h
index c02aec8d956ed6..c63ef4824bc7a4 100644
--- a/llvm/include/llvm/Object/OffloadBinary.h
+++ b/llvm/include/llvm/Object/OffloadBinary.h
@@ -17,10 +17,12 @@
 #ifndef LLVM_OBJECT_OFFLOADBINARY_H
 #define LLVM_OBJECT_OFFLOADBINARY_H
 
+#include "llvm/Support/Compression.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Object/Binary.h"
+#include "llvm/Object/ObjectFile.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include <memory>
@@ -49,6 +51,31 @@ enum ImageKind : uint16_t {
   IMG_LAST,
 };
 
+class CompressedOffloadBundle {
+private:
+  static inline const size_t MagicSize = 4;
+  static inline const size_t VersionFieldSize = sizeof(uint16_t);
+  static inline const size_t MethodFieldSize = sizeof(uint16_t);
+  static inline const size_t FileSizeFieldSize = sizeof(uint32_t);
+  static inline const size_t UncompressedSizeFieldSize = sizeof(uint32_t);
+  static inline const size_t HashFieldSize = sizeof(uint64_t);
+  static inline const size_t V1HeaderSize =
+      MagicSize + VersionFieldSize + MethodFieldSize +
+      UncompressedSizeFieldSize + HashFieldSize;
+  static inline const size_t V2HeaderSize =
+      MagicSize + VersionFieldSize + FileSizeFieldSize + MethodFieldSize +
+      UncompressedSizeFieldSize + HashFieldSize;
+  static inline const llvm::StringRef MagicNumber = "CCOB";
+  static inline const uint16_t Version = 2;
+
+public:
+  static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+  compress(llvm::compression::Params P, const llvm::MemoryBuffer &Input,
+           bool Verbose = false);
+  static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+  decompress(llvm::MemoryBufferRef &Input, bool Verbose = false);
+};
+
 /// A simple binary serialization of an offloading file. We use this format to
 /// embed the offloading image into the host executable so it can be extracted
 /// and used by the linker.
@@ -183,11 +210,126 @@ class OffloadFile : public OwningBinary<OffloadBinary> {
   }
 };
 
+class OffloadFatBinBundle {
+
+private:
+  uint64_t Size = 0u;
+  StringRef FileName;
+  int64_t NumberOfEntries;
+
+public:
+
+  struct BundleEntry {
+    uint64_t Offset = 0u;
+    uint64_t Size = 0u;
+    uint64_t IDLength = 0u;
+    StringRef ID;
+    BundleEntry(uint64_t O, uint64_t S, uint64_t I, StringRef T )
+                : Offset(O), Size(S), IDLength(I), ID(T) {}
+    void dump(raw_ostream &OS) { OS << "Offset = " << Offset << ", Size = " << Size << ", ID Length = " << IDLength << ", ID = " << ID;}
+    void dumpURI(raw_ostream &OS, StringRef filePath) {OS << ID.data() << "\tfile:\/\/" << filePath << "#offset=" << Offset << "&size=" << Size <<"\n";}
+  };
+
+  uint64_t getSize() const { return Size; }
+  StringRef getFileName() const {return FileName;}
+  int64_t getNumEntries() const { return NumberOfEntries;}
+
+  std::unique_ptr<SmallVector<BundleEntry>> Entries;
+  static Expected<std::unique_ptr<OffloadFatBinBundle>> create(MemoryBufferRef, uint64_t SectionOffset, StringRef fileName);
+  Error extractBundle(const ObjectFile &Source);
+
+  Error ReadEntries(StringRef Section, uint64_t SectionOffset);
+  void DumpEntries() {
+    SmallVectorImpl<BundleEntry>::iterator it = Entries->begin();
+    for (int64_t I = 0; I < Entries->size(); I++) {
+      it->dump(outs());
+      ++it;
+    }
+  }
+
+  void PrintEntriesAsURI() {
+    SmallVectorImpl<BundleEntry>::iterator it = Entries->begin();
+    for (int64_t I = 0; I < NumberOfEntries; I++) {
+      it->dumpURI(outs(),FileName);
+      ++it;
+    }
+  }
+
+  OffloadFatBinBundle(MemoryBufferRef Source, StringRef file) :
+                      FileName(file) {
+    NumberOfEntries = 0;
+    Entries = std::make_unique<SmallVector<BundleEntry>>();
+  }
+};
+
+enum uri_type_t {FILE_URI, MEMORY_URI};
+
+struct OffloadBundleURI {
+  int64_t Offset = 0;
+  int64_t Size = 0;
+  uint64_t ProcessID = 0;
+  StringRef FileName;
+  uri_type_t URIType;
+
+  // Constructors
+  // TODO: add a Copy ctor ?
+  OffloadBundleURI(StringRef file, int64_t off, int64_t size) : Offset(off), Size(size), ProcessID(0), FileName(file), URIType(FILE_URI) {}
+
+  OffloadBundleURI(StringRef str, uri_type_t type) {
+    URIType = type;
+    switch (URIType) {
+      case FILE_URI:
+    	parseFileName(str);
+        break;
+      case MEMORY_URI:
+        parseMemoryURI(str);
+        break;
+      default:
+    	report_fatal_error("Unrecognized URI type.");
+     }
+  }
+
+  void parseFileName(StringRef str) {
+    ProcessID=0;
+    URIType=FILE_URI;
+    if (str.consume_front("file://")) {
+      StringRef FilePathname = str.take_until([](char c) {return (c == '#') || (c == '?');});
+      FileName = FilePathname;
+      str = str.drop_front(FilePathname.size());
+
+      if (str.consume_front("#offset=")) {
+        StringRef OffsetStr = str.take_until([](char c) { return c == '&';});
+        OffsetStr.getAsInteger(10,Offset);
+        str = str.drop_front(OffsetStr.size());
+
+        if (str.consume_front("&size=")) {
+           Size; str.getAsInteger(10,Size);
+        } else
+        	report_fatal_error("Reading 'size' in URI.");
+      } else
+        report_fatal_error("Reading 'offset' in URI.");
+    } else
+      report_fatal_error("Reading type of URI.");
+  }
+
+  void parseMemoryURI(StringRef str) {
+    // TODO: add parseMemoryURI type
+  }
+
+  StringRef getFileName() const { return FileName;}
+};
+
 /// Extracts embedded device offloading code from a memory \p Buffer to a list
 /// of \p Binaries.
 Error extractOffloadBinaries(MemoryBufferRef Buffer,
                              SmallVectorImpl<OffloadFile> &Binaries);
 
+Error extractFatBinaryFromObject(const ObjectFile &Obj, SmallVectorImpl<OffloadFatBinBundle> &Bundles);
+
+Error extractCodeObject(const ObjectFile &Source, int64_t Offset, int64_t Size, StringRef OutputFileName);
+
+Error extractURI(StringRef URIstr);
+
 /// Convert a string \p Name to an image kind.
 ImageKind getImageKind(StringRef Name);
 
diff --git a/llvm/lib/Object/OffloadBinary.cpp b/llvm/lib/Object/OffloadBinary.cpp
index 89dc12551494fd..75085b96a960bf 100644
--- a/llvm/lib/Object/OffloadBinary.cpp
+++ b/llvm/lib/Object/OffloadBinary.cpp
@@ -17,6 +17,7 @@
 #include "llvm/Object/Archive.h"
 #include "llvm/Object/ArchiveWriter.h"
 #include "llvm/Object/Binary.h"
+#include "llvm/BinaryFormat/COFF.h"
 #include "llvm/Object/COFF.h"
 #include "llvm/Object/ELFObjectFile.h"
 #include "llvm/Object/Error.h"
@@ -25,12 +26,18 @@
 #include "llvm/Support/Alignment.h"
 #include "llvm/Support/FileOutputBuffer.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/Timer.h"
+#include "llvm/Support/BinaryStreamReader.h"
 
 using namespace llvm;
 using namespace llvm::object;
 
 namespace {
 
+static llvm::TimerGroup
+    ClangOffloadBundlerTimerGroup("Clang Offload Bundler Timer Group",
+                                  "Timer group for clang offload bundler");
+
 /// Attempts to extract all the embedded device images contained inside the
 /// buffer \p Contents. The buffer is expected to contain a valid offloading
 /// binary format.
@@ -99,6 +106,42 @@ Error extractFromObject(const ObjectFile &Obj,
   return Error::success();
 }
 
+// Extract an Offload bundle (usually a Clang Offload Bundle) from a fat_bin section
+Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset, StringRef fileName, SmallVectorImpl<OffloadFatBinBundle> &Bundles ) {
+
+  uint64_t Offset = 0;
+  int64_t nextbundleStart = 0;
+
+  // There could be multiple offloading bundles stored at this section.
+  while (nextbundleStart >= 0) {
+
+    std::unique_ptr<MemoryBuffer> Buffer =
+        MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
+                                   /*RequiresNullTerminator*/ false);
+
+    // Create the FatBinBindle object. This will also create the Bundle Entry list info.
+    auto FatBundleOrErr = OffloadFatBinBundle::create(*Buffer, SectionOffset + Offset, fileName);
+    if (!FatBundleOrErr)
+      return FatBundleOrErr.takeError();
+    OffloadFatBinBundle &Bundle = **FatBundleOrErr;
+
+    // add current Bundle to list.
+   Bundles.emplace_back(std::move(**FatBundleOrErr));
+
+    // find the next bundle by searching for the magic string
+    StringRef str = Buffer->getBuffer();
+    nextbundleStart = (int64_t)str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);
+
+    if (nextbundleStart >= 0)
+      Offset += nextbundleStart;
+    else {
+      return Error::success();
+    }
+  } // end of while loop
+
+  return Error::success();
+}
+
 Error extractFromBitcode(MemoryBufferRef Buffer,
                          SmallVectorImpl<OffloadFile> &Binaries) {
   LLVMContext Context;
@@ -170,6 +213,95 @@ Error extractFromArchive(const Archive &Library,
 
 } // namespace
 
+
+Error OffloadFatBinBundle::ReadEntries(StringRef Buffer, uint64_t SectionOffset) {
+  uint64_t BundleNumber = 0;
+  uint64_t NumOfEntries = 0;
+
+  // get Reader
+  BinaryStreamReader Reader( Buffer, llvm::endianness::little);
+
+  // Read the Magic String first.
+  StringRef Magic;
+  if (auto EC = Reader.readFixedString(Magic, 24)) {
+      return errorCodeToError(object_error::parse_failed);
+   }
+
+   // read the number of Code Objects (Entries) in the current Bundle.
+  if (auto EC = Reader.readInteger(NumOfEntries)) {
+    printf("OffloadFatBinBundle::ReadEntries .... failed to read number of Entries\n");
+    return errorCodeToError(object_error::parse_failed);
+  }
+  NumberOfEntries = NumOfEntries;
+
+    // For each Bundle Entry (code object)
+    for (uint64_t I = 0; I < NumOfEntries; I++) {
+      uint64_t EntrySize;
+      uint64_t EntryOffset;
+      uint64_t EntryIDSize;
+      StringRef EntryID;
+      uint64_t absOffset;
+
+      if (auto EC = Reader.readInteger(EntryOffset)) {
+        return errorCodeToError(object_error::parse_failed);
+      }
+
+      if (auto EC = Reader.readInteger(EntrySize)) {
+        return errorCodeToError(object_error::parse_failed);
+      }
+
+      if (auto EC = Reader.readInteger(EntryIDSize)) {
+        return errorCodeToError(object_error::parse_failed);
+      }
+
+      if (auto EC = Reader.readFixedString(EntryID, EntryIDSize)) {
+        return errorCodeToError(object_error::parse_failed);
+      }
+
+      // create a Bundle Entry object:
+      auto entry = new OffloadFatBinBundle::BundleEntry(EntryOffset+SectionOffset, EntrySize, EntryIDSize, EntryID);
+
+      Entries->push_back(*entry);
+    } // end of for loop
+
+    return Error::success();
+}
+
+Expected<std::unique_ptr<OffloadFatBinBundle>> OffloadFatBinBundle::create(MemoryBufferRef Buf, uint64_t SectionOffset, StringRef fileName) {
+  if (Buf.getBufferSize() < 24)
+    return errorCodeToError(object_error::parse_failed);
+
+  // Check for magic bytes.
+  if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)
+    return errorCodeToError(object_error::parse_failed);
+
+  OffloadFatBinBundle* TheBundle = new OffloadFatBinBundle(Buf, fileName);
+
+  // Read the Bundle Entries
+  Error Err = TheBundle->ReadEntries(Buf.getBuffer(), SectionOffset);
+  if (Err)
+    return errorCodeToError(object_error::parse_failed);
+
+  return std::unique_ptr<OffloadFatBinBundle>(TheBundle);
+}
+
+Error OffloadFatBinBundle::extractBundle(const ObjectFile &Source) {
+  // This will extract all entries in the Bundle
+  SmallVectorImpl<OffloadFatBinBundle::BundleEntry>::iterator it = Entries->begin();
+  for (int64_t I = 0; I < getNumEntries(); I++) {
+
+    if (it->Size >0) {
+      // create output file name. Which should be <fileName>-offset<Offset>-size<Size>.co"
+     std::string str = getFileName().str() + "-offset" + itostr(it->Offset) + "-size" + itostr(it->Size) + ".co";
+     if (Error Err = object::extractCodeObject(Source, it->Offset, it->Size, StringRef(str)))
+       return Err;
+    }
+    ++it;
+  }
+
+  return Error::success();
+}
+
 Expected<std::unique_ptr<OffloadBinary>>
 OffloadBinary::create(MemoryBufferRef Buf) {
   if (Buf.getBufferSize() < sizeof(Header) + sizeof(Entry))
@@ -299,6 +431,94 @@ Error object::extractOffloadBinaries(MemoryBufferRef Buffer,
   }
 }
 
+Error object::extractFatBinaryFromObject(const ObjectFile &Obj, SmallVectorImpl<OffloadFatBinBundle> &Bundles) {
+  assert((Obj.isELF() || Obj.isCOFF()) && "Invalid file type");
+
+  // iterate through Sections until we find an offload_bundle section.
+  for (SectionRef Sec : Obj.sections()) {
+    Expected<StringRef> Buffer = Sec.getContents();
+    if (!Buffer)
+      return Buffer.takeError();
+
+    // If it does not start with the reserved suffix, just skip this section.
+    if ((llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle) ||
+    	(llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle_compressed)){
+
+      uint64_t SectionOffset = 0;
+      if (Obj.isELF()) {
+    	  SectionOffset = ELFSectionRef(Sec).getOffset();
+      } else if (Obj.isCOFF()) {
+    	  if (const COFFObjectFile *COFFObj = dyn_cast<COFFObjectFile>(&Obj)) {
+    	    const coff_section *CoffSection = COFFObj->getCOFFSection(Sec);
+    	    fprintf(stderr,"DAVE: COFF viritual address =0x%llX\n", CoffSection->VirtualAddress);// COFFObj->getCOFFSection(Sec)->VirtualAddress);
+    	  }
+      }
+
+      MemoryBufferRef Contents(*Buffer, Obj.getFileName());
+
+      if (llvm::identify_magic(*Buffer) == llvm::file_magic::offload_bundle_compressed){
+        // Decompress the input if necessary.
+        Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
+            CompressedOffloadBundle::decompress(Contents, false);
+
+        if (!DecompressedBufferOrErr)
+            return createStringError(
+                inconvertibleErrorCode(),
+                "Failed to decompress input: " +
+                    llvm::toString(DecompressedBufferOrErr.takeError()));
+
+          MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
+          if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset, Obj.getFileName() ,Bundles))
+                    return Err;
+      } else {
+        if (Error Err = extractOffloadBundle(Contents, SectionOffset, Obj.getFileName() ,Bundles))
+          return Err;
+      }
+    }
+  }
+  return Error::success();
+}
+
+Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset, int64_t Size, StringRef OutputFileName) {
+  Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
+     FileOutputBuffer::create(OutputFileName, Size);
+
+  if (!BufferOrErr)
+    return BufferOrErr.takeError();
+
+  Expected<MemoryBufferRef> InputBuffOrErr = Source.getMemoryBufferRef();
+  if (Error Err = InputBuffOrErr.takeError())
+    return Err;
+
+  std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
+  std::copy(InputBuffOrErr->getBufferStart() + Offset, InputBuffOrErr->getBufferStart() + Offset + Size, Buf->getBufferStart());
+  if (Error E = Buf->commit())
+    return E;
+
+  return Error::success();
+}
+
+// given a file name, offset, and size, extract data into a code object file,
+// into file <SourceFile>-offset<Offset>-size<Size>.co
+Error object::extractURI(StringRef URIstr) {
+  // create a URI object
+  object::OffloadBundleURI* uri = new object::OffloadBundleURI(URIstr, FILE_URI);
+
+  std::string OutputFile = uri->FileName.str();
+  OutputFile +="-offset" + itostr(uri->Offset) + "-size" + itostr(uri->Size) + ".co";
+
+  // Create an ObjectFile object from uri.file_uri
+  auto ObjOrErr = ObjectFile::createObjectFile(uri->FileName);
+  if(!ObjOrErr)
+    return ObjOrErr.takeError();
+
+  auto Obj = ObjOrErr->getBinary();
+  if (Error Err = object::extractCodeObject(*Obj, uri->Offset, uri->Size, OutputFile))
+    return Err;
+
+  return Error::success();
+}
+
 OffloadKind object::getOffloadKind(StringRef Name) {
   return llvm::StringSwitch<OffloadKind>(Name)
       .Case("openmp", OFK_OpenMP)
@@ -382,3 +602,224 @@ bool object::areTargetsCompatible(const OffloadFile::TargetID &LHS,
     return false;
   return true;
 }
+
+
+// Utility function to format numbers with commas
+static std::string formatWithCommas(unsigned long long Value) {
+  std::string Num = std::to_string(Value);
+  int InsertPosition = Num.length() - 3;
+  while (InsertPosition > 0) {
+    Num.insert(InsertPosition, ",");
+    InsertPosition -= 3;
+  }
+  return Num;
+}
+
+llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
+
+                                    bool Verbose) {
+  StringRef Blob = Input.getBuffer();
+
+  if (Blob.size() < V1HeaderSize)
+    return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+
+  if (llvm::identify_magic(Blob) !=
+      llvm::file_magic::offload_bundle_compressed) {
+    if (Verbose)
+      llvm::errs() << "Uncompressed bundle.\n";
+    return llvm::MemoryBuffer::getMemBufferCopy(Blob);
+  }
+
+  size_t CurrentOffset = MagicSize;
+
+  uint16_t ThisVersion;
+  memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
+  CurrentOffset += VersionFieldSize;
+
+  uint16_t CompressionMethod;
+  memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
+  CurrentOffset += MethodFieldSize;
+
+  uint32_t TotalFileSize;
+  if (ThisVersion >= 2) {
+    if (Blob.size() < V2HeaderSize)
+      return createStringError(inconvertibleErrorCode(),
+                               "Compressed bundle header size too small");
+    memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
+    CurrentOffset += FileSizeFieldSize;
+  }
+
+  uint32_t UncompressedSize;
+  memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
+  CurrentOffset += UncompressedSizeFieldSize;
+
+  uint64_t StoredHash;
+  memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
+  CurrentOffset += HashFieldSize;
+
+  llvm::compression::Format CompressionFormat;
+  if (CompressionMethod ==
+      static_cast<uint16_t>(llvm::compression::Format::Zlib))
+    CompressionFormat = llvm::compression::Format::Zlib;
+  else if (CompressionMethod ==
+           static_cast<uint16_t>(llvm::compression::Format::Zstd))
+    CompressionFormat = llvm::compression::Format::Zstd;
+  else
+    return createStringError(inconvertibleErrorCode(),
+                             "Unknown compressing method");
+
+  llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
+                              ClangOffloadBundlerTimerGroup);
+  if (Verbose)
+    DecompressTimer.startTimer();
+
+  SmallVector<uint8_t, 0> DecompressedData;
+  StringRef CompressedData = Blob.substr(CurrentOffset);
+  if (llvm::Error DecompressionError = llvm::compression::decompress(
+          CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
+          DecompressedData, UncompressedSize))
+    return createStringError(inconvertibleErrorCode(),
+                             "Could not decompress embedded file contents: " +
+                                 llvm::toString(std::move(DecompressionError)));
+
+  if (Verbose) {
+    DecompressTimer.stopTimer();
+
+    double DecompressionTimeSeconds =
+        DecompressTimer.getTotalTime().getWallTime();
+
+    // Recalculate MD5 hash for integrity check
+    llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
+                                "Hash recalculation time",
+                                ClangOffloadBundlerTimerGroup);
+    HashRecalcTimer.startTimer();
+    llvm::MD5 Hash;
+    llvm::MD5::MD5Result Result;
+    Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData.data(),
+                                        DecompressedData.size()));
+    Hash.final(Result);
+    uint64_t RecalculatedHash = Result.low();
+    HashRecalcTimer.stopTimer();
+    bool HashMatch = (StoredHash == RecalculatedHash);
+
+    double CompressionRate =
+        static_cast<double>(UncompressedSize) / CompressedData.size();
+    double DecompressionSpeedMBs =
+        (UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
+
+    llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
+    if (ThisVersion >= 2)
+      llvm::errs() << "Total file size (from header): "
+                   << formatWithCommas(TotalFileSize) << " bytes\n";
+    llvm::errs() << "Decompression method: "
+                 << (CompressionFormat == llvm::compression::Format::Zlib
+                         ? "zlib"
+                         : "zstd")
+                 << "\n"
+                 << "Size before decompression: "
+                 << formatWithCommas(CompressedData.size()) << " bytes\n"
+                 << "Size after decompression: "
+                 << formatWithCommas(UncompressedSize) << " bytes\n"
+                 << "Compression rate: "
+                 << llvm::format("%.2lf", CompressionRate) << "\n"
+                 << "Compression ratio: "
+                 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
+                 << "Decompression speed: "
+                 << llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
+                 << "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
+                 << "Recalculated hash: "
+                 << llvm::format_hex(RecalculatedHash, 16) << "\n"
+                 << "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
+  }
+
+  return llvm::MemoryBuffer::getMemBufferCopy(
+      llvm::toStringRef(DecompressedData));
+}
+
+llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+CompressedOffloadBundle::compress(llvm::compression::Params P,
+                                  const llvm::MemoryBuffer &Input,
+                                  bool Verbose) {
+  if (!llvm::compression::zstd::isAvailable() &&
+      !llvm::compression::zlib::isAvailable())
+    return createStringError(llvm::inconvertibleErrorCode(),
+                             "Compression not supported");
+
+  llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+                        ClangOffloadBundlerTimerGroup);
+  if (Verbose)
+    HashTimer.startTimer();
+  llvm::MD5 Hash;
+  llvm::MD5::MD5Result Result;
+  Hash.update(Input.getBuffer());
+  Hash.final(Result);
+  uint64_t TruncatedHash = Result.low();
+  if (Verbose)
+    HashTimer.stopTimer();
+
+  SmallVector<uint8_t, 0> CompressedBuffer;
+  auto BufferUint8 = llvm::ArrayRef<uint8_t>(
+      reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+      Input.getBuffer().size());
+
+  llvm::Timer CompressTimer("Compression Timer", "Compression time",
+                            ClangOffloadBundlerTimerGroup);
+  if (Verbose)
+    CompressTimer.startTimer();
+  llvm::compression::compress(P, BufferUint8, CompressedBuffer);
+  if (Verbose)
+    CompressTimer.stopTimer();
+
+  uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
+  uint32_t UncompressedSize = Input.getBuffer().size();
+  uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
+                           sizeof(Version) + sizeof(CompressionMethod) +
+                           sizeof(UncompressedSize) + sizeof(TruncatedHash) +
+                           CompressedBuffer.size();
+
+  SmallVector<char, 0> FinalBuffer;
+  llvm::raw_svector_ostream OS(FinalBuffer);
+  OS << MagicNumber;
+  OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
+  OS.write(reinterpret_cast<const char *>(&CompressionMethod),
+           sizeof(CompressionMethod));
+  OS.write(reinterpret_cast<const char *>(&TotalFileSize),
+           sizeof(TotalFileSize));
+  OS.write(reinterpret_cast<const char *>(&UncompressedSize),
+           sizeof(UncompressedSize));
+  OS.write(reinterpret_cast<const char *>(&TruncatedHash),
+           sizeof(TruncatedHash));
+  OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
+           CompressedBuffer.size());
+
+  if (Verbose) {
+    auto MethodUsed =
+        P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
+    double CompressionRate =
+        static_cast<double>(UncompressedSize) / CompressedBuffer.size();
+    double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
+    double CompressionSpeedMBs =
+        (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
+
+    llvm::errs() << "Compressed bundle format version: " << Version << "\n"
+                 << "Total file size (including headers): "
+                 << formatWithCommas(TotalFileSize) << " bytes\n"
+                 << "Compression method used: " << MethodUsed << "\n"
+                 << "Compression level: " << P.level << "\n"
+                 << "Binary size before compression: "
+                 << formatWithCommas(UncompressedSize) << " bytes\n"
+                 << "Binary size after compression: "
+                 << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
+                 << "Compression rate: "
+                 << llvm::format("%.2lf", CompressionRate) << "\n"
+                 << "Compression ratio: "
+                 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
+                 << "Compression speed: "
+                 << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
+                 << "Truncated MD5 hash: "
+                 << llvm::format_hex(TruncatedHash, 16) << "\n";
+  }
+  return llvm::MemoryBuffer::getMemBufferCopy(
+      llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
+}



More information about the llvm-commits mailing list