[llvm] [Offload] Update allocations to include device (PR #154733)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 3 03:58:15 PDT 2025


https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/154733

>From 071ae75e8779ada3b153bf4e7a97b075ad307af1 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 21 Aug 2025 12:01:43 +0100
Subject: [PATCH 1/4] [Offload] Update allocations to include device

To get around some memory issues with aliasing pointers, device
allocations need to be linked with their allocating device. `olMemFree`
now requires a device pointer to be provided for device allocations.
---
 offload/liboffload/API/Memory.td              | 19 ++++++--
 offload/liboffload/src/OffloadImpl.cpp        | 46 +++++++++++--------
 .../OffloadAPI/kernel/olLaunchKernel.cpp      | 16 +++----
 .../OffloadAPI/memory/olMemAlloc.cpp          |  6 +--
 .../unittests/OffloadAPI/memory/olMemFree.cpp | 44 ++++++++++++++++--
 .../unittests/OffloadAPI/memory/olMemcpy.cpp  | 20 ++++----
 .../OffloadAPI/queue/olLaunchHostFunction.cpp |  2 +-
 7 files changed, 103 insertions(+), 50 deletions(-)

diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td
index cc98b672a26a9..adca91dbbe2a3 100644
--- a/offload/liboffload/API/Memory.td
+++ b/offload/liboffload/API/Memory.td
@@ -13,14 +13,17 @@
 def ol_alloc_type_t : Enum {
   let desc = "Represents the type of allocation made with olMemAlloc.";
   let etors = [
-    Etor<"HOST", "Host allocation">,
-    Etor<"DEVICE", "Device allocation">,
-    Etor<"MANAGED", "Managed allocation">
+    Etor<"HOST", "Host allocation. Allocated on the host and visible to the host and all devices sharing the same platform.">,
+    Etor<"DEVICE", "Device allocation. Allocated on a specific device and visible only to that device.">,
+    Etor<"MANAGED", "Managed allocation. Allocated on a specific device and visible to the host and all devices sharing the same platform.">
   ];
 }
 
 def olMemAlloc : Function {
   let desc = "Creates a memory allocation on the specified device.";
+  let details = [
+      "`DEVICE` allocations do not share the same address space as the host or other devices. The `AllocationOut` pointer cannot be used to uniquely identify the allocation in these cases.",
+  ];
   let params = [
     Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
     Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,
@@ -36,10 +39,18 @@ def olMemAlloc : Function {
 
 def olMemFree : Function {
   let desc = "Frees a memory allocation previously made by olMemAlloc.";
+  let details = [
+      "`Address` must be the beginning of the allocation.",
+      "`Device` must be provided for memory allocated as `OL_ALLOC_TYPE_DEVICE`, and may be provided for other types.",
+      "If `Device` is provided, it must match the device used to allocate the memory with `olMemAlloc`.",
+  ];
   let params = [
+    Param<"ol_device_handle_t", "Device", "handle of the device this allocation was allocated on", PARAM_IN_OPTIONAL>,
     Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
   ];
-  let returns = [];
+  let returns = [
+    Return<"OL_ERRC_NOT_FOUND", ["The address was not found in the list of allocations"]>
+  ];
 }
 
 def olMemcpy : Function {
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 7e8e297831f45..b20549250220e 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -183,6 +183,7 @@ namespace llvm {
 namespace offload {
 
 struct AllocInfo {
+  void *Base;
   ol_device_handle_t Device;
   ol_alloc_type_t Type;
 };
@@ -201,8 +202,8 @@ struct OffloadContext {
 
   bool TracingEnabled = false;
   bool ValidationEnabled = true;
-  DenseMap<void *, AllocInfo> AllocInfoMap{};
-  std::mutex AllocInfoMapMutex{};
+  SmallVector<AllocInfo> AllocInfoList{};
+  std::mutex AllocInfoListMutex{};
   SmallVector<ol_platform_impl_t, 4> Platforms{};
   size_t RefCount;
 
@@ -625,30 +626,37 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
 
   *AllocationOut = *Alloc;
   {
-    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
-    OffloadContext::get().AllocInfoMap.insert_or_assign(
-        *Alloc, AllocInfo{Device, Type});
+    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+    OffloadContext::get().AllocInfoList.emplace_back(
+        AllocInfo{*AllocationOut, Device, Type});
   }
   return Error::success();
 }
 
-Error olMemFree_impl(void *Address) {
-  ol_device_handle_t Device;
-  ol_alloc_type_t Type;
+Error olMemFree_impl(ol_device_handle_t Device, void *Address) {
+  AllocInfo Removed;
   {
-    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
-    if (!OffloadContext::get().AllocInfoMap.contains(Address))
-      return createOffloadError(ErrorCode::INVALID_ARGUMENT,
-                                "address is not a known allocation");
-
-    auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
-    Device = AllocInfo.Device;
-    Type = AllocInfo.Type;
-    OffloadContext::get().AllocInfoMap.erase(Address);
+    std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+
+    auto &List = OffloadContext::get().AllocInfoList;
+    auto Entry = std::find_if(List.begin(), List.end(), [&](AllocInfo &Entry) {
+      return Address == Entry.Base && (!Device || Entry.Device == Device);
+    });
+
+    if (Entry == List.end())
+      return Plugin::error(ErrorCode::NOT_FOUND,
+                           "could not find memory allocated by olMemAlloc");
+    if (!Device && Entry->Type == OL_ALLOC_TYPE_DEVICE)
+      return Plugin::error(
+          ErrorCode::NOT_FOUND,
+          "specifying the Device parameter is required to query device memory");
+
+    Removed = std::move(*Entry);
+    *Entry = List.pop_back_val();
   }
 
-  if (auto Res =
-          Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
+  if (auto Res = Removed.Device->Device->dataDelete(
+          Removed.Base, convertOlToPluginAllocTy(Removed.Type)))
     return Res;
 
   return Error::success();
diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
index 1dac8c50271b5..a7f4881bcc709 100644
--- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
+++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
@@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
@@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
       ASSERT_EQ(Data[i], i);
     }
 
-    ASSERT_SUCCESS(olMemFree(Mem));
+    ASSERT_SUCCESS(olMemFree(Device, Mem));
   });
 }
 
@@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemTest, Success) {
@@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], (i % 64) * 2);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
@@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
@@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, Success) {
@@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
     ASSERT_EQ(Data[i], i * 2);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
@@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
     ASSERT_EQ(Data[i], i + 100);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalDtorTest, Success) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
index 00e428ec2abc7..c1d585d7271f3 100644
--- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
@@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemAllocTest, InvalidNullDevice) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index dfaf9bdef3189..8618e740f02bd 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -16,24 +16,58 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
 TEST_P(olMemFreeTest, SuccessFreeManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeManagedNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+  ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeHostNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+  ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemFreeTest, InvalidNullPtr) {
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Device, nullptr));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(nullptr, Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index cc67d782ef403..d028099916848 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) {
   std::vector<uint8_t> Input(Size, 42);
   ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size));
   olSyncQueue(Queue);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemcpyTest, SuccessDtoH) {
@@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessDtoD) {
@@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(AllocA));
-  ASSERT_SUCCESS(olMemFree(AllocB));
+  ASSERT_SUCCESS(olMemFree(Device, AllocA));
+  ASSERT_SUCCESS(olMemFree(Device, AllocB));
 }
 
 TEST_P(olMemcpyTest, SuccessHtoHSync) {
@@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessSizeZero) {
@@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessWrite) {
@@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessRead) {
@@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I * 2);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(Device, DestMem));
 }
diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
index aa86750f6adf9..2f3fda6fb729b 100644
--- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
+++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
@@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
   }
 
   ASSERT_SUCCESS(olDestroyQueue(Queue));
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Device, Mem));
 }
 
 TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {

>From 5c51f00b6b3ad9883e26248b125251775ae02f2b Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 21 Aug 2025 12:34:10 +0100
Subject: [PATCH 2/4] Clang format

---
 offload/unittests/OffloadAPI/memory/olMemFree.cpp | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index 8618e740f02bd..561ec84fd98a4 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -56,18 +56,20 @@ TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
 TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
 TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
-
 TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+  ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+               olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }

>From 4944a92575bdc0271fccedd422d048a6ed4b1ad3 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Fri, 22 Aug 2025 13:16:02 +0100
Subject: [PATCH 3/4] Clean up tests

---
 .../unittests/OffloadAPI/memory/olMemFree.cpp | 57 +++++++++----------
 1 file changed, 27 insertions(+), 30 deletions(-)

diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index 561ec84fd98a4..a68ae2514adfe 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -10,66 +10,63 @@
 #include <OffloadAPI.h>
 #include <gtest/gtest.h>
 
-using olMemFreeTest = OffloadDeviceTest;
-OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
+template <ol_alloc_type_t Type> struct olMemFreeTestBase : OffloadDeviceTest {
+  void SetUp() override {
+    RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp());
+    ASSERT_SUCCESS(olMemAlloc(Device, Type, 0x1000, &Alloc));
+  }
 
-TEST_P(olMemFreeTest, SuccessFreeManaged) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+  void *Alloc;
+};
+
+struct olMemFreeDeviceTest : olMemFreeTestBase<OL_ALLOC_TYPE_DEVICE> {};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeDeviceTest);
+
+struct olMemFreeHostTest : olMemFreeTestBase<OL_ALLOC_TYPE_HOST> {};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeHostTest);
+
+struct olMemFreeManagedTest : olMemFreeTestBase<OL_ALLOC_TYPE_MANAGED> {};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeManagedTest);
+
+TEST_P(olMemFreeManagedTest, SuccessFree) {
   ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
-TEST_P(olMemFreeTest, SuccessFreeManagedNull) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+TEST_P(olMemFreeManagedTest, SuccessFreeNull) {
   ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
-TEST_P(olMemFreeTest, SuccessFreeHost) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+TEST_P(olMemFreeHostTest, SuccessFree) {
   ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
-TEST_P(olMemFreeTest, SuccessFreeHostNull) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+TEST_P(olMemFreeHostTest, SuccessFreeNull) {
   ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
 }
 
-TEST_P(olMemFreeTest, SuccessFreeDevice) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+TEST_P(olMemFreeDeviceTest, SuccessFree) {
   ASSERT_SUCCESS(olMemFree(Device, Alloc));
 }
 
-TEST_P(olMemFreeTest, InvalidNullPtr) {
+TEST_P(olMemFreeDeviceTest, InvalidNullPtr) {
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Device, nullptr));
 }
 
-TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+TEST_P(olMemFreeDeviceTest, InvalidNullDevice) {
   ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(nullptr, Alloc));
 }
 
-TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+TEST_P(olMemFreeDeviceTest, InvalidFreeWrongDevice) {
   ASSERT_ERROR(OL_ERRC_NOT_FOUND,
                olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
-TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+TEST_P(olMemFreeHostTest, InvalidFreeWrongDevice) {
   ASSERT_ERROR(OL_ERRC_NOT_FOUND,
                olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }
 
-TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
-  void *Alloc = nullptr;
-  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+TEST_P(olMemFreeManagedTest, InvalidFreeWrongDevice) {
   ASSERT_ERROR(OL_ERRC_NOT_FOUND,
                olMemFree(TestEnvironment::getHostDevice(), Alloc));
 }

>From ef9904489bcd6f30019de83bdd1eba7fda5f46f5 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 3 Sep 2025 11:57:55 +0100
Subject: [PATCH 4/4] Update tests after rebase

---
 .../Conformance/include/mathtest/DeviceResources.hpp |  4 ++--
 .../unittests/Conformance/lib/DeviceResources.cpp    |  5 +++--
 offload/unittests/OffloadAPI/memory/olMemFill.cpp    | 12 ++++++------
 3 files changed, 11 insertions(+), 10 deletions(-)

diff --git a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
index 860448afa3a01..9650678b53066 100644
--- a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
+++ b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
@@ -29,7 +29,7 @@ class DeviceContext;
 
 namespace detail {
 
-void freeDeviceMemory(void *Address) noexcept;
+void freeDeviceMemory(ol_device_handle_t Device, void *Address) noexcept;
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -40,7 +40,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
 public:
   ~ManagedBuffer() noexcept {
     if (Address)
-      detail::freeDeviceMemory(Address);
+      detail::freeDeviceMemory(nullptr, Address);
   }
 
   ManagedBuffer(const ManagedBuffer &) = delete;
diff --git a/offload/unittests/Conformance/lib/DeviceResources.cpp b/offload/unittests/Conformance/lib/DeviceResources.cpp
index d1c7b90e751e6..29c9efa4852a1 100644
--- a/offload/unittests/Conformance/lib/DeviceResources.cpp
+++ b/offload/unittests/Conformance/lib/DeviceResources.cpp
@@ -24,9 +24,10 @@ using namespace mathtest;
 // Helpers
 //===----------------------------------------------------------------------===//
 
-void detail::freeDeviceMemory(void *Address) noexcept {
+void detail::freeDeviceMemory(ol_device_handle_t Device,
+                              void *Address) noexcept {
   if (Address)
-    OL_CHECK(olMemFree(Address));
+    OL_CHECK(olMemFree(Device, Address));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/offload/unittests/OffloadAPI/memory/olMemFill.cpp b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
index a84ed3d78eccf..e22b0001ca838 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFill.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
@@ -39,7 +39,7 @@ struct olMemFillTest : OffloadQueueTest {
       ASSERT_EQ(AllocPtr[i], Pattern);
     }
 
-    olMemFree(Alloc);
+    olMemFree(Device, Alloc);
   }
 };
 OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest);
@@ -92,7 +92,7 @@ TEST_P(olMemFillTest, SuccessLarge) {
     ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeEnqueue) {
@@ -120,7 +120,7 @@ TEST_P(olMemFillTest, SuccessLargeEnqueue) {
     ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeByteAligned) {
@@ -146,7 +146,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) {
     ASSERT_EQ(AllocPtr[i].C, 255);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
@@ -176,7 +176,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
     ASSERT_EQ(AllocPtr[i].C, 255);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }
 
 TEST_P(olMemFillTest, InvalidPatternSize) {
@@ -189,5 +189,5 @@ TEST_P(olMemFillTest, InvalidPatternSize) {
                olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
 
   olSyncQueue(Queue);
-  olMemFree(Alloc);
+  olMemFree(Device, Alloc);
 }



More information about the llvm-commits mailing list