[llvm] [orc-rt] Introduce WrapperFunction APIs. (PR #157091)
Lang Hames via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 5 05:27:26 PDT 2025
https://github.com/lhames created https://github.com/llvm/llvm-project/pull/157091
Introduces the following key APIs:
`orc_rt_WrapperFunction` defines the signature of an ORC asynchronous wrapper function:
```
typedef void (*orc_rt_WrapperFunctionReturn)(
orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionBuffer ResultBytes);
typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);
```
A wrapper function takes a reference to the session object, a context pointer for the call being made, and a pointer to an orc_rt_WrapperFunctionReturn function that can be used to send the result bytes.
The `orc_rt::WrapperFunction` utility simplifies the writing of wrapper functions whose arguments and return values are serialized/deserialized using an abstract serialization utility.
The `orc_rt::SPSWrapperFunction` utility provides a specialized version of `orc_rt::WrapperFunction` that uses SPS serialization.
>From 1ea5d81a05c2e19c834ec11a059f1f6486770006 Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Thu, 4 Sep 2025 21:21:40 +1000
Subject: [PATCH] [orc-rt] Introduce WrapperFunction APIs.
Introduces the following key APIs:
`orc_rt_WrapperFunction` defines the signature of an ORC asynchronous wrapper
function:
```
typedef void (*orc_rt_WrapperFunctionReturn)(
orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionBuffer ResultBytes);
typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
orc_rt_WrapperFunctionReturn Return,
orc_rt_WrapperFunctionBuffer ArgBytes);
```
A wrapper function takes a reference to the session object, a context pointer
for the call being made, and a pointer to an orc_rt_WrapperFunctionReturn
function that can be used to send the result bytes.
The `orc_rt::WrapperFunction` utility simplifies the writing of wrapper
functions whose arguments and return values are serialized/deserialized using
an abstract serialization utility.
The `orc_rt::SPSWrapperFunction` utility provides a specialized version of
`orc_rt::WrapperFunction` that uses SPS serialization.
---
orc-rt/include/CMakeLists.txt | 2 +
orc-rt/include/orc-rt-c/CoreTypes.h | 28 ++++
orc-rt/include/orc-rt-c/WrapperFunction.h | 20 +++
orc-rt/include/orc-rt/SPSWrapperFunction.h | 89 +++++++++++
orc-rt/include/orc-rt/WrapperFunction.h | 160 ++++++++++++++++++++
orc-rt/unittests/CMakeLists.txt | 1 +
orc-rt/unittests/SPSWrapperFunctionTest.cpp | 109 +++++++++++++
7 files changed, 409 insertions(+)
create mode 100644 orc-rt/include/orc-rt-c/CoreTypes.h
create mode 100644 orc-rt/include/orc-rt/SPSWrapperFunction.h
create mode 100644 orc-rt/unittests/SPSWrapperFunctionTest.cpp
diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt
index 07a7e52061d6c..67fe060c4b25b 100644
--- a/orc-rt/include/CMakeLists.txt
+++ b/orc-rt/include/CMakeLists.txt
@@ -1,4 +1,5 @@
set(ORC_RT_HEADERS
+ orc-rt-c/CoreTyspe.h
orc-rt-c/ExternC.h
orc-rt-c/WrapperFunction.h
orc-rt-c/orc-rt.h
@@ -13,6 +14,7 @@ set(ORC_RT_HEADERS
orc-rt/RTTI.h
orc-rt/WrapperFunction.h
orc-rt/SimplePackedSerialization.h
+ orc-rt/SPSWrapperFunction.h
orc-rt/bind.h
orc-rt/bit.h
orc-rt/move_only_function.h
diff --git a/orc-rt/include/orc-rt-c/CoreTypes.h b/orc-rt/include/orc-rt-c/CoreTypes.h
new file mode 100644
index 0000000000000..9b3fdbea41498
--- /dev/null
+++ b/orc-rt/include/orc-rt-c/CoreTypes.h
@@ -0,0 +1,28 @@
+/*===-- CoreTypes.h - Essential types for the ORC Runtime C APIs --*- C -*-===*\
+|* *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
+|* Exceptions. *|
+|* See https://llvm.org/LICENSE.txt for license information. *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
+|* *|
+|*===----------------------------------------------------------------------===*|
+|* *|
+|* Defines core types for the ORC runtime. *|
+|* *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef ORC_RT_C_CORETYPES_H
+#define ORC_RT_C_CORETYPES_H
+
+#include "orc-rt-c/ExternC.h"
+
+ORC_RT_C_EXTERN_C_BEGIN
+
+/**
+ * A reference to an orc_rt::Session instance.
+ */
+typedef struct orc_rt_OpaqueSession *orc_rt_SessionRef;
+
+ORC_RT_C_EXTERN_C_END
+
+#endif /* ORC_RT_C_CORETYPES_H */
diff --git a/orc-rt/include/orc-rt-c/WrapperFunction.h b/orc-rt/include/orc-rt-c/WrapperFunction.h
index b7dbc16978233..34bcdeffef9ee 100644
--- a/orc-rt/include/orc-rt-c/WrapperFunction.h
+++ b/orc-rt/include/orc-rt-c/WrapperFunction.h
@@ -14,6 +14,7 @@
#ifndef ORC_RT_C_WRAPPERFUNCTION_H
#define ORC_RT_C_WRAPPERFUNCTION_H
+#include "orc-rt-c/CoreTypes.h"
#include "orc-rt-c/ExternC.h"
#include <assert.h>
@@ -49,6 +50,25 @@ typedef struct {
size_t Size;
} orc_rt_WrapperFunctionBuffer;
+/**
+ * Asynchronous return function for an orc-rt wrapper function.
+ */
+typedef void (*orc_rt_WrapperFunctionReturn)(
+ orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionBuffer ResultBytes);
+
+/**
+ * orc-rt wrapper function prototype.
+ *
+ * ArgBytes contains the serialized arguments for the wrapper function.
+ * Session holds a reference to the session object.
+ * CallCtx holds a pointer to the context object for this particular call.
+ * Return holds a pointer to the return function.
+ */
+typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes);
+
/**
* Zero-initialize an orc_rt_WrapperFunctionBuffer.
*/
diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
new file mode 100644
index 0000000000000..d08176f676289
--- /dev/null
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -0,0 +1,89 @@
+//===--- SPSWrapperFunction.h -- SPS-serializing Wrapper utls ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for calling / handling wrapper functions that use SPS
+// serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_SPSWRAPPERFUNCTION_H
+#define ORC_RT_SPSWRAPPERFUNCTION_H
+
+#include "orc-rt/SimplePackedSerialization.h"
+#include "orc-rt/WrapperFunction.h"
+
+namespace orc_rt {
+namespace detail {
+
+template <typename... SPSArgTs> struct WFSPSSerializer {
+ template <typename... ArgTs>
+ std::optional<WrapperFunctionBuffer> operator()(const ArgTs &...Args) {
+ auto R =
+ WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
+ SPSOutputBuffer OB(R.data(), R.size());
+ if (!SPSArgList<SPSArgTs...>::serialize(OB, Args...))
+ return std::nullopt;
+ return std::move(R);
+ }
+};
+
+template <typename... SPSArgTs> struct WFSPSDeserializer {
+ template <typename... ArgTs>
+ bool operator()(WrapperFunctionBuffer &ArgBytes, ArgTs &...Args) {
+ assert(!ArgBytes.getOutOfBandError() &&
+ "Should not attempt to deserialize out-of-band error");
+ SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
+ return SPSArgList<SPSArgTs...>::deserialize(IB, Args...);
+ }
+};
+
+} // namespace detail
+
+template <typename SPSSig> struct WrapperFunctionSPSSerializer;
+
+template <typename SPSRetT, typename... SPSArgTs>
+struct WrapperFunctionSPSSerializer<SPSRetT(SPSArgTs...)> {
+ static detail::WFSPSSerializer<SPSArgTs...> argumentSerializer() noexcept {
+ return {};
+ }
+ static detail::WFSPSDeserializer<SPSArgTs...>
+ argumentDeserializer() noexcept {
+ return {};
+ }
+ static detail::WFSPSSerializer<SPSRetT> resultSerializer() noexcept {
+ return {};
+ }
+ static detail::WFSPSDeserializer<SPSRetT> resultDeserializer() noexcept {
+ return {};
+ }
+};
+
+/// Provides call and handle utilities to simplify writing and invocation of
+/// wrapper functions that use SimplePackedSerialization to serialize and
+/// deserialize their arguments and return values.
+template <typename SPSSig> struct SPSWrapperFunction {
+ template <typename Caller, typename ResultHandler, typename... ArgTs>
+ static void call(Caller &&C, ResultHandler &&RH, ArgTs &&...Args) {
+ WrapperFunction::call(
+ std::forward<Caller>(C), WrapperFunctionSPSSerializer<SPSSig>(),
+ std::forward<ResultHandler>(RH), std::forward<ArgTs>(Args)...);
+ }
+
+ template <typename Handler>
+ static void handle(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ WrapperFunctionBuffer ArgBytes, Handler &&H) {
+ WrapperFunction::handle(Session, CallCtx, Return, std::move(ArgBytes),
+ WrapperFunctionSPSSerializer<SPSSig>(),
+ std::forward<Handler>(H));
+ }
+};
+
+} // namespace orc_rt
+
+#endif // ORC_RT_SPSWRAPPERFUNCTION_H
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index eb64cf64450e7..24b149cbe15f3 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -14,6 +14,8 @@
#define ORC_RT_WRAPPERFUNCTION_H
#include "orc-rt-c/WrapperFunction.h"
+#include "orc-rt/Error.h"
+#include "orc-rt/bind.h"
#include <utility>
@@ -98,6 +100,164 @@ class WrapperFunctionBuffer {
orc_rt_WrapperFunctionBuffer B;
};
+namespace detail {
+
+template <typename C>
+struct WFCallableTraits
+ : public WFCallableTraits<
+ decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
+};
+
+template <typename RetT> struct WFCallableTraits<RetT()> {
+ typedef void HeadArgType;
+};
+
+template <typename RetT, typename ArgT, typename... ArgTs>
+struct WFCallableTraits<RetT(ArgT, ArgTs...)> {
+ typedef ArgT HeadArgType;
+ typedef std::tuple<ArgTs...> TailArgTuple;
+};
+
+template <typename ClassT, typename RetT, typename... ArgTs>
+struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)>
+ : public WFCallableTraits<RetT(ArgTs...)> {};
+
+template <typename ClassT, typename RetT, typename... ArgTs>
+struct WFCallableTraits<RetT (ClassT::*)(ArgTs...) const>
+ : public WFCallableTraits<RetT(ArgTs...)> {};
+
+template <typename Serializer> class StructuredYieldBase {
+public:
+ StructuredYieldBase(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return, Serializer &&S)
+ : Session(Session), CallCtx(CallCtx), Return(Return),
+ S(std::forward<Serializer>(S)) {}
+
+protected:
+ orc_rt_SessionRef Session;
+ void *CallCtx;
+ orc_rt_WrapperFunctionReturn Return;
+ std::decay_t<Serializer> S;
+};
+
+template <typename RetT, typename Serializer>
+class StructuredYield : public StructuredYieldBase<Serializer> {
+public:
+ using StructuredYieldBase<Serializer>::StructuredYieldBase;
+ void operator()(RetT &&R) {
+ if (auto ResultBytes = this->S.resultSerializer()(std::forward<RetT>(R)))
+ this->Return(this->Session, this->CallCtx, ResultBytes->release());
+ else
+ this->Return(this->Session, this->CallCtx,
+ WrapperFunctionBuffer::createOutOfBandError(
+ "Could not serialize wrapper function result data")
+ .release());
+ }
+};
+
+template <typename Serializer>
+class StructuredYield<void, Serializer>
+ : public StructuredYieldBase<Serializer> {
+public:
+ using StructuredYieldBase<Serializer>::StructuredYieldBase;
+ void operator()() {
+ this->Return(this->Session, this->CallCtx,
+ WrapperFunctionBuffer().release());
+ }
+};
+
+template <typename T, typename Serializer> struct ResultDeserializer;
+
+template <typename T, typename Serializer>
+struct ResultDeserializer<Expected<T>, Serializer> {
+ static Expected<T> deserialize(WrapperFunctionBuffer ResultBytes,
+ Serializer &S) {
+ T Val;
+ if (S.resultDeserializer()(ResultBytes, Val))
+ return std::move(Val);
+ else
+ return make_error<StringError>("Could not deserialize result");
+ }
+};
+
+template <typename Serializer> struct ResultDeserializer<Error, Serializer> {
+ static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) {
+ assert(ResultBytes.empty());
+ return Error::success();
+ }
+};
+
+} // namespace detail
+
+/// Provides call and handle utilities to simplify writing and invocation of
+/// wrapper functions in C++.
+struct WrapperFunction {
+
+ /// Make a call to a wrapper function.
+ ///
+ /// This utility serializes and deserializes arguments and return values
+ /// (using the given Serializer), and calls the wrapper function via the
+ /// given Caller object.
+ template <typename Caller, typename Serializer, typename ResultHandler,
+ typename... ArgTs>
+ static void call(Caller &&C, Serializer &&S, ResultHandler &&RH,
+ ArgTs &&...Args) {
+ typedef detail::WFCallableTraits<ResultHandler> ResultHandlerTraits;
+ static_assert(
+ std::tuple_size_v<typename ResultHandlerTraits::TailArgTuple> == 0,
+ "Expected one argument to result-handler");
+ typedef typename ResultHandlerTraits::HeadArgType ResultType;
+
+ if (auto ArgBytes = S.argumentSerializer()(std::forward<ArgTs>(Args)...)) {
+ C(
+ [RH = std::move(RH),
+ S = std::move(S)](orc_rt_SessionRef Session,
+ WrapperFunctionBuffer ResultBytes) mutable {
+ if (const char *ErrMsg = ResultBytes.getOutOfBandError())
+ RH(make_error<StringError>(ErrMsg));
+ else
+ RH(detail::ResultDeserializer<
+ ResultType, Serializer>::deserialize(std::move(ResultBytes),
+ S));
+ },
+ std::move(*ArgBytes));
+ } else
+ RH(make_error<StringError>(
+ "Could not serialize wrapper function call arguments"));
+ }
+
+ /// Simplifies implementation of wrapper functions in C++.
+ ///
+ /// This utility deserializes and serializes arguments and return values
+ /// (using the given Serializer), and calls the given handler.
+ template <typename Serializer, typename Handler>
+ static void handle(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ WrapperFunctionBuffer ArgBytes, Serializer &&S,
+ Handler &&H) {
+ typedef detail::WFCallableTraits<Handler> HandlerTraits;
+ typedef typename HandlerTraits::HeadArgType Yield;
+ typedef typename HandlerTraits::TailArgTuple ArgTuple;
+ typedef typename detail::WFCallableTraits<Yield>::HeadArgType RetType;
+
+ if (ArgBytes.getOutOfBandError())
+ return Return(Session, CallCtx, ArgBytes.release());
+
+ ArgTuple Args;
+ if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)),
+ Args))
+ std::apply(bind_front(std::forward<Handler>(H),
+ detail::StructuredYield<RetType, Serializer>(
+ Session, CallCtx, Return, std::move(S))),
+ std::move(Args));
+ else
+ Return(Session, CallCtx,
+ WrapperFunctionBuffer::createOutOfBandError(
+ "Could not deserialize wrapper function arg data")
+ .release());
+ }
+};
+
} // namespace orc_rt
#endif // ORC_RT_WRAPPERFUNCTION_H
diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt
index 55e089a539725..7bf53ca9826e2 100644
--- a/orc-rt/unittests/CMakeLists.txt
+++ b/orc-rt/unittests/CMakeLists.txt
@@ -22,6 +22,7 @@ add_orc_rt_unittest(CoreTests
MemoryFlagsTest.cpp
RTTITest.cpp
SimplePackedSerializationTest.cpp
+ SPSWrapperFunctionTest.cpp
WrapperFunctionBufferTest.cpp
bind-test.cpp
bit-test.cpp
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
new file mode 100644
index 0000000000000..919ec2cebd69b
--- /dev/null
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -0,0 +1,109 @@
+//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test SPSWrapperFunction and associated utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/SPSWrapperFunction.h"
+#include "orc-rt/WrapperFunction.h"
+#include "orc-rt/move_only_function.h"
+
+#include "gtest/gtest.h"
+
+using namespace orc_rt;
+
+/// Make calls and call result handlers directly on the current thread.
+class DirectCaller {
+private:
+ class DirectResultSender {
+ public:
+ virtual ~DirectResultSender() {}
+ virtual void send(orc_rt_SessionRef Session,
+ WrapperFunctionBuffer ResultBytes) = 0;
+ static void send(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionBuffer ResultBytes) {
+ std::unique_ptr<DirectResultSender>(
+ reinterpret_cast<DirectResultSender *>(CallCtx))
+ ->send(Session, ResultBytes);
+ }
+ };
+
+ template <typename ImplFn>
+ class DirectResultSenderImpl : public DirectResultSender {
+ public:
+ DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
+ void send(orc_rt_SessionRef Session,
+ WrapperFunctionBuffer ResultBytes) override {
+ Fn(Session, std::move(ResultBytes));
+ }
+
+ private:
+ std::decay_t<ImplFn> Fn;
+ };
+
+ template <typename ImplFn>
+ static std::unique_ptr<DirectResultSender>
+ makeDirectResultSender(ImplFn &&Fn) {
+ return std::make_unique<DirectResultSenderImpl<ImplFn>>(
+ std::forward<ImplFn>(Fn));
+ }
+
+public:
+ DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
+ : Session(Session), Fn(Fn) {}
+
+ template <typename HandleResultFn>
+ void operator()(HandleResultFn &&HandleResult,
+ WrapperFunctionBuffer ArgBytes) {
+ auto DR =
+ makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
+ Fn(Session, reinterpret_cast<void *>(DR.release()),
+ DirectResultSender::send, ArgBytes.release());
+ }
+
+private:
+ orc_rt_SessionRef Session;
+ orc_rt_WrapperFunction Fn;
+};
+
+static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<void()>::handle(
+ Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void()> Return) { Return(); });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) {
+ bool Ran = false;
+ SPSWrapperFunction<void()>::call(DirectCaller(nullptr, void_noop_sps_wrapper),
+ [&](Error Err) {
+ cantFail(std::move(Err));
+ Ran = true;
+ });
+ EXPECT_TRUE(Ran);
+}
+
+static void add_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
+ Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
+ Return(X + Y);
+ });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestAdd) {
+ int32_t Result = 0;
+ SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
+ DirectCaller(nullptr, add_sps_wrapper),
+ [&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
+ EXPECT_EQ(Result, 42);
+}
More information about the llvm-commits
mailing list