[llvm] [Offload] Update allocations to include device (PR #154733)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 21 04:34:21 PDT 2025


https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/154733

>From 5151703b01433e113c7d85a48be8e55888003ab4 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 21 Aug 2025 12:01:43 +0100
Subject: [PATCH 1/2] [Offload] Update allocations to include device

To get around some memory issues with aliasing pointers, device
allocations need to be linked with their allocating device. `olMemFree`
now requires a device pointer to be provided for device allocations.
---
 offload/liboffload/API/Memory.td              | 19 ++++++--
 offload/liboffload/src/OffloadImpl.cpp        | 44 +++++++++++--------
 .../OffloadAPI/kernel/olLaunchKernel.cpp      | 16 +++----
 .../OffloadAPI/memory/olMemAlloc.cpp          |  6 +--
 .../unittests/OffloadAPI/memory/olMemFree.cpp | 44 ++++++++++++++++---
 .../unittests/OffloadAPI/memory/olMemcpy.cpp  | 20 ++++-----
 .../OffloadAPI/queue/olLaunchHostFunction.cpp |  2 +-
 7 files changed, 102 insertions(+), 49 deletions(-)

diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td
index 5f7158588bc77..a904364075ba8 100644
--- a/offload/liboffload/API/Memory.td
+++ b/offload/liboffload/API/Memory.td
@@ -14,15 +14,18 @@ def : Enum {
   let name = "ol_alloc_type_t";
   let desc = "Represents the type of allocation made with olMemAlloc.";
   let etors = [
-    Etor<"HOST", "Host allocation">,
-    Etor<"DEVICE", "Device allocation">,
-    Etor<"MANAGED", "Managed allocation">
+    Etor<"HOST", "Host allocation. Allocated on the host and visible to the host and all devices sharing the same platform.">,
+    Etor<"DEVICE", "Device allocation. Allocated on a specific device and visible only to that device.">,
+    Etor<"MANAGED", "Managed allocation. Allocated on a specific device and visible to the host and all devices sharing the same platform.">
   ];
 }
 
 def : Function {
   let name = "olMemAlloc";
   let desc = "Creates a memory allocation on the specified device.";
+  let details = [
+      "`DEVICE` allocations do not share the same address space as the host or other devices. The `AllocationOut` pointer cannot be used to uniquely identify the allocation in these cases.",
+  ];
   let params = [
     Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
     Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,
@@ -39,10 +42,18 @@ def : Function {
 def : Function {
   let name = "olMemFree";
   let desc = "Frees a memory allocation previously made by olMemAlloc.";
+  let details = [
+      "`Address` must be the beginning of the allocation.",
+      "`Device` must be provided for memory allocated as `OL_ALLOC_TYPE_DEVICE`, and may be provided for other types.",
+      "If `Device` is provided, it must match the device used to allocate the memory with `olMemAlloc`.",
+  ];
   let params = [
+    Param<"ol_device_handle_t", "Device", "handle of the device this allocation was allocated on", PARAM_IN_OPTIONAL>,
     Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
   ];
-  let returns = [];
+  let returns = [
+    Return<"OL_ERRC_NOT_FOUND", ["The address was not found in the list of allocations"]>
+  ];
 }
 
 def : Function {
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 49154337dd193..da48378547974 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -106,6 +106,7 @@ namespace llvm {
 namespace offload {
 
 struct AllocInfo {
+  void *Base;
   ol_device_handle_t Device;
   ol_alloc_type_t Type;
 };
@@ -124,8 +125,8 @@ struct OffloadContext {
 
   bool TracingEnabled = false;
   bool ValidationEnabled = true;
-  DenseMap<void *, AllocInfo> AllocInfoMap{};
-  std::mutex AllocInfoMapMutex{};
+  SmallVector<AllocInfo> AllocInfoList{};
+  std::mutex AllocInfoListMutex{};
   SmallVector<ol_platform_impl_t, 4> Platforms{};
   size_t RefCount;
 
@@ -535,30 +536,37 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
 
   *AllocationOut = *Alloc;
   {
-    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
-    OffloadContext::get().AllocInfoMap.insert_or_assign(
-        *Alloc, AllocInfo{Device, Type});
+    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+    OffloadContext::get().AllocInfoList.emplace_back(
+        AllocInfo{*AllocationOut, Device, Type});
   }
   return Error::success();
 }
 
-Error olMemFree_impl(void *Address) {
-  ol_device_handle_t Device;
-  ol_alloc_type_t Type;
+Error olMemFree_impl(ol_device_handle_t Device, void *Address) {
+  AllocInfo Removed;
   {
-    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
-    if (!OffloadContext::get().AllocInfoMap.contains(Address))
-      return createOffloadError(ErrorCode::INVALID_ARGUMENT,
-                                "address is not a known allocation");
+    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+
+    auto &List = OffloadContext::get().AllocInfoList;
+    auto Entry = std::find_if(List.begin(), List.end(), [&](AllocInfo &Entry) {
+      return Address == Entry.Base && (!Device || Entry.Device == Device);
+    });
+
+    if (Entry == List.end())
+      return Plugin::error(ErrorCode::NOT_FOUND,
+                           "could not find memory allocated by olMemAlloc");
+    if (!Device && Entry->Type == OL_ALLOC_TYPE_DEVICE)
+      return Plugin::error(
+          ErrorCode::NOT_FOUND,
+          "specifying the Device parameter is required to query device memory");
 
-    auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
-    Device = AllocInfo.Device;
-    Type = AllocInfo.Type;
-    OffloadContext::get().AllocInfoMap.erase(Address);
+    Removed = std::move(*Entry);
+    *Entry = List.pop_back_val();
   }
 
-  if (auto Res =
-          Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
+  if (auto Res = Removed.Device->Device->dataDelete(
+          Removed.Base, convertOlToPluginAllocTy(Removed.Type)))
     return Res;
 
   return Error::success();
diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
index 1dac8c50271b5..a7f4881bcc709 100644
--- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
+++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
@@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
@@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
       ASSERT_EQ(Data[i], i);
     }
 
-    ASSERT_SUCCESS(olMemFree(Mem));
+    ASSERT_SUCCESS(olMemFree(Device, Mem));
   });
 }
 
@@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemTest, Success) {
@@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], (i % 64) * 2);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
@@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
@@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, Success) {
@@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
     ASSERT_EQ(Data[i], i * 2);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
@@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
     ASSERT_EQ(Data[i], i + 100);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalDtorTest, Success) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
index 00e428ec2abc7..c1d585d7271f3 100644
--- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
@@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, InvalidNullDevice) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index dfaf9bdef3189..8618e740f02bd 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -16,24 +16,58 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
 TEST_P(olMemFreeTest, SuccessFreeManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeManagedNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+  ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeHostNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+  ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemFreeTest, InvalidNullPtr) {
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Device, nullptr));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(nullptr, Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index cc67d782ef403..d028099916848 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) {
   std::vector<uint8_t> Input(Size, 42);
   ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size));
   olSyncQueue(Queue);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemcpyTest, SuccessDtoH) {
@@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessDtoD) {
@@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(AllocA));
-  ASSERT_SUCCESS(olMemFree(AllocB));
+  ASSERT_SUCCESS(olMemFree(Device, AllocA));
+  ASSERT_SUCCESS(olMemFree(Device, AllocB));
 }
 
 TEST_P(olMemcpyTest, SuccessHtoHSync) {
@@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessSizeZero) {
@@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessWrite) {
@@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessRead) {
@@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I * 2);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
 }
diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
index aa86750f6adf9..2f3fda6fb729b 100644
--- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
+++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
@@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
   }
 
   ASSERT_SUCCESS(olDestroyQueue(Queue));
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {

>From 1bd8c66c04aeafece6ed538f26ad1880e53171d6 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 21 Aug 2025 12:34:10 +0100
Subject: [PATCH 2/2] Clang format

---
 offload/unittests/OffloadAPI/memory/olMemFree.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index 8618e740f02bd..561ec84fd98a4 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -56,18 +56,20 @@ TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
 TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
 TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
-
 TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }



More information about the llvm-commits mailing list