[Openmp-commits] [openmp] f8e094b - [OpenMP][JIT] Cleanup JIT interface, caching, and races

Johannes Doerfert via Openmp-commits openmp-commits at lists.llvm.org
Sun Jan 15 11:44:36 PST 2023


Author: Johannes Doerfert
Date: 2023-01-15T11:43:50-08:00
New Revision: f8e094be8166443383f84831a406960a49281f04

URL: https://github.com/llvm/llvm-project/commit/f8e094be8166443383f84831a406960a49281f04
DIFF: https://github.com/llvm/llvm-project/commit/f8e094be8166443383f84831a406960a49281f04.diff

LOG: [OpenMP][JIT] Cleanup JIT interface, caching, and races

The JIT interface was somewhat irregular as it used multiple global
functions. It also did not cache the results of the JIT, hence multiple
GPU systems would perform the work multiple times. Finally, there might
have been races on the state if we have multi-threaded initialization of
different embedded images, or one image initialized on multiple devices.

This patch tries to rectify all of the above. The JITEngine is now a
part of the GenericPluginTy and tied to one target triple. To support
multiple "ComputeUnitKind"s (previously confusingly called Arch or
[M]CPU) and to avoid re-jitting for the same ComputeUnitKind, we keep a
map of JIT results per ComputeUnitKind. All interaction with the JIT
happens through the JITEngine directly, two functions are exposed. Both
use (shared) locks to avoid races and cache the result. All JIT-related
environment variables are now defined together.

Differential Revision: https://reviews.llvm.org/D141081

Added: 
    

Modified: 
    openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
    openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
    openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
    openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp

Removed: 
    


################################################################################
diff  --git a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
index 4eefd19d2b0d1..14efe841ee99b 100644
--- a/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -1530,7 +1530,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     char GPUName[64];
     if (auto Err = getDeviceAttr(HSA_AGENT_INFO_NAME, GPUName))
       return Err;
-    Arch = GPUName;
+    ComputeUnitKind = GPUName;
 
     // Get the wavefront size.
     uint32_t WavefrontSize = 0;
@@ -1669,7 +1669,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     INFO(OMP_INFOTYPE_PLUGIN_KERNEL, getDeviceId(),
          "Using `%s` to link JITed amdgcn ouput.", LLDPath.c_str());
 
-    std::string MCPU = "-plugin-opt=mcpu=" + getArch();
+    std::string MCPU = "-plugin-opt=mcpu=" + getComputeUnitKind();
 
     StringRef Args[] = {LLDPath,
                         "-flavor",
@@ -1692,7 +1692,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
         MemoryBuffer::getFileOrSTDIN(LinkerOutputFilePath.data()).get());
   }
 
-  std::string getArch() const override { return Arch; }
+  /// See GenericDeviceTy::getComputeUnitKind().
+  std::string getComputeUnitKind() const override { return ComputeUnitKind; }
 
   /// Allocate and construct an AMDGPU kernel.
   Expected<GenericKernelTy *>
@@ -2096,7 +2097,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
   hsa_agent_t Agent;
 
   /// The GPU architecture.
-  std::string Arch;
+  std::string ComputeUnitKind;
 
   /// Reference to the host device.
   AMDHostDeviceTy &HostDevice;
@@ -2244,7 +2245,7 @@ struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy {
 /// Class implementing the AMDGPU-specific functionalities of the plugin.
 struct AMDGPUPluginTy final : public GenericPluginTy {
   /// Create an AMDGPU plugin and initialize the AMDGPU driver.
-  AMDGPUPluginTy() : GenericPluginTy(), HostDevice(nullptr) {}
+  AMDGPUPluginTy() : GenericPluginTy(getTripleArch()), HostDevice(nullptr) {}
 
   /// This class should not be copied.
   AMDGPUPluginTy(const AMDGPUPluginTy &) = delete;

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
index aa0e599135261..4382135ff6758 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.cpp
@@ -11,11 +11,11 @@
 #include "JIT.h"
 #include "Debug.h"
 
+#include "PluginInterface.h"
 #include "Utilities.h"
 #include "omptarget.h"
 
 #include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
 #include "llvm/CodeGen/CommandFlags.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/IR/LLVMContext.h"
@@ -28,7 +28,6 @@
 #include "llvm/Object/IRObjectFile.h"
 #include "llvm/Passes/OptimizationLevel.h"
 #include "llvm/Passes/PassBuilder.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Support/TargetSelect.h"
@@ -39,15 +38,23 @@
 #include "llvm/Target/TargetOptions.h"
 
 #include <mutex>
+#include <shared_mutex>
 #include <system_error>
 
 using namespace llvm;
 using namespace llvm::object;
 using namespace omp;
+using namespace omp::target;
 
 static codegen::RegisterCodeGenFlags RCGF;
 
 namespace {
+
+/// A map from a bitcode image start address to its corresponding triple. If the
+/// image is not in the map, it is not a bitcode image.
+DenseMap<void *, Triple::ArchType> BitcodeImageMap;
+std::shared_mutex BitcodeImageMapMutex;
+
 std::once_flag InitFlag;
 
 void init(Triple TT) {
@@ -70,10 +77,8 @@ void init(Triple TT) {
     JITTargetInitialized = true;
   }
 #endif
-  if (!JITTargetInitialized) {
-    FAILURE_MESSAGE("unsupported JIT target: %s\n", TT.str().c_str());
-    abort();
-  }
+  if (!JITTargetInitialized)
+    return;
 
   // Initialize passes
   PassRegistry &Registry = *PassRegistry::getPassRegistry();
@@ -125,9 +130,9 @@ createModuleFromMemoryBuffer(std::unique_ptr<MemoryBuffer> &MB,
   return std::move(Mod);
 }
 Expected<std::unique_ptr<Module>>
-createModuleFromImage(__tgt_device_image *Image, LLVMContext &Context) {
-  StringRef Data((const char *)Image->ImageStart,
-                 target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+createModuleFromImage(const __tgt_device_image &Image, LLVMContext &Context) {
+  StringRef Data((const char *)Image.ImageStart,
+                 target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
   std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
       Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
   return createModuleFromMemoryBuffer(MB, Context);
@@ -192,44 +197,11 @@ createTargetMachine(Module &M, std::string CPU, unsigned OptLevel) {
   return std::move(TM);
 }
 
-///
-class JITEngine {
-public:
-  JITEngine(Triple::ArchType TA, std::string MCpu)
-      : TT(Triple::getArchTypeName(TA)), CPU(MCpu),
-        ReplacementModuleFileName("LIBOMPTARGET_JIT_REPLACEMENT_MODULE"),
-        PreOptIRModuleFileName("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE"),
-        PostOptIRModuleFileName("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE") {
-    std::call_once(InitFlag, init, TT);
-  }
-
-  /// Run jit compilation. It is expected to get a memory buffer containing the
-  /// generated device image that could be loaded to the device directly.
-  Expected<std::unique_ptr<MemoryBuffer>>
-  run(__tgt_device_image *Image, unsigned OptLevel,
-      jit::PostProcessingFn PostProcessing);
-
-private:
-  /// Run backend, which contains optimization and code generation.
-  Expected<std::unique_ptr<MemoryBuffer>> backend(Module &M, unsigned OptLevel);
-
-  /// Run optimization pipeline.
-  void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
-           unsigned OptLevel);
-
-  /// Run code generation.
-  void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
-               raw_pwrite_stream &OS);
-
-  LLVMContext Context;
-  const Triple TT;
-  const std::string CPU;
+} // namespace
 
-  /// Control environment variables.
-  target::StringEnvar ReplacementModuleFileName;
-  target::StringEnvar PreOptIRModuleFileName;
-  target::StringEnvar PostOptIRModuleFileName;
-};
+JITEngine::JITEngine(Triple::ArchType TA) : TT(Triple::getArchTypeName(TA)) {
+  std::call_once(InitFlag, init, TT);
+}
 
 void JITEngine::opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
                     unsigned OptLevel) {
@@ -274,18 +246,19 @@ void JITEngine::codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII,
   PM.run(M);
 }
 
-Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
-                                                           unsigned OptLevel) {
+Expected<std::unique_ptr<MemoryBuffer>>
+JITEngine::backend(Module &M, const std::string &ComputeUnitKind,
+                   unsigned OptLevel) {
 
   auto RemarksFileOrErr = setupLLVMOptimizationRemarks(
-      Context, /* RemarksFilename */ "", /* RemarksPasses */ "",
+      M.getContext(), /* RemarksFilename */ "", /* RemarksPasses */ "",
       /* RemarksFormat */ "", /* RemarksWithHotness */ false);
   if (Error E = RemarksFileOrErr.takeError())
     return std::move(E);
   if (*RemarksFileOrErr)
     (*RemarksFileOrErr)->keep();
 
-  auto TMOrErr = createTargetMachine(M, CPU, OptLevel);
+  auto TMOrErr = createTargetMachine(M, ComputeUnitKind, OptLevel);
   if (!TMOrErr)
     return TMOrErr.takeError();
 
@@ -323,14 +296,23 @@ Expected<std::unique_ptr<MemoryBuffer>> JITEngine::backend(Module &M,
   return MemoryBuffer::getMemBufferCopy(OS.str());
 }
 
-Expected<std::unique_ptr<MemoryBuffer>>
-JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
-               jit::PostProcessingFn PostProcessing) {
+Expected<const __tgt_device_image *>
+JITEngine::compile(const __tgt_device_image &Image,
+                   const std::string &ComputeUnitKind,
+                   PostProcessingFn PostProcessing) {
+  std::lock_guard<std::mutex> Lock(ComputeUnitMapMutex);
+
+  // Check if we JITed this image for the given compute unit kind before.
+  ComputeUnitInfo &CUI = ComputeUnitMap[ComputeUnitKind];
+  if (__tgt_device_image *JITedImage = CUI.TgtImageMap.lookup(&Image))
+    return JITedImage;
+
   Module *Mod = nullptr;
   // Check if the user replaces the module at runtime or we read it from the
   // image.
+  // TODO: Allow the user to specify images per device (Arch + ComputeUnitKind).
   if (!ReplacementModuleFileName.isPresent()) {
-    auto ModOrErr = createModuleFromImage(Image, Context);
+    auto ModOrErr = createModuleFromImage(Image, CUI.Context);
     if (!ModOrErr)
       return ModOrErr.takeError();
     Mod = ModOrErr->release();
@@ -341,44 +323,65 @@ JITEngine::run(__tgt_device_image *Image, unsigned OptLevel,
       return createStringError(MBOrErr.getError(),
                                "Could not read replacement module from %s\n",
                                ReplacementModuleFileName.get().c_str());
-    auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), Context);
+    auto ModOrErr = createModuleFromMemoryBuffer(MBOrErr.get(), CUI.Context);
     if (!ModOrErr)
       return ModOrErr.takeError();
     Mod = ModOrErr->release();
   }
 
-  auto MBOrError = backend(*Mod, OptLevel);
+  auto MBOrError = backend(*Mod, ComputeUnitKind, JITOptLevel);
   if (!MBOrError)
     return MBOrError.takeError();
 
-  return PostProcessing(std::move(*MBOrError));
+  auto ImageMBOrErr = PostProcessing(std::move(*MBOrError));
+  if (!ImageMBOrErr)
+    return ImageMBOrErr.takeError();
+
+  CUI.JITImages.push_back(std::move(*ImageMBOrErr));
+  __tgt_device_image *&JITedImage = CUI.TgtImageMap[&Image];
+  JITedImage = new __tgt_device_image();
+  *JITedImage = Image;
+
+  auto &ImageMB = CUI.JITImages.back();
+
+  JITedImage->ImageStart = (void *)ImageMB->getBufferStart();
+  JITedImage->ImageEnd = (void *)ImageMB->getBufferEnd();
+
+  return JITedImage;
 }
 
-/// A map from a bitcode image start address to its corresponding triple. If the
-/// image is not in the map, it is not a bitcode image.
-DenseMap<void *, Triple::ArchType> BitcodeImageMap;
+Expected<const __tgt_device_image *>
+JITEngine::process(const __tgt_device_image &Image,
+                   target::plugin::GenericDeviceTy &Device) {
+  const std::string &ComputeUnitKind = Device.getComputeUnitKind();
 
-/// Output images generated from LLVM backend.
-SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+  PostProcessingFn PostProcessing = [&Device](std::unique_ptr<MemoryBuffer> MB)
+      -> Expected<std::unique_ptr<MemoryBuffer>> {
+    return Device.doJITPostProcessing(std::move(MB));
+  };
 
-/// A list of __tgt_device_image images.
-std::list<__tgt_device_image> TgtImages;
-} // namespace
+  {
+    std::shared_lock<std::shared_mutex> SharedLock(BitcodeImageMapMutex);
+    auto Itr = BitcodeImageMap.find(Image.ImageStart);
+    if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
+      return compile(Image, ComputeUnitKind, PostProcessing);
+  }
 
-namespace llvm {
-namespace omp {
-namespace jit {
-bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
+  return &Image;
+}
+
+bool JITEngine::checkBitcodeImage(const __tgt_device_image &Image) {
   TimeTraceScope TimeScope("Check bitcode image");
+  std::lock_guard<std::shared_mutex> Lock(BitcodeImageMapMutex);
 
   {
-    auto Itr = BitcodeImageMap.find(Image->ImageStart);
-    if (Itr != BitcodeImageMap.end() && Itr->second == TA)
+    auto Itr = BitcodeImageMap.find(Image.ImageStart);
+    if (Itr != BitcodeImageMap.end() && Itr->second == TT.getArch())
       return true;
   }
 
-  StringRef Data(reinterpret_cast<const char *>(Image->ImageStart),
-                 target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
+  StringRef Data(reinterpret_cast<const char *>(Image.ImageStart),
+                 target::getPtrDiff(Image.ImageEnd, Image.ImageStart));
   std::unique_ptr<MemoryBuffer> MB = MemoryBuffer::getMemBuffer(
       Data, /* BufferName */ "", /* RequiresNullTerminator */ false);
   if (!MB)
@@ -391,37 +394,8 @@ bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA) {
   }
 
   auto ActualTriple = FOrErr->TheReader.getTargetTriple();
+  auto BitcodeTA = Triple(ActualTriple).getArch();
+  BitcodeImageMap[Image.ImageStart] = BitcodeTA;
 
-  if (Triple(ActualTriple).getArch() == TA) {
-    BitcodeImageMap[Image->ImageStart] = TA;
-    return true;
-  }
-
-  return false;
+  return BitcodeTA == TT.getArch();
 }
-
-Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
-                                       Triple::ArchType TA, std::string MCPU,
-                                       unsigned OptLevel,
-                                       PostProcessingFn PostProcessing) {
-  JITEngine J(TA, MCPU);
-
-  auto ImageMBOrErr = J.run(Image, OptLevel, PostProcessing);
-  if (!ImageMBOrErr)
-    return ImageMBOrErr.takeError();
-
-  JITImages.push_back(std::move(*ImageMBOrErr));
-  TgtImages.push_back(*Image);
-
-  auto &ImageMB = JITImages.back();
-  auto *NewImage = &TgtImages.back();
-
-  NewImage->ImageStart = (void *)ImageMB->getBufferStart();
-  NewImage->ImageEnd = (void *)ImageMB->getBufferEnd();
-
-  return NewImage;
-}
-
-} // namespace jit
-} // namespace omp
-} // namespace llvm

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
index 73483cebc264d..0c51810690c8f 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/JIT.h
@@ -11,12 +11,19 @@
 #ifndef OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H
 #define OPENMP_LIBOMPTARGET_PLUGINS_NEXTGEN_COMMON_JIT_H
 
+#include "Utilities.h"
+
+#include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/IR/Module.h"
 #include "llvm/Support/Error.h"
+#include "llvm/Target/TargetMachine.h"
 
 #include <functional>
 #include <memory>
+#include <shared_mutex>
 #include <string>
 
 struct __tgt_device_image;
@@ -25,25 +32,84 @@ namespace llvm {
 class MemoryBuffer;
 
 namespace omp {
-namespace jit {
-
-/// Function type for a callback that will be called after the backend is
-/// called.
-using PostProcessingFn = std::function<Expected<std::unique_ptr<MemoryBuffer>>(
-    std::unique_ptr<MemoryBuffer>)>;
-
-/// Check if \p Image contains bitcode with triple \p Triple.
-bool checkBitcodeImage(__tgt_device_image *Image, Triple::ArchType TA);
-
-/// Compile the bitcode image \p Image and generate the binary image that can be
-/// loaded to the target device of the triple \p Triple architecture \p MCpu. \p
-/// PostProcessing will be called after codegen to handle cases such as assember
-/// as an external tool.
-Expected<__tgt_device_image *> compile(__tgt_device_image *Image,
-                                       Triple::ArchType TA, std::string MCpu,
-                                       unsigned OptLevel,
-                                       PostProcessingFn PostProcessing);
-} // namespace jit
+namespace target {
+namespace plugin {
+struct GenericDeviceTy;
+} // namespace plugin
+
+/// The JIT infrastructure and caching mechanism.
+struct JITEngine {
+  /// Function type for a callback that will be called after the backend is
+  /// called.
+  using PostProcessingFn =
+      std::function<Expected<std::unique_ptr<MemoryBuffer>>(
+          std::unique_ptr<MemoryBuffer>)>;
+
+  JITEngine(Triple::ArchType TA);
+
+  /// Run jit compilation if \p Image is a bitcode image, otherwise simply
+  /// return \p Image. It is expected to return a memory buffer containing the
+  /// generated device image that could be loaded to the device directly.
+  Expected<const __tgt_device_image *>
+  process(const __tgt_device_image &Image,
+          target::plugin::GenericDeviceTy &Device);
+
+  /// Return true if \p Image is a bitcode image that can be JITed for the given
+  /// architecture.
+  bool checkBitcodeImage(const __tgt_device_image &Image);
+
+private:
+  /// Compile the bitcode image \p Image and generate the binary image that can
+  /// be loaded to the target device of the triple \p Triple architecture \p
+  /// MCpu. \p PostProcessing will be called after codegen to handle cases such
+  /// as assember as an external tool.
+  Expected<const __tgt_device_image *>
+  compile(const __tgt_device_image &Image, const std::string &ComputeUnitKind,
+          PostProcessingFn PostProcessing);
+
+  /// Run backend, which contains optimization and code generation.
+  Expected<std::unique_ptr<MemoryBuffer>>
+  backend(Module &M, const std::string &ComputeUnitKind, unsigned OptLevel);
+
+  /// Run optimization pipeline.
+  void opt(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
+           unsigned OptLevel);
+
+  /// Run code generation.
+  void codegen(TargetMachine *TM, TargetLibraryInfoImpl *TLII, Module &M,
+               raw_pwrite_stream &OS);
+
+  /// The target triple used by the JIT.
+  const Triple TT;
+
+  struct ComputeUnitInfo {
+    /// LLVM Context in which the modules will be constructed.
+    LLVMContext Context;
+
+    /// Output images generated from LLVM backend.
+    SmallVector<std::unique_ptr<MemoryBuffer>, 4> JITImages;
+
+    /// A map of embedded IR images to JITed images.
+    DenseMap<const __tgt_device_image *, __tgt_device_image *> TgtImageMap;
+  };
+
+  /// Map from (march) "CPUs" (e.g., sm_80, or gfx90a), which we call compute
+  /// units as they are not CPUs, to the image information we cached for them.
+  StringMap<ComputeUnitInfo> ComputeUnitMap;
+  std::mutex ComputeUnitMapMutex;
+
+  /// Control environment variables.
+  target::StringEnvar ReplacementModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_REPLACEMENT_MODULE");
+  target::StringEnvar PreOptIRModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_PRE_OPT_IR_MODULE");
+  target::StringEnvar PostOptIRModuleFileName =
+      target::StringEnvar("LIBOMPTARGET_JIT_POST_OPT_IR_MODULE");
+  target::UInt32Envar JITOptLevel =
+      target::UInt32Envar("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
+};
+
+} // namespace target
 } // namespace omp
 } // namespace llvm
 

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
index 3ff7e096ed306..96800fe518c40 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp
@@ -212,12 +212,22 @@ Error GenericDeviceTy::deinit() {
 
 Expected<__tgt_target_table *>
 GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
-                            const __tgt_device_image *TgtImage) {
-  DP("Load data from image " DPxMOD "\n", DPxPTR(TgtImage->ImageStart));
+                            const __tgt_device_image *InputTgtImage) {
+  assert(InputTgtImage && "Expected non-null target image");
+  DP("Load data from image " DPxMOD "\n", DPxPTR(InputTgtImage->ImageStart));
+
+  auto PostJITImageOrErr = Plugin.getJIT().process(*InputTgtImage, *this);
+  if (!PostJITImageOrErr) {
+    auto Err = PostJITImageOrErr.takeError();
+    REPORT("Failure to jit IR image %p on device %d: %s\n", InputTgtImage,
+           DeviceId, toString(std::move(Err)).data());
+    return nullptr;
+  }
 
   // Load the binary and allocate the image object. Use the next available id
   // for the image id, which is the number of previously loaded images.
-  auto ImageOrErr = loadBinaryImpl(TgtImage, LoadedImages.size());
+  auto ImageOrErr =
+      loadBinaryImpl(PostJITImageOrErr.get(), LoadedImages.size());
   if (!ImageOrErr)
     return ImageOrErr.takeError();
 
@@ -668,7 +678,7 @@ int32_t __tgt_rtl_is_valid_binary(__tgt_device_image *TgtImage) {
   if (elf_check_machine(TgtImage, Plugin::get().getMagicElfBits()))
     return true;
 
-  return jit::checkBitcodeImage(TgtImage, Plugin::get().getTripleArch());
+  return Plugin::get().getJIT().checkBitcodeImage(*TgtImage);
 }
 
 int32_t __tgt_rtl_is_valid_binary_info(__tgt_device_image *TgtImage,
@@ -745,34 +755,6 @@ __tgt_target_table *__tgt_rtl_load_binary(int32_t DeviceId,
   GenericPluginTy &Plugin = Plugin::get();
   GenericDeviceTy &Device = Plugin.getDevice(DeviceId);
 
-  // If it is a bitcode image, we have to jit the binary image before loading to
-  // the device.
-  {
-    // TODO: Move this (at least the environment variable) into the JIT.h.
-    UInt32Envar JITOptLevel("LIBOMPTARGET_JIT_OPT_LEVEL", 3);
-    Triple::ArchType TA = Plugin.getTripleArch();
-    std::string Arch = Device.getArch();
-
-    jit::PostProcessingFn PostProcessing =
-        [&Device](std::unique_ptr<MemoryBuffer> MB)
-        -> Expected<std::unique_ptr<MemoryBuffer>> {
-      return Device.doJITPostProcessing(std::move(MB));
-    };
-
-    if (jit::checkBitcodeImage(TgtImage, TA)) {
-      auto TgtImageOrErr =
-          jit::compile(TgtImage, TA, Arch, JITOptLevel, PostProcessing);
-      if (!TgtImageOrErr) {
-        auto Err = TgtImageOrErr.takeError();
-        REPORT("Failure to jit binary image from bitcode image %p on device "
-               "%d: %s\n",
-               TgtImage, DeviceId, toString(std::move(Err)).data());
-        return nullptr;
-      }
-
-      TgtImage = *TgtImageOrErr;
-    }
-  }
 
   auto TableOrErr = Device.loadBinary(Plugin, TgtImage);
   if (!TableOrErr) {

diff  --git a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
index 752ee2e38de16..d65209d3cfbd8 100644
--- a/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
+++ b/openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.h
@@ -21,6 +21,7 @@
 #include "Debug.h"
 #include "DeviceEnvironment.h"
 #include "GlobalHandler.h"
+#include "JIT.h"
 #include "MemoryManager.h"
 #include "Utilities.h"
 #include "omptarget.h"
@@ -37,6 +38,7 @@
 namespace llvm {
 namespace omp {
 namespace target {
+
 namespace plugin {
 
 struct GenericPluginTy;
@@ -378,10 +380,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   }
   uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
 
-  /// Get target architecture.
-  virtual std::string getArch() const {
-    return "unknown";
-  }
+  /// Get target compute unit kind (e.g., sm_80, or gfx908).
+  virtual std::string getComputeUnitKind() const { return "unknown"; }
 
   /// Post processing after jit backend. The ownership of \p MB will be taken.
   virtual Expected<std::unique_ptr<MemoryBuffer>>
@@ -513,8 +513,8 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
 struct GenericPluginTy {
 
   /// Construct a plugin instance.
-  GenericPluginTy()
-      : RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr) {}
+  GenericPluginTy(Triple::ArchType TA)
+      : RequiresFlags(OMP_REQ_UNDEFINED), GlobalHandler(nullptr), JIT(TA) {}
 
   virtual ~GenericPluginTy() {}
 
@@ -543,9 +543,7 @@ struct GenericPluginTy {
   virtual uint16_t getMagicElfBits() const = 0;
 
   /// Get the target triple of this plugin.
-  virtual Triple::ArchType getTripleArch() const {
-    return Triple::ArchType::UnknownArch;
-  }
+  virtual Triple::ArchType getTripleArch() const = 0;
 
   /// Allocate a structure using the internal allocator.
   template <typename Ty> Ty *allocate() {
@@ -558,6 +556,10 @@ struct GenericPluginTy {
     return *GlobalHandler;
   }
 
+  /// Get the reference to the JIT used for all devices connected to this
+  /// plugin.
+  JITEngine &getJIT() { return JIT; }
+
   /// Get the OpenMP requires flags set for this plugin.
   int64_t getRequiresFlags() const { return RequiresFlags; }
 
@@ -609,6 +611,9 @@ struct GenericPluginTy {
 
   /// Internal allocator for 
diff erent structures.
   BumpPtrAllocator Allocator;
+
+  /// The JIT engine shared by all devices connected to this plugin.
+  JITEngine JIT;
 };
 
 /// Class for simplifying the getter operation of the plugin. Anywhere on the

diff  --git a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
index cb5f004d81046..cfc97b6cdc46c 100644
--- a/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp
@@ -784,8 +784,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
     return Plugin::check(Res, "Error in cuDeviceGetAttribute: %s");
   }
 
-  /// See GenericDeviceTy::getArch().
-  std::string getArch() const override { return ComputeCapability.str(); }
+  /// See GenericDeviceTy::getComputeUnitKind().
+  std::string getComputeUnitKind() const override {
+    return ComputeCapability.str();
+  }
 
 private:
   using CUDAStreamManagerTy = GenericDeviceResourceManagerTy<CUDAStreamRef>;
@@ -867,7 +869,7 @@ class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy {
 /// Class implementing the CUDA-specific functionalities of the plugin.
 struct CUDAPluginTy final : public GenericPluginTy {
   /// Create a CUDA plugin.
-  CUDAPluginTy() : GenericPluginTy() {}
+  CUDAPluginTy() : GenericPluginTy(getTripleArch()) {}
 
   /// This class should not be copied.
   CUDAPluginTy(const CUDAPluginTy &) = delete;

diff  --git a/openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp b/openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp
index ed6897a11a3b2..ab5f9e3f29faa 100644
--- a/openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp
+++ b/openmp/libomptarget/plugins-nextgen/generic-elf-64bit/src/rtl.cpp
@@ -340,7 +340,7 @@ class GenELF64GlobalHandlerTy final : public GenericGlobalHandlerTy {
 /// Class implementing the plugin functionalities for GenELF64.
 struct GenELF64PluginTy final : public GenericPluginTy {
   /// Create the GenELF64 plugin.
-  GenELF64PluginTy() : GenericPluginTy() {}
+  GenELF64PluginTy() : GenericPluginTy(getTripleArch()) {}
 
   /// This class should not be copied.
   GenELF64PluginTy(const GenELF64PluginTy &) = delete;


        


More information about the Openmp-commits mailing list