[llvm] [Offload] Add `olGetProgramGlobal` (PR #147774)
Ross Brunton via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 9 09:12:11 PDT 2025
https://github.com/RossBrunton created https://github.com/llvm/llvm-project/pull/147774
This entry point allows looking up a global in a program by name and getting its size and address.
>From 05d6648bd8984878a16d8fc58319ce5a4cde0b90 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 9 Jul 2025 12:08:34 +0100
Subject: [PATCH 1/3] [Offload] Allow querying the size of globals
The `GlobalTy` helper has been extended to make both the Size and Ptr be
optional. Now `getGlobalMetadataFromDevice`/`Image` is able to write the
size of the global to the struct, instead of just verifying it.
---
offload/plugins-nextgen/amdgpu/src/rtl.cpp | 5 ++-
.../common/include/GlobalHandler.h | 38 ++++++++++++++-----
offload/plugins-nextgen/cuda/src/rtl.cpp | 4 +-
3 files changed, 34 insertions(+), 13 deletions(-)
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 832c31c43b5d2..7e72564f35848 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -3089,15 +3089,16 @@ struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy {
}
// Check the size of the symbol.
- if (SymbolSize != DeviceGlobal.getSize())
+ if (DeviceGlobal.hasSize() && SymbolSize != DeviceGlobal.getSize())
return Plugin::error(
ErrorCode::INVALID_BINARY,
"failed to load global '%s' due to size mismatch (%zu != %zu)",
DeviceGlobal.getName().data(), SymbolSize,
(size_t)DeviceGlobal.getSize());
- // Store the symbol address on the device global metadata.
+ // Store the symbol address and size on the device global metadata.
DeviceGlobal.setPtr(reinterpret_cast<void *>(SymbolAddr));
+ DeviceGlobal.setSize(SymbolSize);
return Plugin::success();
}
diff --git a/offload/plugins-nextgen/common/include/GlobalHandler.h b/offload/plugins-nextgen/common/include/GlobalHandler.h
index 5d6109df49da5..26d63189a143b 100644
--- a/offload/plugins-nextgen/common/include/GlobalHandler.h
+++ b/offload/plugins-nextgen/common/include/GlobalHandler.h
@@ -37,20 +37,33 @@ using namespace llvm::object;
/// Common abstraction for globals that live on the host and device.
/// It simply encapsulates the symbol name, symbol size, and symbol address
/// (which might be host or device depending on the context).
+/// Both size and address may be absent, and can be populated with
+// getGlobalMetadataFromDevice/Image.
class GlobalTy {
// NOTE: Maybe we can have a pointer to the offload entry name instead of
// holding a private copy of the name as a std::string.
std::string Name;
- uint32_t Size;
- void *Ptr;
+ std::optional<uint32_t> Size;
+ std::optional<void *> Ptr;
public:
- GlobalTy(const std::string &Name, uint32_t Size, void *Ptr = nullptr)
+ GlobalTy(const std::string &Name) : Name(Name) {}
+ GlobalTy(const std::string &Name, uint32_t Size) : Name(Name), Size(Size) {}
+ GlobalTy(const std::string &Name, uint32_t Size, void *Ptr)
: Name(Name), Size(Size), Ptr(Ptr) {}
const std::string &getName() const { return Name; }
- uint32_t getSize() const { return Size; }
- void *getPtr() const { return Ptr; }
+ uint32_t getSize() const {
+ assert(hasSize() && "Size not initialised");
+ return *Size;
+ }
+ void *getPtr() const {
+ assert(hasPtr() && "Ptr not initialised");
+ return *Ptr;
+ }
+
+ bool hasSize() const { return Size.has_value(); }
+ bool hasPtr() const { return Ptr.has_value(); }
void setSize(int32_t S) { Size = S; }
void setPtr(void *P) { Ptr = P; }
@@ -139,8 +152,11 @@ class GenericGlobalHandlerTy {
bool isSymbolInImage(GenericDeviceTy &Device, DeviceImageTy &Image,
StringRef SymName);
- /// Get the address and size of a global in the image. Address and size are
- /// return in \p ImageGlobal, the global name is passed in \p ImageGlobal.
+ /// Get the address and size of a global in the image. Address is
+ /// returned in \p ImageGlobal and the global name is passed in \p
+ /// ImageGlobal. If no size is present in \p ImageGlobal, then the size of the
+ /// global will be stored there. If it is present, it will be validated
+ /// against the real size of the global.
Error getGlobalMetadataFromImage(GenericDeviceTy &Device,
DeviceImageTy &Image, GlobalTy &ImageGlobal);
@@ -149,9 +165,11 @@ class GenericGlobalHandlerTy {
Error readGlobalFromImage(GenericDeviceTy &Device, DeviceImageTy &Image,
const GlobalTy &HostGlobal);
- /// Get the address and size of a global from the device. Address is return in
- /// \p DeviceGlobal, the global name and expected size are passed in
- /// \p DeviceGlobal.
+ /// Get the address and size of a global from the device. Address is
+ /// returned in \p ImageGlobal and the global name is passed in \p
+ /// ImageGlobal. If no size is present in \p ImageGlobal, then the size of the
+ /// global will be stored there. If it is present, it will be validated
+ /// against the real size of the global.
virtual Error getGlobalMetadataFromDevice(GenericDeviceTy &Device,
DeviceImageTy &Image,
GlobalTy &DeviceGlobal) = 0;
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 53089df2d0f0d..16418bea91958 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -1355,13 +1355,15 @@ class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy {
GlobalName))
return Err;
- if (CUSize != DeviceGlobal.getSize())
+ if (DeviceGlobal.hasSize() && CUSize != DeviceGlobal.getSize())
return Plugin::error(
ErrorCode::INVALID_BINARY,
"failed to load global '%s' due to size mismatch (%zu != %zu)",
GlobalName, CUSize, (size_t)DeviceGlobal.getSize());
DeviceGlobal.setPtr(reinterpret_cast<void *>(CUPtr));
+ DeviceGlobal.setSize(CUSize);
+
return Plugin::success();
}
};
>From 69bcb87dce588b8bda65720c950512f92d8d1645 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 9 Jul 2025 15:32:18 +0100
Subject: [PATCH 2/3] Just use 0/nullptr as None
---
offload/plugins-nextgen/amdgpu/src/rtl.cpp | 2 +-
.../common/include/GlobalHandler.h | 25 ++++++-------------
offload/plugins-nextgen/cuda/src/rtl.cpp | 2 +-
3 files changed, 9 insertions(+), 20 deletions(-)
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 7e72564f35848..12c7cc62905c9 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -3089,7 +3089,7 @@ struct AMDGPUGlobalHandlerTy final : public GenericGlobalHandlerTy {
}
// Check the size of the symbol.
- if (DeviceGlobal.hasSize() && SymbolSize != DeviceGlobal.getSize())
+ if (DeviceGlobal.getSize() && SymbolSize != DeviceGlobal.getSize())
return Plugin::error(
ErrorCode::INVALID_BINARY,
"failed to load global '%s' due to size mismatch (%zu != %zu)",
diff --git a/offload/plugins-nextgen/common/include/GlobalHandler.h b/offload/plugins-nextgen/common/include/GlobalHandler.h
index 26d63189a143b..af7dac66ca85d 100644
--- a/offload/plugins-nextgen/common/include/GlobalHandler.h
+++ b/offload/plugins-nextgen/common/include/GlobalHandler.h
@@ -37,33 +37,22 @@ using namespace llvm::object;
/// Common abstraction for globals that live on the host and device.
/// It simply encapsulates the symbol name, symbol size, and symbol address
/// (which might be host or device depending on the context).
-/// Both size and address may be absent, and can be populated with
-// getGlobalMetadataFromDevice/Image.
+/// Both size and address may be absent (signified by 0/nullptr), and can be
+/// populated with getGlobalMetadataFromDevice/Image.
class GlobalTy {
// NOTE: Maybe we can have a pointer to the offload entry name instead of
// holding a private copy of the name as a std::string.
std::string Name;
- std::optional<uint32_t> Size;
- std::optional<void *> Ptr;
+ uint32_t Size;
+ void *Ptr;
public:
- GlobalTy(const std::string &Name) : Name(Name) {}
- GlobalTy(const std::string &Name, uint32_t Size) : Name(Name), Size(Size) {}
- GlobalTy(const std::string &Name, uint32_t Size, void *Ptr)
+ GlobalTy(const std::string &Name, uint32_t Size = 0, void *Ptr = nullptr)
: Name(Name), Size(Size), Ptr(Ptr) {}
const std::string &getName() const { return Name; }
- uint32_t getSize() const {
- assert(hasSize() && "Size not initialised");
- return *Size;
- }
- void *getPtr() const {
- assert(hasPtr() && "Ptr not initialised");
- return *Ptr;
- }
-
- bool hasSize() const { return Size.has_value(); }
- bool hasPtr() const { return Ptr.has_value(); }
+ uint32_t getSize() const { return Size; }
+ void *getPtr() const { return Ptr; }
void setSize(int32_t S) { Size = S; }
void setPtr(void *P) { Ptr = P; }
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 16418bea91958..15193de6ae430 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -1355,7 +1355,7 @@ class CUDAGlobalHandlerTy final : public GenericGlobalHandlerTy {
GlobalName))
return Err;
- if (DeviceGlobal.hasSize() && CUSize != DeviceGlobal.getSize())
+ if (DeviceGlobal.getSize() && CUSize != DeviceGlobal.getSize())
return Plugin::error(
ErrorCode::INVALID_BINARY,
"failed to load global '%s' due to size mismatch (%zu != %zu)",
>From 74961150794d57025afb000780d2c20568ff2646 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Wed, 9 Jul 2025 17:10:09 +0100
Subject: [PATCH 3/3] [Offload] Add `olGetProgramGlobal`
This entry point allows looking up a global in a program by name and
getting its size and address.
---
offload/liboffload/API/Program.td | 15 ++
offload/liboffload/src/OffloadImpl.cpp | 17 ++
offload/unittests/OffloadAPI/CMakeLists.txt | 3 +-
.../unittests/OffloadAPI/device_code/global.c | 1 +
.../OffloadAPI/program/olGetProgramGlobal.cpp | 178 ++++++++++++++++++
5 files changed, 213 insertions(+), 1 deletion(-)
create mode 100644 offload/unittests/OffloadAPI/program/olGetProgramGlobal.cpp
diff --git a/offload/liboffload/API/Program.td b/offload/liboffload/API/Program.td
index 0476fa1f7c27a..0523bc5660a1d 100644
--- a/offload/liboffload/API/Program.td
+++ b/offload/liboffload/API/Program.td
@@ -34,3 +34,18 @@ def : Function {
];
let returns = [];
}
+
+def : Function {
+ let name = "olGetProgramGlobal";
+ let desc = "Return the device address and/or size for the global variable specified by `GlobalName` in the given program.";
+ let details = [
+ "This pointer may be used by olMemcpy to copy memory to and from that global."
+ ];
+ let params = [
+ Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
+ Param<"const char*", "GlobalName", "null-terminated name of the global variable in the program", PARAM_IN>,
+ Param<"void **", "Address", "output pointer for the resolved address", PARAM_OUT_OPTIONAL>,
+ Param<"size_t*", "Size", "output pointer for the resolved size in bytes", PARAM_OUT_OPTIONAL>
+ ];
+ let returns = [];
+}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index f9da638436705..2e96d5aaa7afe 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -597,6 +597,23 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
return olDestroy(Program);
}
+Error olGetProgramGlobal_impl(ol_program_handle_t Program,
+ const char *GlobalName, void **Address,
+ size_t *Size) {
+ auto &Device = Program->Image->getDevice();
+ GlobalTy Global{GlobalName};
+ if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
+ Device, *Program->Image, Global))
+ return Res;
+
+ if (Address)
+ *Address = Global.getPtr();
+ if (Size)
+ *Size = Global.getSize();
+
+ return Error::success();
+}
+
Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
ol_kernel_handle_t *Kernel) {
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index 05e862865ed33..27afb82c58362 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -31,7 +31,8 @@ add_offload_unittest("platform"
add_offload_unittest("program"
program/olCreateProgram.cpp
- program/olDestroyProgram.cpp)
+ program/olDestroyProgram.cpp
+ program/olGetProgramGlobal.cpp)
add_offload_unittest("queue"
queue/olCreateQueue.cpp
diff --git a/offload/unittests/OffloadAPI/device_code/global.c b/offload/unittests/OffloadAPI/device_code/global.c
index b30e406fb98c7..9f27f9424324f 100644
--- a/offload/unittests/OffloadAPI/device_code/global.c
+++ b/offload/unittests/OffloadAPI/device_code/global.c
@@ -1,6 +1,7 @@
#include <gpuintrin.h>
#include <stdint.h>
+[[gnu::visibility("default")]]
uint32_t global[64];
__gpu_kernel void write() {
diff --git a/offload/unittests/OffloadAPI/program/olGetProgramGlobal.cpp b/offload/unittests/OffloadAPI/program/olGetProgramGlobal.cpp
new file mode 100644
index 0000000000000..8dbd9fe08b2b2
--- /dev/null
+++ b/offload/unittests/OffloadAPI/program/olGetProgramGlobal.cpp
@@ -0,0 +1,178 @@
+//===------- Offload API tests - olGetProgramGlobal -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olGetProgramGlobalTest : OffloadQueueTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp());
+ ASSERT_TRUE(TestEnvironment::loadDeviceBinary("global", Device, DeviceBin));
+ ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
+ ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
+ DeviceBin->getBufferSize(), &Program));
+ }
+
+ void TearDown() override {
+ if (Program) {
+ olDestroyProgram(Program);
+ }
+ RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown());
+ }
+
+ std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
+ ol_program_handle_t Program = nullptr;
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetProgramGlobalTest);
+
+struct olGetProgramGlobalKernelTest : olGetProgramGlobalTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(olGetProgramGlobalTest::SetUp());
+
+ ASSERT_SUCCESS(olGetKernel(Program, "read", &ReadKernel));
+ ASSERT_SUCCESS(olGetKernel(Program, "write", &WriteKernel));
+
+ LaunchArgs.Dimensions = 1;
+ LaunchArgs.GroupSize = {64, 1, 1};
+ LaunchArgs.NumGroups = {1, 1, 1};
+
+ LaunchArgs.DynSharedMemory = 0;
+ }
+
+ ol_kernel_handle_t ReadKernel = nullptr;
+ ol_kernel_handle_t WriteKernel = nullptr;
+ ol_kernel_launch_size_args_t LaunchArgs{};
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetProgramGlobalKernelTest);
+
+TEST_P(olGetProgramGlobalTest, SuccessGetAddr) {
+ void *Addr = nullptr;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", &Addr, nullptr));
+
+ ASSERT_NE(Addr, nullptr);
+}
+
+TEST_P(olGetProgramGlobalTest, SuccessGetSize) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", nullptr, &Size));
+
+ ASSERT_EQ(Size, 64 * sizeof(uint32_t));
+}
+
+TEST_P(olGetProgramGlobalTest, SuccessGetBoth) {
+ void *Addr = nullptr;
+ size_t Size = 0;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", &Addr, &Size));
+
+ ASSERT_EQ(Size, 64 * sizeof(uint32_t));
+ ASSERT_NE(Addr, nullptr);
+}
+
+TEST_P(olGetProgramGlobalTest, InvalidNullHandle) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
+ olGetProgramGlobal(nullptr, "global", nullptr, nullptr));
+}
+
+TEST_P(olGetProgramGlobalTest, InvalidNullString) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
+ olGetProgramGlobal(Program, nullptr, nullptr, nullptr));
+}
+
+TEST_P(olGetProgramGlobalTest, InvalidGlobalName) {
+ ASSERT_ERROR(OL_ERRC_NOT_FOUND,
+ olGetProgramGlobal(Program, "nosuchglobal", nullptr, nullptr));
+}
+
+TEST_P(olGetProgramGlobalTest, SuccessRoundTrip) {
+ void *Addr = nullptr;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", &Addr, nullptr));
+
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &DestMem));
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olGetProgramGlobalKernelTest, SuccessWriteGlobal) {
+ void *Addr = nullptr;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", &Addr, nullptr));
+
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+ struct {
+ void *Mem;
+ } Args{DestMem};
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olGetProgramGlobalKernelTest, SuccessReadGlobal) {
+ void *Addr = nullptr;
+ ASSERT_SUCCESS(olGetProgramGlobal(Program, "global", &Addr, nullptr));
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I * 2);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+}
More information about the llvm-commits
mailing list