[llvm] [orc-rt] Introduce WrapperFunction APIs. (PR #157091)

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 5 05:27:26 PDT 2025


https://github.com/lhames created https://github.com/llvm/llvm-project/pull/157091

Introduces the following key APIs:

`orc_rt_WrapperFunction` defines the signature of an ORC asynchronous wrapper function:

```
typedef void (*orc_rt_WrapperFunctionReturn)(
    orc_rt_SessionRef Session, void *CallCtx,
    orc_rt_WrapperFunctionBuffer ResultBytes);

typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
                                       orc_rt_WrapperFunctionReturn Return,
                                       orc_rt_WrapperFunctionBuffer ArgBytes);
```

A wrapper function takes a reference to the session object, a context pointer for the call being made, and a pointer to an orc_rt_WrapperFunctionReturn function that can be used to send the result bytes.

The `orc_rt::WrapperFunction` utility simplifies the writing of wrapper functions whose arguments and return values are serialized/deserialized using an abstract serialization utility.

The `orc_rt::SPSWrapperFunction` utility provides a specialized version of `orc_rt::WrapperFunction` that uses SPS serialization.

>From 1ea5d81a05c2e19c834ec11a059f1f6486770006 Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Thu, 4 Sep 2025 21:21:40 +1000
Subject: [PATCH] [orc-rt] Introduce WrapperFunction APIs.

Introduces the following key APIs:

`orc_rt_WrapperFunction` defines the signature of an ORC asynchronous wrapper
function:

```
typedef void (*orc_rt_WrapperFunctionReturn)(
    orc_rt_SessionRef Session, void *CallCtx,
    orc_rt_WrapperFunctionBuffer ResultBytes);

typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
                                       orc_rt_WrapperFunctionReturn Return,
                                       orc_rt_WrapperFunctionBuffer ArgBytes);
```

A wrapper function takes a reference to the session object, a context pointer
for the call being made, and a pointer to an orc_rt_WrapperFunctionReturn
function that can be used to send the result bytes.

The `orc_rt::WrapperFunction` utility simplifies the writing of wrapper
functions whose arguments and return values are serialized/deserialized using
an abstract serialization utility.

The `orc_rt::SPSWrapperFunction` utility provides a specialized version of
`orc_rt::WrapperFunction` that uses SPS serialization.
---
 orc-rt/include/CMakeLists.txt               |   2 +
 orc-rt/include/orc-rt-c/CoreTypes.h         |  28 ++++
 orc-rt/include/orc-rt-c/WrapperFunction.h   |  20 +++
 orc-rt/include/orc-rt/SPSWrapperFunction.h  |  89 +++++++++++
 orc-rt/include/orc-rt/WrapperFunction.h     | 160 ++++++++++++++++++++
 orc-rt/unittests/CMakeLists.txt             |   1 +
 orc-rt/unittests/SPSWrapperFunctionTest.cpp | 109 +++++++++++++
 7 files changed, 409 insertions(+)
 create mode 100644 orc-rt/include/orc-rt-c/CoreTypes.h
 create mode 100644 orc-rt/include/orc-rt/SPSWrapperFunction.h
 create mode 100644 orc-rt/unittests/SPSWrapperFunctionTest.cpp

diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt
index 07a7e52061d6c..67fe060c4b25b 100644
--- a/orc-rt/include/CMakeLists.txt
+++ b/orc-rt/include/CMakeLists.txt
@@ -1,4 +1,5 @@
 set(ORC_RT_HEADERS
+    orc-rt-c/CoreTyspe.h
     orc-rt-c/ExternC.h
     orc-rt-c/WrapperFunction.h
     orc-rt-c/orc-rt.h
@@ -13,6 +14,7 @@ set(ORC_RT_HEADERS
     orc-rt/RTTI.h
     orc-rt/WrapperFunction.h
     orc-rt/SimplePackedSerialization.h
+    orc-rt/SPSWrapperFunction.h
     orc-rt/bind.h
     orc-rt/bit.h
     orc-rt/move_only_function.h
diff --git a/orc-rt/include/orc-rt-c/CoreTypes.h b/orc-rt/include/orc-rt-c/CoreTypes.h
new file mode 100644
index 0000000000000..9b3fdbea41498
--- /dev/null
+++ b/orc-rt/include/orc-rt-c/CoreTypes.h
@@ -0,0 +1,28 @@
+/*===-- CoreTypes.h - Essential types for the ORC Runtime C APIs --*- C -*-===*\
+|*                                                                            *|
+|* Part of the LLVM Project, under the Apache License v2.0 with LLVM          *|
+|* Exceptions.                                                                *|
+|* See https://llvm.org/LICENSE.txt for license information.                  *|
+|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception                    *|
+|*                                                                            *|
+|*===----------------------------------------------------------------------===*|
+|*                                                                            *|
+|* Defines core types for the ORC runtime.                                    *|
+|*                                                                            *|
+\*===----------------------------------------------------------------------===*/
+
+#ifndef ORC_RT_C_CORETYPES_H
+#define ORC_RT_C_CORETYPES_H
+
+#include "orc-rt-c/ExternC.h"
+
+ORC_RT_C_EXTERN_C_BEGIN
+
+/**
+ * A reference to an orc_rt::Session instance.
+ */
+typedef struct orc_rt_OpaqueSession *orc_rt_SessionRef;
+
+ORC_RT_C_EXTERN_C_END
+
+#endif /* ORC_RT_C_CORETYPES_H */
diff --git a/orc-rt/include/orc-rt-c/WrapperFunction.h b/orc-rt/include/orc-rt-c/WrapperFunction.h
index b7dbc16978233..34bcdeffef9ee 100644
--- a/orc-rt/include/orc-rt-c/WrapperFunction.h
+++ b/orc-rt/include/orc-rt-c/WrapperFunction.h
@@ -14,6 +14,7 @@
 #ifndef ORC_RT_C_WRAPPERFUNCTION_H
 #define ORC_RT_C_WRAPPERFUNCTION_H
 
+#include "orc-rt-c/CoreTypes.h"
 #include "orc-rt-c/ExternC.h"
 
 #include <assert.h>
@@ -49,6 +50,25 @@ typedef struct {
   size_t Size;
 } orc_rt_WrapperFunctionBuffer;
 
+/**
+ * Asynchronous return function for an orc-rt wrapper function.
+ */
+typedef void (*orc_rt_WrapperFunctionReturn)(
+    orc_rt_SessionRef Session, void *CallCtx,
+    orc_rt_WrapperFunctionBuffer ResultBytes);
+
+/**
+ * orc-rt wrapper function prototype.
+ *
+ * ArgBytes contains the serialized arguments for the wrapper function.
+ * Session holds a reference to the session object.
+ * CallCtx holds a pointer to the context object for this particular call.
+ * Return holds a pointer to the return function.
+ */
+typedef void (*orc_rt_WrapperFunction)(orc_rt_SessionRef Session, void *CallCtx,
+                                       orc_rt_WrapperFunctionReturn Return,
+                                       orc_rt_WrapperFunctionBuffer ArgBytes);
+
 /**
  * Zero-initialize an orc_rt_WrapperFunctionBuffer.
  */
diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
new file mode 100644
index 0000000000000..d08176f676289
--- /dev/null
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -0,0 +1,89 @@
+//===--- SPSWrapperFunction.h -- SPS-serializing Wrapper utls ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Utilities for calling / handling wrapper functions that use SPS
+// serialization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_SPSWRAPPERFUNCTION_H
+#define ORC_RT_SPSWRAPPERFUNCTION_H
+
+#include "orc-rt/SimplePackedSerialization.h"
+#include "orc-rt/WrapperFunction.h"
+
+namespace orc_rt {
+namespace detail {
+
+template <typename... SPSArgTs> struct WFSPSSerializer {
+  template <typename... ArgTs>
+  std::optional<WrapperFunctionBuffer> operator()(const ArgTs &...Args) {
+    auto R =
+        WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
+    SPSOutputBuffer OB(R.data(), R.size());
+    if (!SPSArgList<SPSArgTs...>::serialize(OB, Args...))
+      return std::nullopt;
+    return std::move(R);
+  }
+};
+
+template <typename... SPSArgTs> struct WFSPSDeserializer {
+  template <typename... ArgTs>
+  bool operator()(WrapperFunctionBuffer &ArgBytes, ArgTs &...Args) {
+    assert(!ArgBytes.getOutOfBandError() &&
+           "Should not attempt to deserialize out-of-band error");
+    SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
+    return SPSArgList<SPSArgTs...>::deserialize(IB, Args...);
+  }
+};
+
+} // namespace detail
+
+template <typename SPSSig> struct WrapperFunctionSPSSerializer;
+
+template <typename SPSRetT, typename... SPSArgTs>
+struct WrapperFunctionSPSSerializer<SPSRetT(SPSArgTs...)> {
+  static detail::WFSPSSerializer<SPSArgTs...> argumentSerializer() noexcept {
+    return {};
+  }
+  static detail::WFSPSDeserializer<SPSArgTs...>
+  argumentDeserializer() noexcept {
+    return {};
+  }
+  static detail::WFSPSSerializer<SPSRetT> resultSerializer() noexcept {
+    return {};
+  }
+  static detail::WFSPSDeserializer<SPSRetT> resultDeserializer() noexcept {
+    return {};
+  }
+};
+
+/// Provides call and handle utilities to simplify writing and invocation of
+/// wrapper functions that use SimplePackedSerialization to serialize and
+/// deserialize their arguments and return values.
+template <typename SPSSig> struct SPSWrapperFunction {
+  template <typename Caller, typename ResultHandler, typename... ArgTs>
+  static void call(Caller &&C, ResultHandler &&RH, ArgTs &&...Args) {
+    WrapperFunction::call(
+        std::forward<Caller>(C), WrapperFunctionSPSSerializer<SPSSig>(),
+        std::forward<ResultHandler>(RH), std::forward<ArgTs>(Args)...);
+  }
+
+  template <typename Handler>
+  static void handle(orc_rt_SessionRef Session, void *CallCtx,
+                     orc_rt_WrapperFunctionReturn Return,
+                     WrapperFunctionBuffer ArgBytes, Handler &&H) {
+    WrapperFunction::handle(Session, CallCtx, Return, std::move(ArgBytes),
+                            WrapperFunctionSPSSerializer<SPSSig>(),
+                            std::forward<Handler>(H));
+  }
+};
+
+} // namespace orc_rt
+
+#endif // ORC_RT_SPSWRAPPERFUNCTION_H
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index eb64cf64450e7..24b149cbe15f3 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -14,6 +14,8 @@
 #define ORC_RT_WRAPPERFUNCTION_H
 
 #include "orc-rt-c/WrapperFunction.h"
+#include "orc-rt/Error.h"
+#include "orc-rt/bind.h"
 
 #include <utility>
 
@@ -98,6 +100,164 @@ class WrapperFunctionBuffer {
   orc_rt_WrapperFunctionBuffer B;
 };
 
+namespace detail {
+
+template <typename C>
+struct WFCallableTraits
+    : public WFCallableTraits<
+          decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
+};
+
+template <typename RetT> struct WFCallableTraits<RetT()> {
+  typedef void HeadArgType;
+};
+
+template <typename RetT, typename ArgT, typename... ArgTs>
+struct WFCallableTraits<RetT(ArgT, ArgTs...)> {
+  typedef ArgT HeadArgType;
+  typedef std::tuple<ArgTs...> TailArgTuple;
+};
+
+template <typename ClassT, typename RetT, typename... ArgTs>
+struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)>
+    : public WFCallableTraits<RetT(ArgTs...)> {};
+
+template <typename ClassT, typename RetT, typename... ArgTs>
+struct WFCallableTraits<RetT (ClassT::*)(ArgTs...) const>
+    : public WFCallableTraits<RetT(ArgTs...)> {};
+
+template <typename Serializer> class StructuredYieldBase {
+public:
+  StructuredYieldBase(orc_rt_SessionRef Session, void *CallCtx,
+                      orc_rt_WrapperFunctionReturn Return, Serializer &&S)
+      : Session(Session), CallCtx(CallCtx), Return(Return),
+        S(std::forward<Serializer>(S)) {}
+
+protected:
+  orc_rt_SessionRef Session;
+  void *CallCtx;
+  orc_rt_WrapperFunctionReturn Return;
+  std::decay_t<Serializer> S;
+};
+
+template <typename RetT, typename Serializer>
+class StructuredYield : public StructuredYieldBase<Serializer> {
+public:
+  using StructuredYieldBase<Serializer>::StructuredYieldBase;
+  void operator()(RetT &&R) {
+    if (auto ResultBytes = this->S.resultSerializer()(std::forward<RetT>(R)))
+      this->Return(this->Session, this->CallCtx, ResultBytes->release());
+    else
+      this->Return(this->Session, this->CallCtx,
+                   WrapperFunctionBuffer::createOutOfBandError(
+                       "Could not serialize wrapper function result data")
+                       .release());
+  }
+};
+
+template <typename Serializer>
+class StructuredYield<void, Serializer>
+    : public StructuredYieldBase<Serializer> {
+public:
+  using StructuredYieldBase<Serializer>::StructuredYieldBase;
+  void operator()() {
+    this->Return(this->Session, this->CallCtx,
+                 WrapperFunctionBuffer().release());
+  }
+};
+
+template <typename T, typename Serializer> struct ResultDeserializer;
+
+template <typename T, typename Serializer>
+struct ResultDeserializer<Expected<T>, Serializer> {
+  static Expected<T> deserialize(WrapperFunctionBuffer ResultBytes,
+                                 Serializer &S) {
+    T Val;
+    if (S.resultDeserializer()(ResultBytes, Val))
+      return std::move(Val);
+    else
+      return make_error<StringError>("Could not deserialize result");
+  }
+};
+
+template <typename Serializer> struct ResultDeserializer<Error, Serializer> {
+  static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) {
+    assert(ResultBytes.empty());
+    return Error::success();
+  }
+};
+
+} // namespace detail
+
+/// Provides call and handle utilities to simplify writing and invocation of
+/// wrapper functions in C++.
+struct WrapperFunction {
+
+  /// Make a call to a wrapper function.
+  ///
+  /// This utility serializes and deserializes arguments and return values
+  /// (using the given Serializer), and calls the wrapper function via the
+  /// given Caller object.
+  template <typename Caller, typename Serializer, typename ResultHandler,
+            typename... ArgTs>
+  static void call(Caller &&C, Serializer &&S, ResultHandler &&RH,
+                   ArgTs &&...Args) {
+    typedef detail::WFCallableTraits<ResultHandler> ResultHandlerTraits;
+    static_assert(
+        std::tuple_size_v<typename ResultHandlerTraits::TailArgTuple> == 0,
+        "Expected one argument to result-handler");
+    typedef typename ResultHandlerTraits::HeadArgType ResultType;
+
+    if (auto ArgBytes = S.argumentSerializer()(std::forward<ArgTs>(Args)...)) {
+      C(
+          [RH = std::move(RH),
+           S = std::move(S)](orc_rt_SessionRef Session,
+                             WrapperFunctionBuffer ResultBytes) mutable {
+            if (const char *ErrMsg = ResultBytes.getOutOfBandError())
+              RH(make_error<StringError>(ErrMsg));
+            else
+              RH(detail::ResultDeserializer<
+                  ResultType, Serializer>::deserialize(std::move(ResultBytes),
+                                                       S));
+          },
+          std::move(*ArgBytes));
+    } else
+      RH(make_error<StringError>(
+          "Could not serialize wrapper function call arguments"));
+  }
+
+  /// Simplifies implementation of wrapper functions in C++.
+  ///
+  /// This utility deserializes and serializes arguments and return values
+  /// (using the given Serializer), and calls the given handler.
+  template <typename Serializer, typename Handler>
+  static void handle(orc_rt_SessionRef Session, void *CallCtx,
+                     orc_rt_WrapperFunctionReturn Return,
+                     WrapperFunctionBuffer ArgBytes, Serializer &&S,
+                     Handler &&H) {
+    typedef detail::WFCallableTraits<Handler> HandlerTraits;
+    typedef typename HandlerTraits::HeadArgType Yield;
+    typedef typename HandlerTraits::TailArgTuple ArgTuple;
+    typedef typename detail::WFCallableTraits<Yield>::HeadArgType RetType;
+
+    if (ArgBytes.getOutOfBandError())
+      return Return(Session, CallCtx, ArgBytes.release());
+
+    ArgTuple Args;
+    if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)),
+                   Args))
+      std::apply(bind_front(std::forward<Handler>(H),
+                            detail::StructuredYield<RetType, Serializer>(
+                                Session, CallCtx, Return, std::move(S))),
+                 std::move(Args));
+    else
+      Return(Session, CallCtx,
+             WrapperFunctionBuffer::createOutOfBandError(
+                 "Could not deserialize wrapper function arg data")
+                 .release());
+  }
+};
+
 } // namespace orc_rt
 
 #endif // ORC_RT_WRAPPERFUNCTION_H
diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt
index 55e089a539725..7bf53ca9826e2 100644
--- a/orc-rt/unittests/CMakeLists.txt
+++ b/orc-rt/unittests/CMakeLists.txt
@@ -22,6 +22,7 @@ add_orc_rt_unittest(CoreTests
   MemoryFlagsTest.cpp
   RTTITest.cpp
   SimplePackedSerializationTest.cpp
+  SPSWrapperFunctionTest.cpp
   WrapperFunctionBufferTest.cpp
   bind-test.cpp
   bit-test.cpp
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
new file mode 100644
index 0000000000000..919ec2cebd69b
--- /dev/null
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -0,0 +1,109 @@
+//===-- SPSWrapperFunctionTest.cpp ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Test SPSWrapperFunction and associated utilities.
+//
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/SPSWrapperFunction.h"
+#include "orc-rt/WrapperFunction.h"
+#include "orc-rt/move_only_function.h"
+
+#include "gtest/gtest.h"
+
+using namespace orc_rt;
+
+/// Make calls and call result handlers directly on the current thread.
+class DirectCaller {
+private:
+  class DirectResultSender {
+  public:
+    virtual ~DirectResultSender() {}
+    virtual void send(orc_rt_SessionRef Session,
+                      WrapperFunctionBuffer ResultBytes) = 0;
+    static void send(orc_rt_SessionRef Session, void *CallCtx,
+                     orc_rt_WrapperFunctionBuffer ResultBytes) {
+      std::unique_ptr<DirectResultSender>(
+          reinterpret_cast<DirectResultSender *>(CallCtx))
+          ->send(Session, ResultBytes);
+    }
+  };
+
+  template <typename ImplFn>
+  class DirectResultSenderImpl : public DirectResultSender {
+  public:
+    DirectResultSenderImpl(ImplFn &&Fn) : Fn(std::forward<ImplFn>(Fn)) {}
+    void send(orc_rt_SessionRef Session,
+              WrapperFunctionBuffer ResultBytes) override {
+      Fn(Session, std::move(ResultBytes));
+    }
+
+  private:
+    std::decay_t<ImplFn> Fn;
+  };
+
+  template <typename ImplFn>
+  static std::unique_ptr<DirectResultSender>
+  makeDirectResultSender(ImplFn &&Fn) {
+    return std::make_unique<DirectResultSenderImpl<ImplFn>>(
+        std::forward<ImplFn>(Fn));
+  }
+
+public:
+  DirectCaller(orc_rt_SessionRef Session, orc_rt_WrapperFunction Fn)
+      : Session(Session), Fn(Fn) {}
+
+  template <typename HandleResultFn>
+  void operator()(HandleResultFn &&HandleResult,
+                  WrapperFunctionBuffer ArgBytes) {
+    auto DR =
+        makeDirectResultSender(std::forward<HandleResultFn>(HandleResult));
+    Fn(Session, reinterpret_cast<void *>(DR.release()),
+       DirectResultSender::send, ArgBytes.release());
+  }
+
+private:
+  orc_rt_SessionRef Session;
+  orc_rt_WrapperFunction Fn;
+};
+
+static void void_noop_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+                                  orc_rt_WrapperFunctionReturn Return,
+                                  orc_rt_WrapperFunctionBuffer ArgBytes) {
+  SPSWrapperFunction<void()>::handle(
+      Session, CallCtx, Return, ArgBytes,
+      [](move_only_function<void()> Return) { Return(); });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestVoidNoop) {
+  bool Ran = false;
+  SPSWrapperFunction<void()>::call(DirectCaller(nullptr, void_noop_sps_wrapper),
+                                   [&](Error Err) {
+                                     cantFail(std::move(Err));
+                                     Ran = true;
+                                   });
+  EXPECT_TRUE(Ran);
+}
+
+static void add_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+                            orc_rt_WrapperFunctionReturn Return,
+                            orc_rt_WrapperFunctionBuffer ArgBytes) {
+  SPSWrapperFunction<int32_t(int32_t, int32_t)>::handle(
+      Session, CallCtx, Return, ArgBytes,
+      [](move_only_function<void(int32_t)> Return, int32_t X, int32_t Y) {
+        Return(X + Y);
+      });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestAdd) {
+  int32_t Result = 0;
+  SPSWrapperFunction<int32_t(int32_t, int32_t)>::call(
+      DirectCaller(nullptr, add_sps_wrapper),
+      [&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
+  EXPECT_EQ(Result, 42);
+}



More information about the llvm-commits mailing list