[llvm] [Offload] `olEnqueueHostCallback` (PR #152482)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 7 05:08:31 PDT 2025


https://github.com/RossBrunton created https://github.com/llvm/llvm-project/pull/152482

Add an `olEnqueueHostCallback` method that allows enqueueing host work
to the stream.


>From 6c90bf709d480a98dfb28af517ea1203a1716644 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Thu, 7 Aug 2025 12:54:52 +0100
Subject: [PATCH] [Offload] `olEnqueueHostCallback`

Add an `olEnqueueHostCallback` method that allows enqueueing host work
to the stream.
---
 offload/liboffload/API/APIDefs.td             | 10 +++-
 offload/liboffload/API/Queue.td               | 26 ++++++++++
 offload/liboffload/src/OffloadImpl.cpp        |  7 +++
 offload/plugins-nextgen/amdgpu/src/rtl.cpp    | 48 +++++++++++++++++++
 .../common/include/PluginInterface.h          |  7 +++
 .../common/src/PluginInterface.cpp            | 10 ++++
 offload/plugins-nextgen/cuda/src/rtl.cpp      | 13 +++++
 offload/plugins-nextgen/host/src/rtl.cpp      |  6 +++
 offload/unittests/OffloadAPI/CMakeLists.txt   |  3 +-
 .../queue/olEnqueueHostCallback.cpp           | 48 +++++++++++++++++++
 10 files changed, 176 insertions(+), 2 deletions(-)
 create mode 100644 offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp

diff --git a/offload/liboffload/API/APIDefs.td b/offload/liboffload/API/APIDefs.td
index 640932dcf8464..bd4cbbaa546b2 100644
--- a/offload/liboffload/API/APIDefs.td
+++ b/offload/liboffload/API/APIDefs.td
@@ -31,6 +31,13 @@ class IsHandleType<string Type> {
                 !ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
 }
 
+// Does the type end with '_cb_t'?
+class IsCallbackType<string Type> {
+  // size("_cb_t") == 5
+  bit ret = !if(!lt(!size(Type), 5), 0,
+                !ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
+}
+
 // Does the type end with '*'?
 class IsPointerType<string Type> {
   bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
@@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
   TypeInfo type_info = TypeInfo<"", "">;
   bit IsHandle = IsHandleType<type>.ret;
   bit IsPointer = IsPointerType<type>.ret;
+  bit IsCallback = IsCallbackType<type>.ret;
 }
 
 // A parameter whose range is described by other parameters in the function.
@@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
 }
 
 class ShouldCheckPointer<Param P> {
-  bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
+  bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
 }
 
 // For a list of returns that contains a specific return code, find and append
diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td
index 7e4c5b8c68be2..3ab7ed32ba5fe 100644
--- a/offload/liboffload/API/Queue.td
+++ b/offload/liboffload/API/Queue.td
@@ -107,3 +107,29 @@ def : Function {
     Return<"OL_ERRC_INVALID_QUEUE">
   ];
 }
+
+def : FptrTypedef {
+  let name = "ol_queue_callback_cb_t";
+  let desc = "Callback function for use by `olEnqueueHostCallback`.";
+  let params = [
+    Param<"void *", "UserData", "user specified data passed into `olEnqueueHostCallback`.", PARAM_IN>,
+  ];
+  let return = "void";
+}
+
+def : Function {
+  let name = "olEnqueueHostCallback";
+  let desc = "Enqueue a callback function on the host.";
+  let details = [
+    "The provided function will be called from the same process as the one that called `olEnqueueHostCallback`.",
+    "The callback will not run until all previous work submitted to the queue has completed.",
+    "The callback must return before any work submitted to the queue after it is started.",
+    "The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
+  ];
+  let params = [
+    Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
+    Param<"ol_queue_callback_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
+    Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
+  ];
+  let returns = [];
+}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 272a12ab59a06..8fdfafbcaee60 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -830,5 +830,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
   return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
 }
 
+Error olEnqueueHostCallback_impl(ol_queue_handle_t Queue,
+                                 ol_queue_callback_cb_t Callback,
+                                 void *UserData) {
+  return Queue->Device->Device->enqueueHostCallback(Callback, UserData,
+                                                    Queue->AsyncInfo);
+}
+
 } // namespace offload
 } // namespace llvm
diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
index 852c0e99b2266..829bb0732b28b 100644
--- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp
+++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp
@@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy {
   /// Indicate to spread data transfers across all available SDMAs
   bool UseMultipleSdmaEngines;
 
+  /// Wrapper function for implementing host callbacks
+  static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
+                              AMDGPUSignalTy *OutputSignal,
+                              void (*Callback)(void *), void *UserData) {
+    if (InputSignal)
+      if (auto Err = InputSignal->wait())
+        // Wait shouldn't report an error
+        reportFatalInternalError(std::move(Err));
+
+    Callback(UserData);
+
+    OutputSignal->signal();
+  }
+
   /// Return the current number of asynchronous operations on the stream.
   uint32_t size() const { return NextSlot; }
 
@@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy {
                                    OutputSignal->get());
   }
 
+  Error pushHostCallback(void (*Callback)(void *), void *UserData) {
+    // Retrieve an available signal for the operation's output.
+    AMDGPUSignalTy *OutputSignal = nullptr;
+    if (auto Err = SignalManager.getResource(OutputSignal))
+      return Err;
+    OutputSignal->reset();
+    OutputSignal->increaseUseCount();
+
+    AMDGPUSignalTy *InputSignal;
+    {
+      std::lock_guard<std::mutex> Lock(Mutex);
+
+      // Consume stream slot and compute dependencies.
+      InputSignal = consume(OutputSignal).second;
+    }
+
+    // "Leaking" the thread here is consistent with other work added to the
+    // queue. The input and output signals will remain valid until the output is
+    // signaled.
+    std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
+        .detach();
+
+    return Plugin::success();
+  }
+
   /// Synchronize with the stream. The current thread waits until all operations
   /// are finalized and it performs the pending post actions (i.e., releasing
   /// intermediate buffers).
@@ -2554,6 +2593,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
     return Plugin::success();
   }
 
+  Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+                                AsyncInfoWrapperTy &AsyncInfo) override {
+    AMDGPUStreamTy *Stream = nullptr;
+    if (auto Err = getStream(AsyncInfo, Stream))
+      return Err;
+
+    return Stream->pushHostCallback(Callback, UserData);
+  };
+
   /// Create an event.
   Error createEventImpl(void **EventPtrStorage) override {
     AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index 1d64193c17f6b..aad8c83d92e84 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -946,6 +946,13 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
   Error initDeviceInfo(__tgt_device_info *DeviceInfo);
   virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;
 
+  /// Enqueue a host call to  AsyncInfo
+  Error enqueueHostCallback(void (*Callback)(void *), void *UserData,
+                            __tgt_async_info *AsyncInfo);
+  virtual Error enqueueHostCallbackImpl(void (*Callback)(void *),
+                                        void *UserData,
+                                        AsyncInfoWrapperTy &AsyncInfo) = 0;
+
   /// Create an event.
   Error createEvent(void **EventPtrStorage);
   virtual Error createEventImpl(void **EventPtrStorage) = 0;
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index bcc91798f3f90..80577b47a3dc6 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -1582,6 +1582,16 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
   return Err;
 }
 
+Error GenericDeviceTy::enqueueHostCallback(void (*Callback)(void *),
+                                           void *UserData,
+                                           __tgt_async_info *AsyncInfo) {
+  AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
+
+  auto Err = enqueueHostCallbackImpl(Callback, UserData, AsyncInfoWrapper);
+  AsyncInfoWrapper.finalize(Err);
+  return Err;
+}
+
 Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
   assert(DeviceInfo && "Invalid device info");
 
diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp
index 7649fd9285bb5..0d0a80f62a116 100644
--- a/offload/plugins-nextgen/cuda/src/rtl.cpp
+++ b/offload/plugins-nextgen/cuda/src/rtl.cpp
@@ -875,6 +875,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
     return Plugin::success();
   }
 
+  Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+                                AsyncInfoWrapperTy &AsyncInfo) override {
+    if (auto Err = setContext())
+      return Err;
+
+    CUstream Stream;
+    if (auto Err = getStream(AsyncInfo, Stream))
+      return Err;
+
+    CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
+    return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
+  };
+
   /// Create an event.
   Error createEventImpl(void **EventPtrStorage) override {
     CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index 9abc3507f6e68..da8c92a6ce93d 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -319,6 +319,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
                          "initDeviceInfoImpl not supported");
   }
 
+  Error enqueueHostCallbackImpl(void (*Callback)(void *), void *UserData,
+                                AsyncInfoWrapperTy &AsyncInfo) override {
+    Callback(UserData);
+    return Plugin::success();
+  };
+
   /// This plugin does not support the event API. Do nothing without failing.
   Error createEventImpl(void **EventPtrStorage) override {
     *EventPtrStorage = nullptr;
diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt
index 8f0267eb39bdf..f3f1b4db2656a 100644
--- a/offload/unittests/OffloadAPI/CMakeLists.txt
+++ b/offload/unittests/OffloadAPI/CMakeLists.txt
@@ -41,7 +41,8 @@ add_offload_unittest("queue"
     queue/olDestroyQueue.cpp
     queue/olGetQueueInfo.cpp
     queue/olGetQueueInfoSize.cpp
-    queue/olWaitEvents.cpp)
+    queue/olWaitEvents.cpp
+    queue/olEnqueueHostCallback.cpp)
 
 add_offload_unittest("symbol"
     symbol/olGetSymbol.cpp
diff --git a/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp b/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp
new file mode 100644
index 0000000000000..27dbfe2f111ce
--- /dev/null
+++ b/offload/unittests/OffloadAPI/queue/olEnqueueHostCallback.cpp
@@ -0,0 +1,48 @@
+//===------- Offload API tests - olEnqueueHostCallback --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "../common/Fixtures.hpp"
+#include <OffloadAPI.h>
+#include <gtest/gtest.h>
+
+struct olEnqueueHostCallbackTest : OffloadQueueTest {};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olEnqueueHostCallbackTest);
+
+TEST_P(olEnqueueHostCallbackTest, Success) {
+  ASSERT_SUCCESS(olEnqueueHostCallback(Queue, [](void *) {}, nullptr));
+}
+
+TEST_P(olEnqueueHostCallbackTest, SuccessSequence) {
+  uint32_t Buff[16] = {1, 1};
+
+  for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
+    ASSERT_SUCCESS(olEnqueueHostCallback(
+        Queue,
+        [](void *BuffPtr) {
+          uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
+          AsU32[0] = AsU32[-1] + AsU32[-2];
+        },
+        BuffPtr));
+  }
+
+  ASSERT_SUCCESS(olSyncQueue(Queue));
+
+  for (uint32_t i = 2; i < 16; i++) {
+    ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
+  }
+}
+
+TEST_P(olEnqueueHostCallbackTest, InvalidNullCallback) {
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
+               olEnqueueHostCallback(Queue, nullptr, nullptr));
+}
+
+TEST_P(olEnqueueHostCallbackTest, InvalidNullQueue) {
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
+               olEnqueueHostCallback(nullptr, [](void *) {}, nullptr));
+}



More information about the llvm-commits mailing list