[llvm] [Offload] Move RPC server handling to a dedicated thread (PR #112988)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 2 14:49:33 PST 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/112988

>From 41facc9391ad20552e4b98bf1983e07e491ca452 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 18 Oct 2024 16:48:33 -0500
Subject: [PATCH] [Offload] Move RPC server handling to a dedicated thread

Summary:
Handling the RPC server requires running through list of jobs that the
device has requested to be done. Currently this is handled by the thread
that does the waiting for the kernel to finish. However, this is not
sound on NVIDIA architectures and only works for async launches in the
OpenMP model that uses helper threads.

However, we also don't want to have this thread doing work
unnnecessarily. For this reason we track the execution of kernels and
cause the thread to sleep via a condition variable (usually backed by
some kind of futex or other intelligent sleeping mechanism) so that the
thread will be idle while no kernels are running.

Use cuLaunchHostFunc

Only create thread if used
---
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    |  59 +++++----
 offload/plugins-nextgen/common/include/RPC.h  |  80 +++++++++++-
 .../common/src/PluginInterface.cpp            |   8 +-
 offload/plugins-nextgen/common/src/RPC.cpp    | 123 +++++++++++++-----
 .../cuda/dynamic_cuda/cuda.cpp                |   1 +
 .../plugins-nextgen/cuda/dynamic_cuda/cuda.h  |   3 +
 offload/plugins-nextgen/cuda/src/rtl.cpp      |  39 +++---
 offload/test/libc/server.c                    |  56 ++++++++
 8 files changed, 281 insertions(+), 88 deletions(-)
 create mode 100644 offload/test/libc/server.c

diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 22c8079ab5812f..581c8e0c4fa26a 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -621,9 +621,9 @@ struct AMDGPUSignalTy {
   }
 
   /// Wait until the signal gets a zero value.
-  Error wait(const uint64_t ActiveTimeout = 0, RPCServerTy *RPCServer = nullptr,
+  Error wait(const uint64_t ActiveTimeout = 0,
              GenericDeviceTy *Device = nullptr) const {
-    if (ActiveTimeout && !RPCServer) {
+    if (ActiveTimeout) {
       hsa_signal_value_t Got = 1;
       Got = hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
                                       ActiveTimeout, HSA_WAIT_STATE_ACTIVE);
@@ -632,14 +632,11 @@ struct AMDGPUSignalTy {
     }
 
     // If there is an RPC device attached to this stream we run it as a server.
-    uint64_t Timeout = RPCServer ? 8192 : UINT64_MAX;
-    auto WaitState = RPCServer ? HSA_WAIT_STATE_ACTIVE : HSA_WAIT_STATE_BLOCKED;
+    uint64_t Timeout = UINT64_MAX;
+    auto WaitState = HSA_WAIT_STATE_BLOCKED;
     while (hsa_signal_wait_scacquire(HSASignal, HSA_SIGNAL_CONDITION_EQ, 0,
-                                     Timeout, WaitState) != 0) {
-      if (RPCServer && Device)
-        if (auto Err = RPCServer->runServer(*Device))
-          return Err;
-    }
+                                     Timeout, WaitState) != 0)
+      ;
     return Plugin::success();
   }
 
@@ -1048,11 +1045,6 @@ struct AMDGPUStreamTy {
   /// operation that was already finalized in a previous stream sycnhronize.
   uint32_t SyncCycle;
 
-  /// A pointer associated with an RPC server running on the given device. If
-  /// RPC is not being used this will be a null pointer. Otherwise, this
-  /// indicates that an RPC server is expected to be run on this stream.
-  RPCServerTy *RPCServer;
-
   /// Mutex to protect stream's management.
   mutable std::mutex Mutex;
 
@@ -1232,9 +1224,6 @@ struct AMDGPUStreamTy {
   /// Deinitialize the stream's signals.
   Error deinit() { return Plugin::success(); }
 
-  /// Attach an RPC server to this stream.
-  void setRPCServer(RPCServerTy *Server) { RPCServer = Server; }
-
   /// Push a asynchronous kernel to the stream. The kernel arguments must be
   /// placed in a special allocation for kernel args and must keep alive until
   /// the kernel finalizes. Once the kernel is finished, the stream will release
@@ -1262,10 +1251,30 @@ struct AMDGPUStreamTy {
     if (auto Err = Slots[Curr].schedReleaseBuffer(KernelArgs, MemoryManager))
       return Err;
 
+    // If we are running an RPC server we want to wake up the server thread
+    // whenever there is a kernel running and let it sleep otherwise.
+    if (Device.getRPCServer())
+      Device.Plugin.getRPCServer().Thread->notify();
+
     // Push the kernel with the output signal and an input signal (optional)
-    return Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads, NumBlocks,
-                                   GroupSize, StackSize, OutputSignal,
-                                   InputSignal);
+    if (auto Err = Queue->pushKernelLaunch(Kernel, KernelArgs, NumThreads,
+                                           NumBlocks, GroupSize, StackSize,
+                                           OutputSignal, InputSignal))
+      return Err;
+
+    // Register a callback to indicate when the kernel is complete.
+    if (Device.getRPCServer()) {
+      if (auto Err = Slots[Curr].schedCallback(
+              [](void *Data) -> llvm::Error {
+                GenericPluginTy &Plugin =
+                    *reinterpret_cast<GenericPluginTy *>(Data);
+                Plugin.getRPCServer().Thread->finish();
+                return Error::success();
+              },
+              &Device.Plugin))
+        return Err;
+    }
+    return Plugin::success();
   }
 
   /// Push an asynchronous memory copy between pinned memory buffers.
@@ -1475,8 +1484,8 @@ struct AMDGPUStreamTy {
       return Plugin::success();
 
     // Wait until all previous operations on the stream have completed.
-    if (auto Err = Slots[last()].Signal->wait(StreamBusyWaitMicroseconds,
-                                              RPCServer, &Device))
+    if (auto Err =
+            Slots[last()].Signal->wait(StreamBusyWaitMicroseconds, &Device))
       return Err;
 
     // Reset the stream and perform all pending post actions.
@@ -3025,7 +3034,7 @@ AMDGPUStreamTy::AMDGPUStreamTy(AMDGPUDeviceTy &Device)
     : Agent(Device.getAgent()), Queue(nullptr),
       SignalManager(Device.getSignalManager()), Device(Device),
       // Initialize the std::deque with some empty positions.
-      Slots(32), NextSlot(0), SyncCycle(0), RPCServer(nullptr),
+      Slots(32), NextSlot(0), SyncCycle(0),
       StreamBusyWaitMicroseconds(Device.getStreamBusyWaitMicroseconds()),
       UseMultipleSdmaEngines(Device.useMultipleSdmaEngines()) {}
 
@@ -3378,10 +3387,6 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
   if (auto Err = AMDGPUDevice.getStream(AsyncInfoWrapper, Stream))
     return Err;
 
-  // If this kernel requires an RPC server we attach its pointer to the stream.
-  if (GenericDevice.getRPCServer())
-    Stream->setRPCServer(GenericDevice.getRPCServer());
-
   // Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
   if (ImplArgs &&
       getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h
index 5b9b7ffd086b57..f3a8e7555020d5 100644
--- a/offload/plugins-nextgen/common/include/RPC.h
+++ b/offload/plugins-nextgen/common/include/RPC.h
@@ -19,7 +19,11 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/Error.h"
 
+#include <atomic>
+#include <condition_variable>
 #include <cstdint>
+#include <mutex>
+#include <thread>
 
 namespace llvm::omp::target {
 namespace plugin {
@@ -37,6 +41,12 @@ struct RPCServerTy {
   /// Initializes the handles to the number of devices we may need to service.
   RPCServerTy(plugin::GenericPluginTy &Plugin);
 
+  /// Deinitialize the associated memory and resources.
+  llvm::Error shutDown();
+
+  /// Initialize the worker thread.
+  llvm::Error startThread();
+
   /// Check if this device image is using an RPC server. This checks for the
   /// precense of an externally visible symbol in the device image that will
   /// be present whenever RPC code is called.
@@ -51,17 +61,77 @@ struct RPCServerTy {
                          plugin::GenericGlobalHandlerTy &Handler,
                          plugin::DeviceImageTy &Image);
 
-  /// Runs the RPC server associated with the \p Device until the pending work
-  /// is cleared.
-  llvm::Error runServer(plugin::GenericDeviceTy &Device);
-
   /// Deinitialize the RPC server for the given device. This will free the
   /// memory associated with the k
   llvm::Error deinitDevice(plugin::GenericDeviceTy &Device);
 
 private:
   /// Array from this device's identifier to its attached devices.
-  llvm::SmallVector<void *> Buffers;
+  std::unique_ptr<void *[]> Buffers;
+
+  /// Array of associated devices. These must be alive as long as the server is.
+  std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
+
+  /// A helper class for running the user thread that handles the RPC interface.
+  /// Because we only need to check the RPC server while any kernels are
+  /// working, we track submission / completion events to allow the thread to
+  /// sleep when it is not needed.
+  struct ServerThread {
+    std::thread Worker;
+
+    /// A boolean indicating whether or not the worker thread should continue.
+    std::atomic<bool> Running;
+
+    /// The number of currently executing kernels across all devices that need
+    /// the server thread to be running.
+    std::atomic<uint32_t> NumUsers;
+
+    /// The condition variable used to suspend the thread if no work is needed.
+    std::condition_variable CV;
+    std::mutex Mutex;
+
+    /// A reference to all the RPC interfaces that the server is handling.
+    llvm::ArrayRef<void *> Buffers;
+
+    /// A reference to the associated generic device for the buffer.
+    llvm::ArrayRef<plugin::GenericDeviceTy *> Devices;
+
+    /// Initialize the worker thread to run in the background.
+    ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
+                 size_t Length)
+        : Running(true), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
+          Devices(Devices, Length) {}
+
+    ~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
+
+    /// Notify the worker thread that there is a user that needs it.
+    void notify() {
+      std::lock_guard<decltype(Mutex)> Lock(Mutex);
+      NumUsers.fetch_add(1, std::memory_order_relaxed);
+      CV.notify_all();
+    }
+
+    /// Indicate that one of the dependent users has finished.
+    void finish() {
+      [[maybe_unused]] uint32_t Old =
+          NumUsers.fetch_sub(1, std::memory_order_relaxed);
+      assert(Old > 0 && "Attempt to signal finish with no pending work");
+    }
+
+    /// Destroy the worker thread and wait.
+    void shutDown();
+
+    /// Initialize the worker thread.
+    void startThread();
+
+    /// Run the server thread to continuously check the RPC interface for work
+    /// to be done for the device.
+    void run();
+  };
+
+public:
+  /// Pointer to the server thread instance.
+  std::unique_ptr<ServerThread> Thread;
 };
 
 } // namespace llvm::omp::target
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index 5cdf12176a0d66..010235a28ec8cb 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -1051,6 +1051,9 @@ Error GenericDeviceTy::setupRPCServer(GenericPluginTy &Plugin,
   if (auto Err = Server.initDevice(*this, Plugin.getGlobalHandler(), Image))
     return Err;
 
+  if (auto Err = Server.startThread())
+    return Err;
+
   RPCServer = &Server;
   DP("Running an RPC server on device %d\n", getDeviceId());
   return Plugin::success();
@@ -1624,8 +1627,11 @@ Error GenericPluginTy::deinit() {
   if (GlobalHandler)
     delete GlobalHandler;
 
-  if (RPCServer)
+  if (RPCServer) {
+    if (Error Err = RPCServer->shutDown())
+      return Err;
     delete RPCServer;
+  }
 
   if (RecordReplay)
     delete RecordReplay;
diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp
index 66f98e68dc4429..25f1ee6a3a1257 100644
--- a/offload/plugins-nextgen/common/src/RPC.cpp
+++ b/offload/plugins-nextgen/common/src/RPC.cpp
@@ -21,8 +21,8 @@ using namespace omp;
 using namespace target;
 
 template <uint32_t NumLanes>
-rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device,
-                                   rpc::Server::Port &Port) {
+rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
+                                 rpc::Server::Port &Port) {
 
   switch (Port.get_opcode()) {
   case LIBC_MALLOC: {
@@ -62,21 +62,99 @@ rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device,
   return rpc::SUCCESS;
 }
 
-static rpc::Status handle_offload_opcodes(plugin::GenericDeviceTy &Device,
-                                          rpc::Server::Port &Port,
-                                          uint32_t NumLanes) {
+static rpc::Status handleOffloadOpcodes(plugin::GenericDeviceTy &Device,
+                                        rpc::Server::Port &Port,
+                                        uint32_t NumLanes) {
   if (NumLanes == 1)
-    return handle_offload_opcodes<1>(Device, Port);
+    return handleOffloadOpcodes<1>(Device, Port);
   else if (NumLanes == 32)
-    return handle_offload_opcodes<32>(Device, Port);
+    return handleOffloadOpcodes<32>(Device, Port);
   else if (NumLanes == 64)
-    return handle_offload_opcodes<64>(Device, Port);
+    return handleOffloadOpcodes<64>(Device, Port);
   else
     return rpc::ERROR;
 }
 
+static rpc::Status runServer(plugin::GenericDeviceTy &Device, void *Buffer) {
+  uint64_t NumPorts =
+      std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
+  rpc::Server Server(NumPorts, Buffer);
+
+  auto Port = Server.try_open(Device.getWarpSize());
+  if (!Port)
+    return rpc::SUCCESS;
+
+  rpc::Status Status =
+      handleOffloadOpcodes(Device, *Port, Device.getWarpSize());
+
+  // Let the `libc` library handle any other unhandled opcodes.
+#ifdef LIBOMPTARGET_RPC_SUPPORT
+  if (Status == rpc::UNHANDLED_OPCODE)
+    Status = handle_libc_opcodes(*Port, Device.getWarpSize());
+#endif
+
+  Port->close();
+
+  return Status;
+}
+
+void RPCServerTy::ServerThread::startThread() {
+  Worker = std::thread([this]() { run(); });
+}
+
+void RPCServerTy::ServerThread::shutDown() {
+  {
+    std::lock_guard<decltype(Mutex)> Lock(Mutex);
+    Running.store(false, std::memory_order_release);
+    CV.notify_all();
+  }
+  if (Worker.joinable())
+    Worker.join();
+}
+
+void RPCServerTy::ServerThread::run() {
+  std::unique_lock<decltype(Mutex)> Lock(Mutex);
+  for (;;) {
+    CV.wait(Lock, [&]() {
+      return NumUsers.load(std::memory_order_acquire) > 0 ||
+             !Running.load(std::memory_order_acquire);
+    });
+
+    if (!Running.load(std::memory_order_acquire))
+      return;
+
+    Lock.unlock();
+    while (NumUsers.load(std::memory_order_relaxed) > 0 &&
+           Running.load(std::memory_order_relaxed)) {
+      for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
+        if (!Buffer || !Device)
+          continue;
+
+        // If running the server failed, print a message but keep running.
+        if (runServer(*Device, Buffer) != rpc::SUCCESS)
+          FAILURE_MESSAGE("Unhandled or invalid RPC opcode!");
+      }
+    }
+    Lock.lock();
+  }
+}
+
 RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
-    : Buffers(Plugin.getNumDevices()) {}
+    : Buffers(std::make_unique<void *[]>(Plugin.getNumDevices())),
+      Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
+          Plugin.getNumDevices())),
+      Thread(new ServerThread(Buffers.get(), Devices.get(),
+                              Plugin.getNumDevices())) {}
+
+llvm::Error RPCServerTy::startThread() {
+  Thread->startThread();
+  return Error::success();
+}
+
+llvm::Error RPCServerTy::shutDown() {
+  Thread->shutDown();
+  return Error::success();
+}
 
 llvm::Expected<bool>
 RPCServerTy::isDeviceUsingRPC(plugin::GenericDeviceTy &Device,
@@ -108,35 +186,14 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
                                    sizeof(rpc::Client), nullptr))
     return Err;
   Buffers[Device.getDeviceId()] = RPCBuffer;
-
-  return Error::success();
-}
-
-Error RPCServerTy::runServer(plugin::GenericDeviceTy &Device) {
-  uint64_t NumPorts =
-      std::min(Device.requestedRPCPortCount(), rpc::MAX_PORT_COUNT);
-  rpc::Server Server(NumPorts, Buffers[Device.getDeviceId()]);
-
-  auto Port = Server.try_open(Device.getWarpSize());
-  if (!Port)
-    return Error::success();
-
-  int Status = handle_offload_opcodes(Device, *Port, Device.getWarpSize());
-
-  // Let the `libc` library handle any other unhandled opcodes.
-#ifdef LIBOMPTARGET_RPC_SUPPORT
-  if (Status == rpc::UNHANDLED_OPCODE)
-    Status = handle_libc_opcodes(*Port, Device.getWarpSize());
-#endif
-
-  Port->close();
-  if (Status != rpc::SUCCESS)
-    return createStringError("RPC server given invalid opcode!");
+  Devices[Device.getDeviceId()] = &Device;
 
   return Error::success();
 }
 
 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
   Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
+  Buffers[Device.getDeviceId()] = nullptr;
+  Devices[Device.getDeviceId()] = nullptr;
   return Error::success();
 }
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
index 5ec3adb9e4e3a1..7878499dbfcb7e 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.cpp
@@ -63,6 +63,7 @@ DLWRAP(cuStreamCreate, 2)
 DLWRAP(cuStreamDestroy, 1)
 DLWRAP(cuStreamSynchronize, 1)
 DLWRAP(cuStreamQuery, 1)
+DLWRAP(cuStreamAddCallback, 4)
 DLWRAP(cuCtxSetCurrent, 1)
 DLWRAP(cuDevicePrimaryCtxRelease, 1)
 DLWRAP(cuDevicePrimaryCtxGetState, 3)
diff --git a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
index 16c8f7ad46c445..ad874735a25ed9 100644
--- a/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
+++ b/offload/plugins-nextgen/cuda/dynamic_cuda/cuda.h
@@ -286,6 +286,8 @@ static inline void *CU_LAUNCH_PARAM_END = (void *)0x00;
 static inline void *CU_LAUNCH_PARAM_BUFFER_POINTER = (void *)0x01;
 static inline void *CU_LAUNCH_PARAM_BUFFER_SIZE = (void *)0x02;
 
+typedef void (*CUstreamCallback)(CUstream, CUresult, void *);
+
 CUresult cuCtxGetDevice(CUdevice *);
 CUresult cuDeviceGet(CUdevice *, int);
 CUresult cuDeviceGetAttribute(int *, CUdevice_attribute, CUdevice);
@@ -326,6 +328,7 @@ CUresult cuStreamCreate(CUstream *, unsigned);
 CUresult cuStreamDestroy(CUstream);
 CUresult cuStreamSynchronize(CUstream);
 CUresult cuStreamQuery(CUstream);
+CUresult cuStreamAddCallback(CUstream, CUstreamCallback, void *, unsigned int);
 CUresult cuCtxSetCurrent(CUcontext);
 CUresult cuDevicePrimaryCtxRelease(CUdevice);
 CUresult cuDevicePrimaryCtxGetState(CUdevice, unsigned *, int *);
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 9af71b06ce97d3..4eb57fb675c7cd 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -628,17 +628,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
   Error synchronizeImpl(__tgt_async_info &AsyncInfo) override {
     CUstream Stream = reinterpret_cast<CUstream>(AsyncInfo.Queue);
     CUresult Res;
-    // If we have an RPC server running on this device we will continuously
-    // query it for work rather than blocking.
-    if (!getRPCServer()) {
-      Res = cuStreamSynchronize(Stream);
-    } else {
-      do {
-        Res = cuStreamQuery(Stream);
-        if (auto Err = getRPCServer()->runServer(*this))
-          return Err;
-      } while (Res == CUDA_ERROR_NOT_READY);
-    }
+    Res = cuStreamSynchronize(Stream);
 
     // Once the stream is synchronized, return it to stream pool and reset
     // AsyncInfo. This is to make sure the synchronization only works for its
@@ -823,17 +813,6 @@ struct CUDADeviceTy : public GenericDeviceTy {
     if (auto Err = getStream(AsyncInfoWrapper, Stream))
       return Err;
 
-    // If there is already pending work on the stream it could be waiting for
-    // someone to check the RPC server.
-    if (auto *RPCServer = getRPCServer()) {
-      CUresult Res = cuStreamQuery(Stream);
-      while (Res == CUDA_ERROR_NOT_READY) {
-        if (auto Err = RPCServer->runServer(*this))
-          return Err;
-        Res = cuStreamQuery(Stream);
-      }
-    }
-
     CUresult Res = cuMemcpyDtoHAsync(HstPtr, (CUdeviceptr)TgtPtr, Size, Stream);
     return Plugin::check(Res, "Error in cuMemcpyDtoHAsync: %s");
   }
@@ -1292,10 +1271,26 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
                     reinterpret_cast<void *>(&LaunchParams.Size),
                     CU_LAUNCH_PARAM_END};
 
+  // If we are running an RPC server we want to wake up the server thread
+  // whenever there is a kernel running and let it sleep otherwise.
+  if (GenericDevice.getRPCServer())
+    GenericDevice.Plugin.getRPCServer().Thread->notify();
+
   CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
                                 /*gridDimZ=*/1, NumThreads,
                                 /*blockDimY=*/1, /*blockDimZ=*/1,
                                 MaxDynCGroupMem, Stream, nullptr, Config);
+
+  // Register a callback to indicate when the kernel is complete.
+  if (GenericDevice.getRPCServer())
+    cuLaunchHostFunc(
+        Stream,
+        [](void *Data) {
+          GenericPluginTy &Plugin = *reinterpret_cast<GenericPluginTy *>(Data);
+          Plugin.getRPCServer().Thread->finish();
+        },
+        &GenericDevice.Plugin);
+
   return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
 }
 
diff --git a/offload/test/libc/server.c b/offload/test/libc/server.c
new file mode 100644
index 00000000000000..eb81294436426a
--- /dev/null
+++ b/offload/test/libc/server.c
@@ -0,0 +1,56 @@
+// RUN: %libomptarget-compile-run-and-check-generic
+
+// REQUIRES: libc
+
+#include <assert.h>
+#include <omp.h>
+#include <stdio.h>
+
+#pragma omp begin declare variant match(device = {kind(gpu)})
+// Extension provided by the 'libc' project.
+unsigned long long rpc_host_call(void *fn, void *args, size_t size);
+#pragma omp declare target to(rpc_host_call) device_type(nohost)
+#pragma omp end declare variant
+
+#pragma omp begin declare variant match(device = {kind(cpu)})
+// Dummy host implementation to make this work for all targets.
+unsigned long long rpc_host_call(void *fn, void *args, size_t size) {
+  return ((unsigned long long (*)(void *))fn)(args);
+}
+#pragma omp end declare variant
+
+long long foo(void *data) { return -1; }
+
+void *fn_ptr = NULL;
+#pragma omp declare target to(fn_ptr)
+
+int main() {
+  fn_ptr = (void *)&foo;
+#pragma omp target update to(fn_ptr)
+
+  for (int i = 0; i < 4; ++i) {
+#pragma omp target
+    {
+      long long res = rpc_host_call(fn_ptr, NULL, 0);
+      assert(res == -1 && "RPC call failed\n");
+    }
+
+    for (int j = 0; j < 128; ++j) {
+#pragma omp target nowait
+      {
+        long long res = rpc_host_call(fn_ptr, NULL, 0);
+        assert(res == -1 && "RPC call failed\n");
+      }
+    }
+#pragma omp taskwait
+
+#pragma omp target
+    {
+      long long res = rpc_host_call(fn_ptr, NULL, 0);
+      assert(res == -1 && "RPC call failed\n");
+    }
+  }
+
+  // CHECK: PASS
+  puts("PASS");
+}



More information about the llvm-commits mailing list