[llvm-branch-commits] [llvm] release/20.x: [Offload] Properly guard modifications to the RPC device array (#126790) (PR #126795)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Feb 11 13:04:04 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload
Author: None (llvmbot)
<details>
<summary>Changes</summary>
Backport baf7a3c1e561ff7e3f7da2261ce1012c4f2ba1c0
Requested by: @<!-- -->jhuber6
---
Full diff: https://github.com/llvm/llvm-project/pull/126795.diff
2 Files Affected:
- (modified) offload/plugins-nextgen/common/include/RPC.h (+9-3)
- (modified) offload/plugins-nextgen/common/src/RPC.cpp (+4-1)
``````````diff
diff --git a/offload/plugins-nextgen/common/include/RPC.h b/offload/plugins-nextgen/common/include/RPC.h
index 42fca4aa4aebc..08556f15a76bf 100644
--- a/offload/plugins-nextgen/common/include/RPC.h
+++ b/offload/plugins-nextgen/common/include/RPC.h
@@ -72,6 +72,9 @@ struct RPCServerTy {
/// Array of associated devices. These must be alive as long as the server is.
std::unique_ptr<plugin::GenericDeviceTy *[]> Devices;
+ /// Mutex that guards accesses to the buffers and device array.
+ std::mutex BufferMutex{};
+
/// A helper class for running the user thread that handles the RPC interface.
/// Because we only need to check the RPC server while any kernels are
/// working, we track submission / completion events to allow the thread to
@@ -90,6 +93,9 @@ struct RPCServerTy {
std::condition_variable CV;
std::mutex Mutex;
+ /// A reference to the main server's mutex.
+ std::mutex &BufferMutex;
+
/// A reference to all the RPC interfaces that the server is handling.
llvm::ArrayRef<void *> Buffers;
@@ -98,9 +104,9 @@ struct RPCServerTy {
/// Initialize the worker thread to run in the background.
ServerThread(void *Buffers[], plugin::GenericDeviceTy *Devices[],
- size_t Length)
- : Running(false), NumUsers(0), CV(), Mutex(), Buffers(Buffers, Length),
- Devices(Devices, Length) {}
+ size_t Length, std::mutex &BufferMutex)
+ : Running(false), NumUsers(0), CV(), Mutex(), BufferMutex(BufferMutex),
+ Buffers(Buffers, Length), Devices(Devices, Length) {}
~ServerThread() { assert(!Running && "Thread not shut down explicitly\n"); }
diff --git a/offload/plugins-nextgen/common/src/RPC.cpp b/offload/plugins-nextgen/common/src/RPC.cpp
index e6750a540b391..eb305736d6264 100644
--- a/offload/plugins-nextgen/common/src/RPC.cpp
+++ b/offload/plugins-nextgen/common/src/RPC.cpp
@@ -131,6 +131,7 @@ void RPCServerTy::ServerThread::run() {
Lock.unlock();
while (NumUsers.load(std::memory_order_relaxed) > 0 &&
Running.load(std::memory_order_relaxed)) {
+ std::lock_guard<decltype(Mutex)> Lock(BufferMutex);
for (const auto &[Buffer, Device] : llvm::zip_equal(Buffers, Devices)) {
if (!Buffer || !Device)
continue;
@@ -149,7 +150,7 @@ RPCServerTy::RPCServerTy(plugin::GenericPluginTy &Plugin)
Devices(std::make_unique<plugin::GenericDeviceTy *[]>(
Plugin.getNumDevices())),
Thread(new ServerThread(Buffers.get(), Devices.get(),
- Plugin.getNumDevices())) {}
+ Plugin.getNumDevices(), BufferMutex)) {}
llvm::Error RPCServerTy::startThread() {
Thread->startThread();
@@ -190,6 +191,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
if (auto Err = Device.dataSubmit(ClientGlobal.getPtr(), &client,
sizeof(rpc::Client), nullptr))
return Err;
+ std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
Buffers[Device.getDeviceId()] = RPCBuffer;
Devices[Device.getDeviceId()] = &Device;
@@ -197,6 +199,7 @@ Error RPCServerTy::initDevice(plugin::GenericDeviceTy &Device,
}
Error RPCServerTy::deinitDevice(plugin::GenericDeviceTy &Device) {
+ std::lock_guard<decltype(BufferMutex)> Lock(BufferMutex);
Device.free(Buffers[Device.getDeviceId()], TARGET_ALLOC_HOST);
Buffers[Device.getDeviceId()] = nullptr;
Devices[Device.getDeviceId()] = nullptr;
``````````
</details>
https://github.com/llvm/llvm-project/pull/126795
More information about the llvm-branch-commits
mailing list