[llvm] [Offload][NFC] use unique ptrs for platforms (PR #160888)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 26 06:52:18 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-offload

Author: Piotr Balcer (pbalcer)

<details>
<summary>Changes</summary>

Currently, devices store a raw pointer to back to their owning Platform. Platforms are stored directly inside of a vector. Modifying this vector risks invalidating all the platform pointers stored in devices.

This patch allocates platforms individually, and changes devices to store a reference to its platform instead of a pointer. This is safe, because devices are guaranteed to outlive the platform they are from.


---
Full diff: https://github.com/llvm/llvm-project/pull/160888.diff


1 Files Affected:

- (modified) offload/liboffload/src/OffloadImpl.cpp (+38-36) 


``````````diff
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 08a2e25b97d85..70fc97c15511e 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -39,12 +39,28 @@ using namespace llvm::omp::target;
 using namespace llvm::omp::target::plugin;
 using namespace error;
 
+struct ol_platform_impl_t {
+  ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
+                     ol_platform_backend_t BackendType)
+      : Plugin(std::move(Plugin)), BackendType(BackendType) {}
+  std::unique_ptr<GenericPluginTy> Plugin;
+  llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
+  ol_platform_backend_t BackendType;
+
+  /// Complete all pending work for this platform and perform any needed
+  /// cleanup.
+  ///
+  /// After calling this function, no liboffload functions should be called with
+  /// this platform handle.
+  llvm::Error destroy();
+};
+
 // Handle type definitions. Ideally these would be 1:1 with the plugins, but
 // we add some additional data here for now to avoid churn in the plugin
 // interface.
 struct ol_device_impl_t {
   ol_device_impl_t(int DeviceNum, GenericDeviceTy *Device,
-                   ol_platform_handle_t Platform, InfoTreeNode &&DevInfo)
+                   ol_platform_impl_t &Platform, InfoTreeNode &&DevInfo)
       : DeviceNum(DeviceNum), Device(Device), Platform(Platform),
         Info(std::forward<InfoTreeNode>(DevInfo)) {}
 
@@ -55,7 +71,7 @@ struct ol_device_impl_t {
 
   int DeviceNum;
   GenericDeviceTy *Device;
-  ol_platform_handle_t Platform;
+  ol_platform_impl_t &Platform;
   InfoTreeNode Info;
 
   llvm::SmallVector<__tgt_async_info *> OutstandingQueues;
@@ -102,20 +118,8 @@ struct ol_device_impl_t {
   }
 };
 
-struct ol_platform_impl_t {
-  ol_platform_impl_t(std::unique_ptr<GenericPluginTy> Plugin,
-                     ol_platform_backend_t BackendType)
-      : Plugin(std::move(Plugin)), BackendType(BackendType) {}
-  std::unique_ptr<GenericPluginTy> Plugin;
-  llvm::SmallVector<std::unique_ptr<ol_device_impl_t>> Devices;
-  ol_platform_backend_t BackendType;
 
-  /// Complete all pending work for this platform and perform any needed
-  /// cleanup.
-  ///
-  /// After calling this function, no liboffload functions should be called with
-  /// this platform handle.
-  llvm::Error destroy() {
+llvm::Error ol_platform_impl_t::destroy() {
     llvm::Error Result = Plugin::success();
     for (auto &D : Devices)
       if (auto Err = D->destroy())
@@ -125,8 +129,7 @@ struct ol_platform_impl_t {
       Result = llvm::joinErrors(std::move(Result), std::move(Res));
 
     return Result;
-  }
-};
+}
 
 struct ol_queue_impl_t {
   ol_queue_impl_t(__tgt_async_info *AsyncInfo, ol_device_handle_t Device)
@@ -206,12 +209,12 @@ struct OffloadContext {
   // Partitioned list of memory base addresses. Each element in this list is a
   // key in AllocInfoMap
   llvm::SmallVector<void *> AllocBases{};
-  SmallVector<ol_platform_impl_t, 4> Platforms{};
+  SmallVector<std::unique_ptr<ol_platform_impl_t>, 4> Platforms{};
   size_t RefCount;
 
   ol_device_handle_t HostDevice() {
     // The host platform is always inserted last
-    return Platforms.back().Devices[0].get();
+    return Platforms.back()->Devices[0].get();
   }
 
   static OffloadContext &get() {
@@ -251,35 +254,34 @@ Error initPlugins(OffloadContext &Context) {
 #define PLUGIN_TARGET(Name)                                                    \
   do {                                                                         \
     if (StringRef(#Name) != "host")                                            \
-      Context.Platforms.emplace_back(ol_platform_impl_t{                       \
+      Context.Platforms.emplace_back(std::make_unique<ol_platform_impl_t>(     \
           std::unique_ptr<GenericPluginTy>(createPlugin_##Name()),             \
-          pluginNameToBackend(#Name)});                                        \
+          pluginNameToBackend(#Name)));                                        \
   } while (false);
 #include "Shared/Targets.def"
 
   // Preemptively initialize all devices in the plugin
   for (auto &Platform : Context.Platforms) {
-    auto Err = Platform.Plugin->init();
+    auto Err = Platform->Plugin->init();
     [[maybe_unused]] std::string InfoMsg = toString(std::move(Err));
-    for (auto DevNum = 0; DevNum < Platform.Plugin->number_of_devices();
+    for (auto DevNum = 0; DevNum < Platform->Plugin->number_of_devices();
          DevNum++) {
-      if (Platform.Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
-        auto Device = &Platform.Plugin->getDevice(DevNum);
+      if (Platform->Plugin->init_device(DevNum) == OFFLOAD_SUCCESS) {
+        auto Device = &Platform->Plugin->getDevice(DevNum);
         auto Info = Device->obtainInfoImpl();
         if (auto Err = Info.takeError())
           return Err;
-        Platform.Devices.emplace_back(std::make_unique<ol_device_impl_t>(
-            DevNum, Device, &Platform, std::move(*Info)));
+        Platform->Devices.emplace_back(std::make_unique<ol_device_impl_t>(
+            DevNum, Device, *Platform, std::move(*Info)));
       }
     }
   }
 
   // Add the special host device
   auto &HostPlatform = Context.Platforms.emplace_back(
-      ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
-  HostPlatform.Devices.emplace_back(
-      std::make_unique<ol_device_impl_t>(-1, nullptr, nullptr, InfoTreeNode{}));
-  Context.HostDevice()->Platform = &HostPlatform;
+      std::make_unique<ol_platform_impl_t>(nullptr, OL_PLATFORM_BACKEND_HOST));
+  HostPlatform->Devices.emplace_back(
+      std::make_unique<ol_device_impl_t>(-1, nullptr, *HostPlatform, InfoTreeNode{}));
 
   Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
   Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
@@ -316,10 +318,10 @@ Error olShutDown_impl() {
 
   for (auto &P : OldContext->Platforms) {
     // Host plugin is nullptr and has no deinit
-    if (!P.Plugin || !P.Plugin->is_initialized())
+    if (!P->Plugin || !P->Plugin->is_initialized())
       continue;
 
-    if (auto Res = P.destroy())
+    if (auto Res = P->destroy())
       Result = llvm::joinErrors(std::move(Result), std::move(Res));
   }
 
@@ -384,7 +386,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
   // These are not implemented by the plugin interface
   switch (PropName) {
   case OL_DEVICE_INFO_PLATFORM:
-    return Info.write<void *>(Device->Platform);
+    return Info.write<void *>(&Device->Platform);
 
   case OL_DEVICE_INFO_TYPE:
     return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_GPU);
@@ -517,7 +519,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
 
   switch (PropName) {
   case OL_DEVICE_INFO_PLATFORM:
-    return Info.write<void *>(Device->Platform);
+    return Info.write<void *>(&Device->Platform);
   case OL_DEVICE_INFO_TYPE:
     return Info.write<ol_device_type_t>(OL_DEVICE_TYPE_HOST);
   case OL_DEVICE_INFO_NAME:
@@ -595,7 +597,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device,
 
 Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) {
   for (auto &Platform : OffloadContext::get().Platforms) {
-    for (auto &Device : Platform.Devices) {
+    for (auto &Device : Platform->Devices) {
       if (!Callback(Device.get(), UserData)) {
         break;
       }

``````````

</details>


https://github.com/llvm/llvm-project/pull/160888


More information about the llvm-commits mailing list