[llvm] [Offload] `olEnqueueHostCallback` (PR #152482)
Ross Brunton via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 7 05:08:31 PDT 2025
https://github.com/RossBrunton created https://github.com/llvm/llvm-project/pull/152482
Add an `olEnqueueHostCallback` method that allows enqueueing host work
to the stream.
>From 6c90bf709d480a98dfb28af517ea1203a1716644 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 7 Aug 2025 12:54:52 +0100
Subject: [PATCH] [Offload] `olEnqueueHostCallback`
Add an `olEnqueueHostCallback` method that allows enqueueing host work
to the stream.
---
offload/liboffload/API/APIDefs.td | 10 +++-
offload/liboffload/API/Queue.td | 26 ++++++++++
offload/liboffload/src/OffloadImpl.cpp | 7 +++
offload/plugins-nextgen/amdgpu/src/rtl.cpp | 48 +++++++++++++++++++
.../common/include/PluginInterface.h | 7 +++
.../common/src/PluginInterface.cpp | 10 ++++
offload/plugins-nextgen/cuda/src/rtl.cpp | 13 +++++
offload/plugins-nextgen/host/src/rtl.cpp | 6 +++
offload/unittests/OffloadAPI/CMakeLists.txt | 3 +-
.../queue/olEnqueueHostCallback.cpp | 48 +++++++++++++++++++
10 files changed, 176 insertions(+), 2 deletions(-)
create mode 100644 offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp
diff --git a/offload/liboffload/API/APIDefs.td b/offload/liboffload/API/APIDefs.td
index 640932dcf8464..bd4cbbaa546b2 100644
--- a/offload/liboffload/API/APIDefs.td
+++ b/offload/liboffload/API/APIDefs.td
@@ -31,6 +31,13 @@ class IsHandleType<string Type> {
!ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
}
+// Does the type end with '_cb_t'?
+class IsCallbackType<string Type> {
+ // size("_cb_t") == 5
+ bit ret = !if(!lt(!size(Type), 5), 0,
+ !ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
+}
+
// Does the type end with '*'?
class IsPointerType<string Type> {
bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
@@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
TypeInfo type_info = TypeInfo<"", "">;
bit IsHandle = IsHandleType<type>.ret;
bit IsPointer = IsPointerType<type>.ret;
+ bit IsCallback = IsCallbackType<type>.ret;
}
// A parameter whose range is described by other parameters in the function.
@@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
}
class ShouldCheckPointer<Param P> {
- bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
+ bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
}
// For a list of returns that contains a specific return code, find and append
diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td
index 7e4c5b8c68be2..3ab7ed32ba5fe 100644
--- a/offload/liboffload/API/Queue.td
+++ b/offload/liboffload/API/Queue.td
@@ -107,3 +107,29 @@ def : Function {
Return<"OL_ERRC_INVALID_QUEUE">
];
}
+
+def : FptrTypedef {
+ let name = "ol_queue_callback_cb_t";
+ let desc = "Callback function for use by `olEnqueueHostCallback`.";
+ let params = [
+ Param<"void *", "UserData", "user specified data passed into `olEnqueueHostCallback`.", PARAM_IN>,
+ ];
+ let return = "void";
+}
+
+def : Function {
+ let name = "olEnqueueHostCallback";
+ let desc = "Enqueue a callback function on the host.";
+ let details = [
+ "The provided function will be called from the same process as the one that called `olEnqueueHostCallback`.",
+ "The callback will not run until all previous work submitted to the queue has completed.",
+ "The callback must return before any work submitted to the queue after it is started.",
+ "The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
+ ];
+ let params = [
+ Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
+ Param<"ol_queue_callback_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
+ Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
+ ];
+ let returns = [];
+}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 272a12ab59a06..8fdfafbcaee60 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -830,5 +830,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
}
+Error olEnqueueHostCallback_impl(ol_queue_handle_t Queue,
+ ol_queue_callback_cb_t Callback,
+ void *UserData) {
+ return Queue->Device->Device->enqueueHostCallback(Callback, UserData,
+ Queue->AsyncInfo);
+}
+
} // namespace offload
} // namespace llvm
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 852c0e99b2266..829bb0732b28b 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy {
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;
+ /// Wrapper function for implementing host callbacks
+ static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
+ AMDGPUSignalTy *OutputSignal,
+ void (*Callback)(void *), void *UserData) {
+ if (InputSignal)
+ if (auto Err = InputSignal->wait())
+ // Wait shouldn't report an error
+ reportFatalInternalError(std::move(Err));
+
+ Callback(UserData);
+
+ OutputSignal->signal();
+ }
+
/// Return the current number of asynchronous operations on the stream.
uint32_t size() const { return NextSlot; }
@@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy {
OutputSignal->get());
}
+ Error pushHostCallback(void (*Callback)(void *), void *UserData) {
+ // Retrieve an available signal for the operation's output.
+ AMDGPUSignalTy *OutputSignal = nullptr;
+ if (auto Err = SignalManager.getResource(OutputSignal))
+ return Err;
+ OutputSignal->reset();
+ OutputSignal->increaseUseCount();
+
+ AMDGPUSignalTy *InputSignal;
+ {
+ std::lock_guard<std::mutex> Lock(Mutex);
+
+ // Consume stream slot and compute dependencies.
+ InputSignal = consume(OutputSignal).second;
+ }
+
+ // "Leaking" the thread here is consistent with other work added to the
+ // queue. The input and output signals will remain valid until the output is
+ // signaled.
+ std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
+ .detach();
+
+ return Plugin::success();
+ }
+
/// Synchronize with the stream. The current thread waits until all operations
/// are finalized and it performs the pending post actions (i.e., releasing
/// intermediate buffers).
@@ -2554,6 +2593,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}
+ Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+ AsyncInfoWrapperTy &AsyncInfo) override {
+ AMDGPUStreamTy *Stream = nullptr;
+ if (auto Err = getStream(AsyncInfo, Stream))
+ return Err;
+
+ return Stream->pushHostCallback(Callback, UserData);
+ };
+
/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 1d64193c17f6b..aad8c83d92e84 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -946,6 +946,13 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error initDeviceInfo(__tgt_device_info *DeviceInfo);
virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;
+ /// Enqueue a host call to AsyncInfo
+ Error enqueueHostCallback(void (*Callback)(void *), void *UserData,
+ __tgt_async_info *AsyncInfo);
+ virtual Error enqueueHostCallbackImpl(void (*Callback)(void *),
+ void *UserData,
+ AsyncInfoWrapperTy &AsyncInfo) = 0;
+
/// Create an event.
Error createEvent(void **EventPtrStorage);
virtual Error createEventImpl(void **EventPtrStorage) = 0;
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index bcc91798f3f90..80577b47a3dc6 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -1582,6 +1582,16 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
return Err;
}
+Error GenericDeviceTy::enqueueHostCallback(void (*Callback)(void *),
+ void *UserData,
+ __tgt_async_info *AsyncInfo) {
+ AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+
+ auto Err = enqueueHostCallbackImpl(Callback, UserData, AsyncInfoWrapper);
+ AsyncInfoWrapper.finalize(Err);
+ return Err;
+}
+
Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
assert(DeviceInfo && "Invalid device info");
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 7649fd9285bb5..0d0a80f62a116 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -875,6 +875,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
return Plugin::success();
}
+ Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+ AsyncInfoWrapperTy &AsyncInfo) override {
+ if (auto Err = setContext())
+ return Err;
+
+ CUstream Stream;
+ if (auto Err = getStream(AsyncInfo, Stream))
+ return Err;
+
+ CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
+ return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
+ };
+
/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 9abc3507f6e68..da8c92a6ce93d 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -319,6 +319,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
"initDeviceInfoImpl not supported");
}
+ Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+ AsyncInfoWrapperTy &AsyncInfo) override {
+ Callback(UserData);
+ return Plugin::success();
+ };
+
/// This plugin does not support the event API. Do nothing without failing.
Error createEventImpl(void **EventPtrStorage) override {
*EventPtrStorage = nullptr;
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index 8f0267eb39bdf..f3f1b4db2656a 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -41,7 +41,8 @@ add_offload_unittest("queue"
queue/olDestroyQueue.cpp
queue/olGetQueueInfo.cpp
queue/olGetQueueInfoSize.cpp
- queue/olWaitEvents.cpp)
+ queue/olWaitEvents.cpp
+ queue/olEnqueueHostCallback.cpp)
add_offload_unittest("symbol"
symbol/olGetSymbol.cpp
diff --git a/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp b/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp
new file mode 100644
index 0000000000000..27dbfe2f111ce
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp
@@ -0,0 +1,48 @@
+//===------- Offload API tests - olEnqueueHostCallback --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olEnqueueHostCallbackTest : OffloadQueueTest {};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olEnqueueHostCallbackTest);
+
+TEST_P(olEnqueueHostCallbackTest, Success) {
+ ASSERT_SUCCESS(olEnqueueHostCallback(Queue, [](void *) {}, nullptr));
+}
+
+TEST_P(olEnqueueHostCallbackTest, SuccessSequence) {
+ uint32_t Buff[16] = {1, 1};
+
+ for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
+ ASSERT_SUCCESS(olEnqueueHostCallback(
+ Queue,
+ [](void *BuffPtr) {
+ uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
+ AsU32[0] = AsU32[-1] + AsU32[-2];
+ },
+ BuffPtr));
+ }
+
+ ASSERT_SUCCESS(olSyncQueue(Queue));
+
+ for (uint32_t i = 2; i < 16; i++) {
+ ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
+ }
+}
+
+TEST_P(olEnqueueHostCallbackTest, InvalidNullCallback) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
+ olEnqueueHostCallback(Queue, nullptr, nullptr));
+}
+
+TEST_P(olEnqueueHostCallbackTest, InvalidNullQueue) {
+ ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
+ olEnqueueHostCallback(nullptr, [](void *) {}, nullptr));
+}
More information about the llvm-commits
mailing list