[llvm] [Offload] Copy loaded images into managed storage (PR #158748)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 15 18:20:10 PDT 2025


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/158748

>From b7274b6adc811a9732ecb8c66110733fbc021de5 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Mon, 15 Sep 2025 17:06:57 -0500
Subject: [PATCH] [Offload] Copy loaded images into managed storage

Summary:
Currently we have this `__tgt_device_image` indirection which just takes
a reference to some pointers. This was all find and good when the only
usage of this was from a section of GPU code that came from an ELF
constant section. However, we have expanded beyond that and now need to
worry about managing lifetimes. We have code that references the image
even after it was loaded internally. This patch changes the
implementation to instaed copy the memory buffer and manage it locally.

This PR reworks the JIT and other image handling to directly manage its
own memory. We now don't need to duplicate this behavior externally at
the Offload API level. Also we actually free these if the user unloads
them.

Upside, less likely to crash and burn. Downside, more latency when
loading an image.
---
 offload/liboffload/src/OffloadImpl.cpp        | 35 +++-------
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    | 28 ++++----
 offload/plugins-nextgen/common/include/JIT.h  | 23 ++-----
 .../common/include/PluginInterface.h          | 44 +++---------
 offload/plugins-nextgen/common/src/JIT.cpp    | 67 ++++---------------
 .../common/src/PluginInterface.cpp            | 54 ++++++---------
 offload/plugins-nextgen/cuda/src/rtl.cpp      | 25 +++----
 offload/plugins-nextgen/host/src/rtl.cpp      | 11 +--
 8 files changed, 89 insertions(+), 198 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 7e8e297831f45..b5b9b0e83b975 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -157,14 +157,11 @@ struct ol_event_impl_t {
 
 struct ol_program_impl_t {
   ol_program_impl_t(plugin::DeviceImageTy *Image,
-                    std::unique_ptr<llvm::MemoryBuffer> ImageData,
-                    const __tgt_device_image &DeviceImage)
-      : Image(Image), ImageData(std::move(ImageData)),
-        DeviceImage(DeviceImage) {}
+                    llvm::MemoryBufferRef DeviceImage)
+      : Image(Image), DeviceImage(DeviceImage) {}
   plugin::DeviceImageTy *Image;
-  std::unique_ptr<llvm::MemoryBuffer> ImageData;
   std::mutex SymbolListMutex;
-  __tgt_device_image DeviceImage;
+  llvm::MemoryBufferRef DeviceImage;
   llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
   llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
 };
@@ -891,28 +888,14 @@ Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
 Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
                            size_t ProgDataSize, ol_program_handle_t *Program) {
   // Make a copy of the program binary in case it is released by the caller.
-  auto ImageData = MemoryBuffer::getMemBufferCopy(
-      StringRef(reinterpret_cast<const char *>(ProgData), ProgDataSize));
-
-  auto DeviceImage = __tgt_device_image{
-      const_cast<char *>(ImageData->getBuffer().data()),
-      const_cast<char *>(ImageData->getBuffer().data()) + ProgDataSize, nullptr,
-      nullptr};
-
-  ol_program_handle_t Prog =
-      new ol_program_impl_t(nullptr, std::move(ImageData), DeviceImage);
-
-  auto Res =
-      Device->Device->loadBinary(Device->Device->Plugin, &Prog->DeviceImage);
-  if (!Res) {
-    delete Prog;
+  StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
+  Expected<plugin::DeviceImageTy *> Res =
+      Device->Device->loadBinary(Device->Device->Plugin, Buffer);
+  if (!Res)
     return Res.takeError();
-  }
-  assert(*Res != nullptr && "loadBinary returned nullptr");
-
-  Prog->Image = *Res;
-  *Program = Prog;
+  assert(*Res && "loadBinary returned nullptr");
 
+  *Program = new ol_program_impl_t(*Res, (*Res)->getMemoryBuffer());
   return Error::success();
 }
 
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index c26cfe961aa0e..1d33bfc1a0be9 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -464,8 +464,8 @@ struct AMDGPUMemoryManagerTy : public DeviceAllocatorTy {
 struct AMDGPUDeviceImageTy : public DeviceImageTy {
   /// Create the AMDGPU image with the id and the target image pointer.
   AMDGPUDeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
-                      const __tgt_device_image *TgtImage)
-      : DeviceImageTy(ImageId, Device, TgtImage) {}
+                      std::unique_ptr<MemoryBuffer> &&TgtImage)
+      : DeviceImageTy(ImageId, Device, std::move(TgtImage)) {}
 
   /// Prepare and load the executable corresponding to the image.
   Error loadExecutable(const AMDGPUDeviceTy &Device);
@@ -2160,7 +2160,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     AMDGPUDeviceImageTy &AMDImage = static_cast<AMDGPUDeviceImageTy &>(*Image);
 
     // Unload the executable of the image.
-    return AMDImage.unloadExecutable();
+    if (auto Err = AMDImage.unloadExecutable())
+      return Err;
+
+    // Destroy the associated memory and invalidate the object.
+    Plugin.free(Image);
+    return Error::success();
   }
 
   /// Deinitialize the device and release its resources.
@@ -2183,18 +2188,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
 
   virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
                                        DeviceImageTy &Image) override {
-    GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
-    if (Handler.isSymbolInImage(*this, Image, "amdgcn.device.fini"))
-      Image.setPendingGlobalDtors();
-
     return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/true);
   }
 
   virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
                                       DeviceImageTy &Image) override {
-    if (Image.hasPendingGlobalDtors())
-      return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
-    return Plugin::success();
+    return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
   }
 
   uint64_t getStreamBusyWaitMicroseconds() const { return OMPX_StreamBusyWait; }
@@ -2321,11 +2320,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   }
 
   /// Load the binary image into the device and allocate an image object.
-  Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
-                                           int32_t ImageId) override {
+  Expected<DeviceImageTy *>
+  loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage,
+                 int32_t ImageId) override {
     // Allocate and initialize the image object.
     AMDGPUDeviceImageTy *AMDImage = Plugin.allocate<AMDGPUDeviceImageTy>();
-    new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, TgtImage);
+    new (AMDImage) AMDGPUDeviceImageTy(ImageId, *this, std::move(TgtImage));
 
     // Load the HSA executable.
     if (Error Err = AMDImage->loadExecutable(*this))
@@ -3105,7 +3105,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     // Perform a quick check for the named kernel in the image. The kernel
     // should be created by the 'amdgpu-lower-ctor-dtor' pass.
     GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
-    if (IsCtor && !Handler.isSymbolInImage(*this, Image, KernelName))
+    if (!Handler.isSymbolInImage(*this, Image, KernelName))
       return Plugin::success();
 
     // Allocate and construct the AMDGPU kernel.
diff --git a/offload/plugins-nextgen/common/include/JIT.h b/offload/plugins-nextgen/common/include/JIT.h
index d62516d20764a..b4e3712d9c980 100644
--- a/offload/plugins-nextgen/common/include/JIT.h
+++ b/offload/plugins-nextgen/common/include/JIT.h
@@ -51,27 +51,22 @@ struct JITEngine {
   /// Run jit compilation if \p Image is a bitcode image, otherwise simply
   /// return \p Image. It is expected to return a memory buffer containing the
   /// generated device image that could be loaded to the device directly.
-  Expected<const __tgt_device_image *>
-  process(const __tgt_device_image &Image,
-          target::plugin::GenericDeviceTy &Device);
-
-  /// Remove \p Image from the jit engine's cache
-  void erase(const __tgt_device_image &Image,
-             target::plugin::GenericDeviceTy &Device);
+  Expected<std::unique_ptr<MemoryBuffer>>
+  process(StringRef Image, target::plugin::GenericDeviceTy &Device);
 
 private:
   /// Compile the bitcode image \p Image and generate the binary image that can
   /// be loaded to the target device of the triple \p Triple architecture \p
   /// MCpu. \p PostProcessing will be called after codegen to handle cases such
   /// as assembler as an external tool.
-  Expected<const __tgt_device_image *>
-  compile(const __tgt_device_image &Image, const std::string &ComputeUnitKind,
+  Expected<std::unique_ptr<MemoryBuffer>>
+  compile(StringRef Image, const std::string &ComputeUnitKind,
           PostProcessingFn PostProcessing);
 
   /// Create or retrieve the object image file from the file system or via
   /// compilation of the \p Image.
   Expected<std::unique_ptr<MemoryBuffer>>
-  getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
+  getOrCreateObjFile(StringRef Image, LLVMContext &Ctx,
                      const std::string &ComputeUnitKind);
 
   /// Run backend, which contains optimization and code generation.
@@ -92,14 +87,6 @@ struct JITEngine {
   struct ComputeUnitInfo {
     /// LLVM Context in which the modules will be constructed.
     LLVMContext Context;
-
-    /// A map of embedded IR images to the buffer used to store JITed code
-    DenseMap<const __tgt_device_image *, std::unique_ptr<MemoryBuffer>>
-        JITImages;
-
-    /// A map of embedded IR images to JITed images.
-    DenseMap<const __tgt_device_image *, std::unique_ptr<__tgt_device_image>>
-        TgtImageMap;
   };
 
   /// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 6ff3ef8cda177..ce66d277d6187 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -306,26 +306,18 @@ class DeviceImageTy {
   /// not unique between different device; they may overlap.
   int32_t ImageId;
 
-  /// The pointer to the raw __tgt_device_image.
-  const __tgt_device_image *TgtImage;
-  const __tgt_device_image *TgtImageBitcode;
+  /// The managed image data.
+  std::unique_ptr<MemoryBuffer> Image;
 
   /// Reference to the device this image is loaded on.
   GenericDeviceTy &Device;
 
-  /// If this image has any global destructors that much be called.
-  /// FIXME: This is only required because we currently have no invariants
-  ///        towards the lifetime of the underlying image. We should either copy
-  ///        the image into memory locally or erase the pointers after init.
-  bool PendingGlobalDtors;
-
 public:
+  virtual ~DeviceImageTy() = default;
+
   DeviceImageTy(int32_t Id, GenericDeviceTy &Device,
-                const __tgt_device_image *Image)
-      : ImageId(Id), TgtImage(Image), TgtImageBitcode(nullptr), Device(Device),
-        PendingGlobalDtors(false) {
-    assert(TgtImage && "Invalid target image");
-  }
+                std::unique_ptr<MemoryBuffer> &&Image)
+      : ImageId(Id), Image(std::move(Image)), Device(Device) {}
 
   /// Get the image identifier within the device.
   int32_t getId() const { return ImageId; }
@@ -333,33 +325,17 @@ class DeviceImageTy {
   /// Get the device that this image is loaded onto.
   GenericDeviceTy &getDevice() const { return Device; }
 
-  /// Get the pointer to the raw __tgt_device_image.
-  const __tgt_device_image *getTgtImage() const { return TgtImage; }
-
-  void setTgtImageBitcode(const __tgt_device_image *TgtImageBitcode) {
-    this->TgtImageBitcode = TgtImageBitcode;
-  }
-
-  const __tgt_device_image *getTgtImageBitcode() const {
-    return TgtImageBitcode;
-  }
-
   /// Get the image starting address.
-  void *getStart() const { return TgtImage->ImageStart; }
+  const void *getStart() const { return Image->getBufferStart(); }
 
   /// Get the image size.
-  size_t getSize() const {
-    return utils::getPtrDiff(TgtImage->ImageEnd, TgtImage->ImageStart);
-  }
+  size_t getSize() const { return Image->getBufferSize(); }
 
   /// Get a memory buffer reference to the whole image.
   MemoryBufferRef getMemoryBuffer() const {
     return MemoryBufferRef(StringRef((const char *)getStart(), getSize()),
                            "Image");
   }
-  /// Accessors to the boolean value
-  bool setPendingGlobalDtors() { return PendingGlobalDtors = true; }
-  bool hasPendingGlobalDtors() const { return PendingGlobalDtors; }
 };
 
 /// Class implementing common functionalities of offload kernels. Each plugin
@@ -831,9 +807,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
 
   /// Load the binary image into the device and return the target table.
   Expected<DeviceImageTy *> loadBinary(GenericPluginTy &Plugin,
-                                       const __tgt_device_image *TgtImage);
+                                       StringRef TgtImage);
   virtual Expected<DeviceImageTy *>
-  loadBinaryImpl(const __tgt_device_image *TgtImage, int32_t ImageId) = 0;
+  loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage, int32_t ImageId) = 0;
 
   /// Unload a previously loaded Image from the device
   Error unloadBinary(DeviceImageTy *Image);
diff --git a/offload/plugins-nextgen/common/src/JIT.cpp b/offload/plugins-nextgen/common/src/JIT.cpp
index 00720fa2d8103..07ef05e7e9d38 100644
--- a/offload/plugins-nextgen/common/src/JIT.cpp
+++ b/offload/plugins-nextgen/common/src/JIT.cpp
@@ -49,13 +49,6 @@ using namespace omp::target;
 
 namespace {
 
-bool isImageBitcode(const __tgt_device_image &Image) {
-  StringRef Binary(reinterpret_cast<const char *>(Image.ImageStart),
-                   utils::getPtrDiff(Image.ImageEnd, Image.ImageStart));
-
-  return identify_magic(Binary) == file_magic::bitcode;
-}
-
 Expected<std::unique_ptr<Module>>
 createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
                              LLVMContext &Context) {
@@ -66,12 +59,10 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
                                      "failed to create module");
   return std::move(Mod);
 }
-Expected<std::unique_ptr<Module>>
-createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
-  StringRef Data((const char *)Image.ImageStart,
-                 utils::getPtrDiff(Image.ImageEnd, Image.ImageStart));
+Expected<std::unique_ptr<Module>> createModuleFromImage(StringRef Image,
+                                                        LLVMContext &Context) {
   std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
-      Data, /*BufferName=*/"", /*RequiresNullTerminator=*/false);
+      Image, /*BufferName=*/"", /*RequiresNullTerminator=*/false);
   return createModuleFromMemoryBuffer(MB, Context);
 }
 
@@ -238,7 +229,7 @@ JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
 }
 
 Expected<std::unique_ptr<MemoryBuffer>>
-JITEngine::getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
+JITEngine::getOrCreateObjFile(StringRef Image, LLVMContext &Ctx,
                               const std::string &ComputeUnitKind) {
 
   // Check if the user replaces the module at runtime with a finished object.
@@ -277,58 +268,28 @@ JITEngine::getOrCreateObjFile(const __tgt_device_image &Image, LLVMContext &Ctx,
   return backend(*Mod, ComputeUnitKind, JITOptLevel);
 }
 
-Expected<const __tgt_device_image *>
-JITEngine::compile(const __tgt_device_image &Image,
-                   const std::string &ComputeUnitKind,
+Expected<std::unique_ptr<MemoryBuffer>>
+JITEngine::compile(StringRef Image, const std::string &ComputeUnitKind,
                    PostProcessingFn PostProcessing) {
   std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
 
-  // Check if we JITed this image for the given compute unit kind before.
-  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
-  if (CUI.TgtImageMap.contains(&Image))
-    return CUI.TgtImageMap[&Image].get();
-
-  auto ObjMBOrErr = getOrCreateObjFile(Image, CUI.Context, ComputeUnitKind);
+  LLVMContext Ctz;
+  auto ObjMBOrErr = getOrCreateObjFile(Image, Ctz, ComputeUnitKind);
   if (!ObjMBOrErr)
     return ObjMBOrErr.takeError();
 
-  auto ImageMBOrErr = PostProcessing(std::move(*ObjMBOrErr));
-  if (!ImageMBOrErr)
-    return ImageMBOrErr.takeError();
-
-  CUI.JITImages.insert({&Image, std::move(*ImageMBOrErr)});
-  auto &ImageMB = CUI.JITImages[&Image];
-  CUI.TgtImageMap.insert({&Image, std::make_unique<__tgt_device_image>()});
-  auto &JITedImage = CUI.TgtImageMap[&Image];
-  *JITedImage = Image;
-  JITedImage->ImageStart = const_cast<char *>(ImageMB->getBufferStart());
-  JITedImage->ImageEnd = const_cast<char *>(ImageMB->getBufferEnd());
-
-  return JITedImage.get();
+  return PostProcessing(std::move(*ObjMBOrErr));
 }
 
-Expected<const __tgt_device_image *>
-JITEngine::process(const __tgt_device_image &Image,
-                   target::plugin::GenericDeviceTy &Device) {
-  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
+Expected<std::unique_ptr<MemoryBuffer>>
+JITEngine::process(StringRef Image, target::plugin::GenericDeviceTy &Device) {
+  assert(identify_magic(Image) == file_magic::bitcode && "Image not LLVM-IR");
 
+  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
   PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
       -> Expected<std::unique_ptr<MemoryBuffer>> {
     return Device.doJITPostProcessing(std::move(MB));
   };
 
-  if (isImageBitcode(Image))
-    return compile(Image, ComputeUnitKind, PostProcessing);
-
-  return &Image;
-}
-
-void JITEngine::erase(const __tgt_device_image &Image,
-                      target::plugin::GenericDeviceTy &Device) {
-  std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
-  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
-  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
-
-  CUI.TgtImageMap.erase(&Image);
-  CUI.JITImages.erase(&Image);
+  return compile(Image, ComputeUnitKind, PostProcessing);
 }
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 36cdd6035e26d..4f83caf7a1187 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -214,15 +214,7 @@ struct RecordReplayTy {
     raw_fd_ostream OS(ImageName, EC);
     if (EC)
       report_fatal_error("Error saving image : " + StringRef(EC.message()));
-    if (const auto *TgtImageBitcode = Image.getTgtImageBitcode()) {
-      size_t Size = utils::getPtrDiff(TgtImageBitcode->ImageEnd,
-                                      TgtImageBitcode->ImageStart);
-      MemoryBufferRef MBR = MemoryBufferRef(
-          StringRef((const char *)TgtImageBitcode->ImageStart, Size), "");
-      OS << MBR.getBuffer();
-    } else {
-      OS << Image.getMemoryBuffer().getBuffer();
-    }
+    OS << Image.getMemoryBuffer().getBuffer();
     OS.close();
   }
 
@@ -813,9 +805,6 @@ Error GenericDeviceTy::unloadBinary(DeviceImageTy *Image) {
       return Err;
   }
 
-  if (Image->getTgtImageBitcode())
-    Plugin.getJIT().erase(*Image->getTgtImageBitcode(), Image->getDevice());
-
   return unloadBinaryImpl(Image);
 }
 
@@ -865,32 +854,29 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
 
   return deinitImpl();
 }
-Expected<DeviceImageTy *>
-GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
-                            const __tgt_device_image *InputTgtImage) {
-  assert(InputTgtImage && "Expected non-null target image");
+Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
+                                                      StringRef InputTgtImage) {
   DP("Load data from image " DPxMOD "\n", DPxPTR(InputTgtImage->ImageStart));
 
-  auto PostJITImageOrErr = Plugin.getJIT().process(*InputTgtImage, *this);
-  if (!PostJITImageOrErr) {
-    auto Err = PostJITImageOrErr.takeError();
-    REPORT("Failure to jit IR image %p on device %d: %s\n", InputTgtImage,
-           DeviceId, toStringWithoutConsuming(Err).data());
-    return Plugin::error(ErrorCode::COMPILE_FAILURE, std::move(Err),
-                         "failure to jit IR image");
+  std::unique_ptr<MemoryBuffer> Buffer;
+  if (identify_magic(InputTgtImage) == file_magic::bitcode) {
+    auto CompiledImageOrErr = Plugin.getJIT().process(InputTgtImage, *this);
+    if (!CompiledImageOrErr) {
+      return Plugin::error(ErrorCode::COMPILE_FAILURE,
+                           CompiledImageOrErr.takeError(),
+                           "failure to jit IR image");
+    }
+    Buffer = std::move(*CompiledImageOrErr);
+  } else {
+    Buffer = MemoryBuffer::getMemBufferCopy(InputTgtImage);
   }
 
   // Load the binary and allocate the image object. Use the next available id
   // for the image id, which is the number of previously loaded images.
-  auto ImageOrErr =
-      loadBinaryImpl(PostJITImageOrErr.get(), LoadedImages.size());
+  auto ImageOrErr = loadBinaryImpl(std::move(Buffer), LoadedImages.size());
   if (!ImageOrErr)
     return ImageOrErr.takeError();
-
   DeviceImageTy *Image = *ImageOrErr;
-  assert(Image != nullptr && "Invalid image");
-  if (InputTgtImage != PostJITImageOrErr.get())
-    Image->setTgtImageBitcode(InputTgtImage);
 
   // Add the image to list.
   LoadedImages.push_back(Image);
@@ -912,12 +898,12 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
 
 #ifdef OMPT_SUPPORT
   if (ompt::Initialized) {
-    size_t Bytes =
-        utils::getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
+    size_t Bytes = InputTgtImage.size();
     performOmptCallback(
         device_load, Plugin.getUserId(DeviceId),
         /*FileName=*/nullptr, /*FileOffset=*/0, /*VmaInFile=*/nullptr,
-        /*ImgSize=*/Bytes, /*HostAddr=*/InputTgtImage->ImageStart,
+        /*ImgSize=*/Bytes,
+        /*HostAddr=*/const_cast<unsigned char *>(InputTgtImage.bytes_begin()),
         /*DeviceAddr=*/nullptr, /* FIXME: ModuleId */ 0);
   }
 #endif
@@ -1848,7 +1834,9 @@ int32_t GenericPluginTy::load_binary(int32_t DeviceId,
                                      __tgt_device_binary *Binary) {
   GenericDeviceTy &Device = getDevice(DeviceId);
 
-  auto ImageOrErr = Device.loadBinary(*this, TgtImage);
+  StringRef Buffer(reinterpret_cast<const char *>(TgtImage->ImageStart),
+                   utils::getPtrDiff(TgtImage->ImageEnd, TgtImage->ImageStart));
+  auto ImageOrErr = Device.loadBinary(*this, Buffer);
   if (!ImageOrErr) {
     auto Err = ImageOrErr.takeError();
     REPORT("Failure to load binary image %p on device %d: %s\n", TgtImage,
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index af3c74636bff3..99195cd8d7c99 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -81,8 +81,8 @@ CUresult cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream) {}
 struct CUDADeviceImageTy : public DeviceImageTy {
   /// Create the CUDA image with the id and the target image pointer.
   CUDADeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
-                    const __tgt_device_image *TgtImage)
-      : DeviceImageTy(ImageId, Device, TgtImage), Module(nullptr) {}
+                    std::unique_ptr<MemoryBuffer> &&TgtImage)
+      : DeviceImageTy(ImageId, Device, std::move(TgtImage)), Module(nullptr) {}
 
   /// Load the image as a CUDA module.
   Error loadModule() {
@@ -385,6 +385,8 @@ struct CUDADeviceTy : public GenericDeviceTy {
     if (auto Err = CUDAImage.unloadModule())
       return Err;
 
+    // Destroy the associated memory and invalidate the object.
+    Plugin.free(Image);
     return Plugin::success();
   }
 
@@ -418,20 +420,12 @@ struct CUDADeviceTy : public GenericDeviceTy {
 
   virtual Error callGlobalConstructors(GenericPluginTy &Plugin,
                                        DeviceImageTy &Image) override {
-    // Check for the presence of global destructors at initialization time. This
-    // is required when the image may be deallocated before destructors are run.
-    GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
-    if (Handler.isSymbolInImage(*this, Image, "nvptx$device$fini"))
-      Image.setPendingGlobalDtors();
-
     return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/true);
   }
 
   virtual Error callGlobalDestructors(GenericPluginTy &Plugin,
                                       DeviceImageTy &Image) override {
-    if (Image.hasPendingGlobalDtors())
-      return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
-    return Plugin::success();
+    return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
   }
 
   Expected<std::unique_ptr<MemoryBuffer>>
@@ -549,14 +543,15 @@ struct CUDADeviceTy : public GenericDeviceTy {
   CUdevice getCUDADevice() const { return Device; }
 
   /// Load the binary image into the device and allocate an image object.
-  Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
-                                           int32_t ImageId) override {
+  Expected<DeviceImageTy *>
+  loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage,
+                 int32_t ImageId) override {
     if (auto Err = setContext())
       return std::move(Err);
 
     // Allocate and initialize the image object.
     CUDADeviceImageTy *CUDAImage = Plugin.allocate<CUDADeviceImageTy>();
-    new (CUDAImage) CUDADeviceImageTy(ImageId, *this, TgtImage);
+    new (CUDAImage) CUDADeviceImageTy(ImageId, *this, std::move(TgtImage));
 
     // Load the CUDA module.
     if (auto Err = CUDAImage->loadModule())
@@ -1299,7 +1294,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
     // Perform a quick check for the named kernel in the image. The kernel
     // should be created by the 'nvptx-lower-ctor-dtor' pass.
     GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
-    if (IsCtor && !Handler.isSymbolInImage(*this, Image, KernelName))
+    if (!Handler.isSymbolInImage(*this, Image, KernelName))
       return Plugin::success();
 
     // The Nvidia backend cannot handle creating the ctor / dtor array
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 5436cae3b0293..0db01ca09ab02 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -131,8 +131,8 @@ struct GenELF64KernelTy : public GenericKernelTy {
 struct GenELF64DeviceImageTy : public DeviceImageTy {
   /// Create the GenELF64 image with the id and the target image pointer.
   GenELF64DeviceImageTy(int32_t ImageId, GenericDeviceTy &Device,
-                        const __tgt_device_image *TgtImage)
-      : DeviceImageTy(ImageId, Device, TgtImage), DynLib() {}
+                        std::unique_ptr<MemoryBuffer> &&TgtImage)
+      : DeviceImageTy(ImageId, Device, std::move(TgtImage)), DynLib() {}
 
   /// Getter and setter for the dynamic library.
   DynamicLibrary &getDynamicLibrary() { return DynLib; }
@@ -189,11 +189,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
   Error setContext() override { return Plugin::success(); }
 
   /// Load the binary image into the device and allocate an image object.
-  Expected<DeviceImageTy *> loadBinaryImpl(const __tgt_device_image *TgtImage,
-                                           int32_t ImageId) override {
+  Expected<DeviceImageTy *>
+  loadBinaryImpl(std::unique_ptr<MemoryBuffer> &&TgtImage,
+                 int32_t ImageId) override {
     // Allocate and initialize the image object.
     GenELF64DeviceImageTy *Image = Plugin.allocate<GenELF64DeviceImageTy>();
-    new (Image) GenELF64DeviceImageTy(ImageId, *this, TgtImage);
+    new (Image) GenELF64DeviceImageTy(ImageId, *this, std::move(TgtImage));
 
     // Create a temporary file.
     char TmpFileName[] = "/tmp/tmpfile_XXXXXX";



More information about the llvm-commits mailing list