[llvm] [Offload] Add global variable address/size queries (PR #147972)
Ross Brunton via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 11 07:30:54 PDT 2025
https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/147972
>From 69089e4d4f443cd5e81fb9aaed546faa0cd0d45b Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 10 Jul 2025 15:34:17 +0100
Subject: [PATCH] [Offload] Add global variable address/size queries
Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.
---
offload/liboffload/API/Symbol.td | 4 +-
offload/liboffload/src/OffloadImpl.cpp | 19 ++++
offload/tools/offload-tblgen/PrintGen.cpp | 8 +-
.../unittests/OffloadAPI/memory/olMemcpy.cpp | 105 ++++++++++++++++++
.../OffloadAPI/symbol/olGetSymbolInfo.cpp | 28 +++++
.../OffloadAPI/symbol/olGetSymbolInfoSize.cpp | 14 +++
6 files changed, 175 insertions(+), 3 deletions(-)
diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td
index 9317c71df1f10..2e94d703809e7 100644
--- a/offload/liboffload/API/Symbol.td
+++ b/offload/liboffload/API/Symbol.td
@@ -39,7 +39,9 @@ def : Enum {
let desc = "Supported symbol info.";
let is_typed = 1;
let etors = [
- TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
+ TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
+ TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
+ TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
];
}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 6d98c33ffb8da..17a2b00cb7140 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
void *PropValue, size_t *PropSizeRet) {
InfoWriter Info(PropSize, PropValue, PropSizeRet);
+ auto CheckKind = [&](ol_symbol_kind_t Required) {
+ if (Symbol->Kind != Required) {
+ std::string ErrBuffer;
+ llvm::raw_string_ostream(ErrBuffer)
+ << PropName << ": Expected a symbol of Kind " << Required
+ << " but given a symbol of Kind " << Symbol->Kind;
+ return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
+ }
+ return Plugin::success();
+ };
+
switch (PropName) {
case OL_SYMBOL_INFO_KIND:
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
+ case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
+ if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+ return Err;
+ return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
+ case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
+ if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+ return Err;
+ return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"olGetSymbolInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index d1189688a90a3..89d7c820426cf 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
if (Type == "char[]") {
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
} else {
- OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
- Type);
+ if (Type == "void *")
+ OS << formatv(TAB_2 "void * const * const tptr = (void * "
+ "const * const)ptr;\n");
+ else
+ OS << formatv(
+ TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
// TODO: Handle other cases here
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
if (Type.ends_with("*")) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index c1762b451b81d..c1fb6df9bad0d 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -13,6 +13,32 @@
using olMemcpyTest = OffloadQueueTest;
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
+struct olMemcpyGlobalTest : OffloadGlobalTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
+ ASSERT_SUCCESS(
+ olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel));
+ ASSERT_SUCCESS(
+ olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel));
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ ASSERT_SUCCESS(olGetSymbolInfo(
+ Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
+
+ LaunchArgs.Dimensions = 1;
+ LaunchArgs.GroupSize = {64, 1, 1};
+ LaunchArgs.NumGroups = {1, 1, 1};
+
+ LaunchArgs.DynSharedMemory = 0;
+ }
+
+ ol_kernel_launch_size_args_t LaunchArgs{};
+ void *Addr;
+ ol_symbol_handle_t ReadKernel;
+ ol_symbol_handle_t WriteKernel;
+ ol_queue_handle_t Queue;
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
+
TEST_P(olMemcpyTest, SuccessHtoD) {
constexpr size_t Size = 1024;
void *Alloc;
@@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
ASSERT_SUCCESS(
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
}
+
+TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &DestMem));
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessWrite) {
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+ struct {
+ void *Mem;
+ } Args{DestMem};
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessRead) {
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I * 2);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+}
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
index 100a374430372..ed8f4716974cd 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
}
+TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
+ void *RetrievedAddr;
+ ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+ olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+ sizeof(RetrievedAddr), &RetrievedAddr));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
+ void *RetrievedAddr = nullptr;
+ ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+ sizeof(RetrievedAddr), &RetrievedAddr));
+ ASSERT_NE(RetrievedAddr, nullptr);
+}
+
+TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
+ size_t RetrievedSize;
+ ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+ olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+ sizeof(RetrievedSize), &RetrievedSize));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
+ size_t RetrievedSize = 0;
+ ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+ sizeof(RetrievedSize), &RetrievedSize));
+ ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
+}
+
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
ol_symbol_kind_t RetrievedKind;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
index aa7a061a9ef7a..ec011865cc6ad 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
}
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(olGetSymbolInfoSize(
+ Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
+ ASSERT_EQ(Size, sizeof(void *));
+}
+
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(
+ olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
+ ASSERT_EQ(Size, sizeof(size_t));
+}
+
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
More information about the llvm-commits
mailing list