[llvm-branch-commits] [llvm] release/20.x: [Offload] Properly guard modifications to the RPC device array (#126790) (PR #126795)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Feb 11 13:04:04 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-offload

Author: None (llvmbot)

<details>
<summary>Changes</summary>

Backport baf7a3c1e561ff7e3f7da2261ce1012c4f2ba1c0

Requested by: @<!-- -->jhuber6

---
Full diff: https://github.com/llvm/llvm-project/pull/126795.diff


2 Files Affected:

- (modified) offload/plugins-nextgen/common/include/RPC.h (+9-3) 
- (modified) offload/plugins-nextgen/common/src/RPC.cpp (+4-1) 


``````````diff
diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h
index 42fca4aa4aebc..08556f15a76bf 100644
--- a/offload/plugins-nextgen/common/include/RPC.h
+++ b/offload/plugins-nextgen/common/include/RPC.h
@@ -72,6 +72,9 @@ struct RPCServerTy {
   /// Array of associated devices. These must be alive as long as the server is.
   std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
 
+  /// Mutex that guards accesses to the buffers and device array.
+  std::mutex BufferMutex{};
+
   /// A helper class for running the user thread that handles the RPC interface.
   /// Because we only need to check the RPC server while any kernels are
   /// working, we track submission / completion events to allow the thread to
@@ -90,6 +93,9 @@ struct RPCServerTy {
     std::condition_variable CV;
     std::mutex Mutex;
 
+    /// A reference to the main server's mutex.
+    std::mutex &BufferMutex;
+
     /// A reference to all the RPC interfaces that the server is handling.
     llvm::ArrayRef<void *> Buffers;
 
@@ -98,9 +104,9 @@ struct RPCServerTy {
 
     /// Initialize the worker thread to run in the background.
     ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
-                 size_t Length)
-        : Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
-          Devices(Devices, Length) {}
+                 size_t Length, std::mutex &BufferMutex)
+        : Running(false), NumUsers(0), CV(), Mutex(), BufferMutex(BufferMutex),
+          Buffers(Buffers, Length), Devices(Devices, Length) {}
 
     ~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
 
diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp
index e6750a540b391..eb305736d6264 100644
--- a/offload/plugins-nextgen/common/src/RPC.cpp
+++ b/offload/plugins-nextgen/common/src/RPC.cpp
@@ -131,6 +131,7 @@ void RPCServerTy::ServerThread::run() {
     Lock.unlock();
     while (NumUsers.load(std::memory_order_relaxed) > 0 &&
            Running.load(std::memory_order_relaxed)) {
+      std::lock_guard<decltype(Mutex)> Lock(BufferMutex);
       for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
         if (!Buffer || !Device)
           continue;
@@ -149,7 +150,7 @@ RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
       Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
           Plugin.getNumDevices())),
       Thread(new ServerThread(Buffers.get(), Devices.get(),
-                              Plugin.getNumDevices())) {}
+                              Plugin.getNumDevices(), BufferMutex)) {}
 
 llvm::Error RPCServerTy::startThread() {
   Thread->startThread();
@@ -190,6 +191,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
   if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
                                    sizeof(rpc::Client), nullptr))
     return Err;
+  std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
   Buffers[Device.getDeviceId()] = RPCBuffer;
   Devices[Device.getDeviceId()] = &Device;
 
@@ -197,6 +199,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
 }
 
 Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
+  std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
   Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
   Buffers[Device.getDeviceId()] = nullptr;
   Devices[Device.getDeviceId()] = nullptr;

``````````

</details>


https://github.com/llvm/llvm-project/pull/126795


More information about the llvm-branch-commits mailing list