[llvm] [Offload] Implement `olShutDown` (PR #144055)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 25 07:22:43 PDT 2025


https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/144055

>From f555d4969f8e70df484425529252effadf488161 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 12 Jun 2025 10:47:19 +0100
Subject: [PATCH] `olShutDown` was not properly calling deinit on the
 platforms, resulting in random segfaults on AMD devices.

---
 offload/liboffload/API/Common.td             |  2 +-
 offload/liboffload/src/OffloadImpl.cpp       | 35 +++++++++++++++++---
 offload/unittests/OffloadAPI/init/olInit.cpp | 12 +++++++
 3 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/offload/liboffload/API/Common.td b/offload/liboffload/API/Common.td
index 79c3bd46f1984..669dfd3cca7c6 100644
--- a/offload/liboffload/API/Common.td
+++ b/offload/liboffload/API/Common.td
@@ -176,7 +176,7 @@ def : Function {
   let desc = "Release the resources in use by Offload";
   let details = [
     "This decrements an internal reference count. When this reaches 0, all resources will be released",
-    "Subsequent API calls made after this are not valid"
+    "Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED"
   ];
   let params = [];
   let returns = [];
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index c2a35a245e2a7..46e10fb758764 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -96,6 +96,7 @@ struct AllocInfo {
 // Global shared state for liboffload
 struct OffloadContext;
 static OffloadContext *OffloadContextVal;
+std::mutex OffloadContextValMutex;
 struct OffloadContext {
   OffloadContext(OffloadContext &) = delete;
   OffloadContext(OffloadContext &&) = delete;
@@ -106,6 +107,7 @@ struct OffloadContext {
   bool ValidationEnabled = true;
   DenseMap<void *, AllocInfo> AllocInfoMap{};
   SmallVector<ol_platform_impl_t, 4> Platforms{};
+  size_t RefCount;
 
   ol_device_handle_t HostDevice() {
     // The host platform is always inserted last
@@ -186,15 +188,38 @@ void initPlugins() {
   OffloadContextVal = Context;
 }
 
-// TODO: We can properly reference count here and manage the resources in a more
-// clever way
 Error olInit_impl() {
-  static std::once_flag InitFlag;
-  std::call_once(InitFlag, initPlugins);
+  std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
+
+  if (!isOffloadInitialized())
+    initPlugins();
+
+  OffloadContext::get().RefCount ++;
 
   return Error::success();
 }
-Error olShutDown_impl() { return Error::success(); }
+
+Error olShutDown_impl() {
+  std::lock_guard<std::mutex> Lock{OffloadContextValMutex};
+
+  if (--OffloadContext::get().RefCount != 0)
+    return Error::success();
+
+  llvm::Error Result = Error::success();
+
+  for (auto &P : OffloadContext::get().Platforms) {
+    // Host plugin is nullptr and has no deinit
+    if (!P.Plugin)
+      continue;
+
+    if (auto Res = P.Plugin->deinit())
+      Result = llvm::joinErrors(std::move(Result), std::move(Res));
+  }
+  delete OffloadContextVal;
+  OffloadContextVal = nullptr;
+
+  return Result;
+}
 
 Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
                                   ol_platform_info_t PropName, size_t PropSize,
diff --git a/offload/unittests/OffloadAPI/init/olInit.cpp b/offload/unittests/OffloadAPI/init/olInit.cpp
index 8e27e77cd0fb5..508615152b4f1 100644
--- a/offload/unittests/OffloadAPI/init/olInit.cpp
+++ b/offload/unittests/OffloadAPI/init/olInit.cpp
@@ -15,8 +15,20 @@
 
 struct olInitTest : ::testing::Test {};
 
+TEST_F(olInitTest, Success) {
+  ASSERT_SUCCESS(olInit());
+  ASSERT_SUCCESS(olShutDown());
+}
+
 TEST_F(olInitTest, Uninitialized) {
   ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
                olIterateDevices(
                    [](ol_device_handle_t, void *) { return false; }, nullptr));
 }
+
+TEST_F(olInitTest, RepeatedInit) {
+  for (size_t I = 0; I < 10; I++) {
+    ASSERT_SUCCESS(olInit());
+    ASSERT_SUCCESS(olShutDown());
+  }
+}



More information about the llvm-commits mailing list