[llvm] [Offload] Update allocations to include device (PR #154733)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 21 04:04:18 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-offload
Author: Ross Brunton (RossBrunton)
<details>
<summary>Changes</summary>
To get around some memory issues with aliasing pointers, device
allocations need to be linked with their allocating device. `olMemFree`
now requires a device pointer to be provided for device allocations.
---
Full diff: https://github.com/llvm/llvm-project/pull/154733.diff
7 Files Affected:
- (modified) offload/liboffload/API/Memory.td (+15-4)
- (modified) offload/liboffload/src/OffloadImpl.cpp (+26-18)
- (modified) offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp (+8-8)
- (modified) offload/unittests/OffloadAPI/memory/olMemAlloc.cpp (+3-3)
- (modified) offload/unittests/OffloadAPI/memory/olMemFree.cpp (+39-5)
- (modified) offload/unittests/OffloadAPI/memory/olMemcpy.cpp (+10-10)
- (modified) offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp (+1-1)
``````````diff
diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td
index 5f7158588bc77..a904364075ba8 100644
--- a/offload/liboffload/API/Memory.td
+++ b/offload/liboffload/API/Memory.td
@@ -14,15 +14,18 @@ def : Enum {
let name = "ol_alloc_type_t";
let desc = "Represents the type of allocation made with olMemAlloc.";
let etors = [
- Etor<"HOST", "Host allocation">,
- Etor<"DEVICE", "Device allocation">,
- Etor<"MANAGED", "Managed allocation">
+ Etor<"HOST", "Host allocation. Allocated on the host and visible to the host and all devices sharing the same platform.">,
+ Etor<"DEVICE", "Device allocation. Allocated on a specific device and visible only to that device.">,
+ Etor<"MANAGED", "Managed allocation. Allocated on a specific device and visible to the host and all devices sharing the same platform.">
];
}
def : Function {
let name = "olMemAlloc";
let desc = "Creates a memory allocation on the specified device.";
+ let details = [
+ "`DEVICE` allocations do not share the same address space as the host or other devices. The `AllocationOut` pointer cannot be used to uniquely identify the allocation in these cases.",
+ ];
let params = [
Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,
@@ -39,10 +42,18 @@ def : Function {
def : Function {
let name = "olMemFree";
let desc = "Frees a memory allocation previously made by olMemAlloc.";
+ let details = [
+ "`Address` must be the beginning of the allocation.",
+ "`Device` must be provided for memory allocated as `OL_ALLOC_TYPE_DEVICE`, and may be provided for other types.",
+ "If `Device` is provided, it must match the device used to allocate the memory with `olMemAlloc`.",
+ ];
let params = [
+ Param<"ol_device_handle_t", "Device", "handle of the device this allocation was allocated on", PARAM_IN_OPTIONAL>,
Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
];
- let returns = [];
+ let returns = [
+ Return<"OL_ERRC_NOT_FOUND", ["The address was not found in the list of allocations"]>
+ ];
}
def : Function {
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 49154337dd193..da48378547974 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -106,6 +106,7 @@ namespace llvm {
namespace offload {
struct AllocInfo {
+ void *Base;
ol_device_handle_t Device;
ol_alloc_type_t Type;
};
@@ -124,8 +125,8 @@ struct OffloadContext {
bool TracingEnabled = false;
bool ValidationEnabled = true;
- DenseMap<void *, AllocInfo> AllocInfoMap{};
- std::mutex AllocInfoMapMutex{};
+ SmallVector<AllocInfo> AllocInfoList{};
+ std::mutex AllocInfoListMutex{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;
@@ -535,30 +536,37 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
*AllocationOut = *Alloc;
{
- std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
- OffloadContext::get().AllocInfoMap.insert_or_assign(
- *Alloc, AllocInfo{Device, Type});
+ std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+ OffloadContext::get().AllocInfoList.emplace_back(
+ AllocInfo{*AllocationOut, Device, Type});
}
return Error::success();
}
-Error olMemFree_impl(void *Address) {
- ol_device_handle_t Device;
- ol_alloc_type_t Type;
+Error olMemFree_impl(ol_device_handle_t Device, void *Address) {
+ AllocInfo Removed;
{
- std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
- if (!OffloadContext::get().AllocInfoMap.contains(Address))
- return createOffloadError(ErrorCode::INVALID_ARGUMENT,
- "address is not a known allocation");
+ std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoListMutex);
+
+ auto &List = OffloadContext::get().AllocInfoList;
+ auto Entry = std::find_if(List.begin(), List.end(), [&](AllocInfo &Entry) {
+ return Address == Entry.Base && (!Device || Entry.Device == Device);
+ });
+
+ if (Entry == List.end())
+ return Plugin::error(ErrorCode::NOT_FOUND,
+ "could not find memory allocated by olMemAlloc");
+ if (!Device && Entry->Type == OL_ALLOC_TYPE_DEVICE)
+ return Plugin::error(
+ ErrorCode::NOT_FOUND,
+ "specifying the Device parameter is required to query device memory");
- auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address);
- Device = AllocInfo.Device;
- Type = AllocInfo.Type;
- OffloadContext::get().AllocInfoMap.erase(Address);
+ Removed = std::move(*Entry);
+ *Entry = List.pop_back_val();
}
- if (auto Res =
- Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
+ if (auto Res = Removed.Device->Device->dataDelete(
+ Removed.Base, convertOlToPluginAllocTy(Removed.Type)))
return Res;
return Error::success();
diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
index 1dac8c50271b5..a7f4881bcc709 100644
--- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
+++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
@@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
ASSERT_EQ(Data[i], i);
}
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
@@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
ASSERT_EQ(Data[i], i);
}
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
});
}
@@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
ASSERT_EQ(Data[i], i);
}
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelLocalMemTest, Success) {
@@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], (i % 64) * 2);
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
@@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
@@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelGlobalTest, Success) {
@@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
ASSERT_EQ(Data[i], i * 2);
}
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
@@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
ASSERT_EQ(Data[i], i + 100);
}
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchKernelGlobalDtorTest, Success) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
index 00e428ec2abc7..c1d585d7271f3 100644
--- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
@@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
- olMemFree(Alloc);
+ olMemFree(Device, Alloc);
}
TEST_P(olMemAllocTest, SuccessAllocHost) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
- olMemFree(Alloc);
+ olMemFree(Device, Alloc);
}
TEST_P(olMemAllocTest, SuccessAllocDevice) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
- olMemFree(Alloc);
+ olMemFree(Device, Alloc);
}
TEST_P(olMemAllocTest, InvalidNullDevice) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index dfaf9bdef3189..8618e740f02bd 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -16,24 +16,58 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
TEST_P(olMemFreeTest, SuccessFreeManaged) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeManagedNull) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+ ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
}
TEST_P(olMemFreeTest, SuccessFreeHost) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_SUCCESS(olMemFree(Device, Alloc));
+}
+
+TEST_P(olMemFreeTest, SuccessFreeHostNull) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+ ASSERT_SUCCESS(olMemFree(nullptr, Alloc));
}
TEST_P(olMemFreeTest, SuccessFreeDevice) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_SUCCESS(olMemFree(Device, Alloc));
}
TEST_P(olMemFreeTest, InvalidNullPtr) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Device, nullptr));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceNull) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+ ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(nullptr, Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeManagedWrongDevice) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
+ ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidFreeHostWrongDevice) {
+ void *Alloc = nullptr;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
+ ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
+}
+
+
+TEST_P(olMemFreeTest, InvalidFreeDeviceWrongDevice) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
- ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(TestEnvironment::getHostDevice(), Alloc));
}
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index cc67d782ef403..d028099916848 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) {
std::vector<uint8_t> Input(Size, 42);
ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size));
olSyncQueue(Queue);
- olMemFree(Alloc);
+ olMemFree(Device, Alloc);
}
TEST_P(olMemcpyTest, SuccessDtoH) {
@@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) {
for (uint8_t Val : Output) {
ASSERT_EQ(Val, 42);
}
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_SUCCESS(olMemFree(Device, Alloc));
}
TEST_P(olMemcpyTest, SuccessDtoD) {
@@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) {
for (uint8_t Val : Output) {
ASSERT_EQ(Val, 42);
}
- ASSERT_SUCCESS(olMemFree(AllocA));
- ASSERT_SUCCESS(olMemFree(AllocB));
+ ASSERT_SUCCESS(olMemFree(Device, AllocA));
+ ASSERT_SUCCESS(olMemFree(Device, AllocB));
}
TEST_P(olMemcpyTest, SuccessHtoHSync) {
@@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) {
for (uint8_t Val : Output) {
ASSERT_EQ(Val, 42);
}
- ASSERT_SUCCESS(olMemFree(Alloc));
+ ASSERT_SUCCESS(olMemFree(Device, Alloc));
}
TEST_P(olMemcpyTest, SuccessSizeZero) {
@@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I);
- ASSERT_SUCCESS(olMemFree(DestMem));
- ASSERT_SUCCESS(olMemFree(SourceMem));
+ ASSERT_SUCCESS(olMemFree(Device, DestMem));
+ ASSERT_SUCCESS(olMemFree(Device, SourceMem));
}
TEST_P(olMemcpyGlobalTest, SuccessWrite) {
@@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) {
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I);
- ASSERT_SUCCESS(olMemFree(DestMem));
- ASSERT_SUCCESS(olMemFree(SourceMem));
+ ASSERT_SUCCESS(olMemFree(Device, DestMem));
+ ASSERT_SUCCESS(olMemFree(Device, SourceMem));
}
TEST_P(olMemcpyGlobalTest, SuccessRead) {
@@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) {
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I * 2);
- ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(Device, DestMem));
}
diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
index aa86750f6adf9..2f3fda6fb729b 100644
--- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
+++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
@@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
}
ASSERT_SUCCESS(olDestroyQueue(Queue));
- ASSERT_SUCCESS(olMemFree(Mem));
+ ASSERT_SUCCESS(olMemFree(Device, Mem));
}
TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/154733
More information about the llvm-commits
mailing list