[llvm] [orc-rt] Add transparent SPS conversion for error/expected types. (PR #161768)

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 2 20:08:50 PDT 2025


https://github.com/lhames created https://github.com/llvm/llvm-project/pull/161768

This commit aims to reduce boilerplate by adding transparent conversion between Error/Expected types and their SPS-serializable counterparts (SPSSerializableError/SPSSerializableExpected). This allows SPSWrapperFunction calls and handles to be written in terms of Error/Expected directly.

This functionality can also be extended to transparently convert between other types. This may be used in the future to provide conversion between ExecutorAddr and native pointer types.

>From fbfe3a4b553f13ecec1d7043da9b279914176736 Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Fri, 3 Oct 2025 11:32:03 +1000
Subject: [PATCH] [orc-rt] Add transparent SPS conversion for error/expected
 types.

This commit aims to reduce boilerplate by adding transparent conversion between
Error/Expected types and their SPS-serializable counterparts
(SPSSerializableError/SPSSerializableExpected). This allows SPSWrapperFunction
calls and handles to be written in terms of Error/Expected directly.

This functionality can also be extended to transparently convert between other
types. This may be used in the future to provide conversion between
ExecutorAddr and native pointer types.
---
 orc-rt/include/orc-rt/SPSWrapperFunction.h  | 59 ++++++++++++++--
 orc-rt/include/orc-rt/WrapperFunction.h     |  3 +-
 orc-rt/unittests/SPSWrapperFunctionTest.cpp | 74 +++++++++++++++++++++
 3 files changed, 129 insertions(+), 7 deletions(-)

diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
index 3ea6406b69a37..14a3d8e3d6ad6 100644
--- a/orc-rt/include/orc-rt/SPSWrapperFunction.h
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -21,8 +21,10 @@ namespace orc_rt {
 namespace detail {
 
 template <typename... SPSArgTs> struct WFSPSHelper {
-  template <typename... ArgTs>
-  std::optional<WrapperFunctionBuffer> serialize(const ArgTs &...Args) {
+private:
+  template <typename... SerializableArgTs>
+  std::optional<WrapperFunctionBuffer>
+  serializeImpl(const SerializableArgTs &...Args) {
     auto R =
         WrapperFunctionBuffer::allocate(SPSArgList<SPSArgTs...>::size(Args...));
     SPSOutputBuffer OB(R.data(), R.size());
@@ -31,16 +33,61 @@ template <typename... SPSArgTs> struct WFSPSHelper {
     return std::move(R);
   }
 
+  template <typename T> static const T &toSerializable(const T &Arg) noexcept {
+    return Arg;
+  }
+
+  static SPSSerializableError toSerializable(Error Err) noexcept {
+    return SPSSerializableError(std::move(Err));
+  }
+
+  template <typename T>
+  static SPSSerializableExpected<T> toSerializable(Expected<T> Arg) noexcept {
+    return SPSSerializableExpected<T>(std::move(Arg));
+  }
+
+  template <typename... Ts> struct DeserializableTuple;
+
+  template <typename... Ts> struct DeserializableTuple<std::tuple<Ts...>> {
+    typedef std::tuple<
+        std::decay_t<decltype(toSerializable(std::declval<Ts>()))>...>
+        type;
+  };
+
+  template <typename... Ts>
+  using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
+
+  template <typename T> static T fromSerializable(T &&Arg) noexcept {
+    return Arg;
+  }
+
+  static Error fromSerializable(SPSSerializableError Err) noexcept {
+    return Err.toError();
+  }
+
+  template <typename T>
+  static Expected<T> fromSerializable(SPSSerializableExpected<T> Val) noexcept {
+    return Val.toExpected();
+  }
+
+public:
+  template <typename... ArgTs>
+  std::optional<WrapperFunctionBuffer> serialize(ArgTs &&...Args) {
+    return serializeImpl(toSerializable(std::forward<ArgTs>(Args))...);
+  }
+
   template <typename ArgTuple>
   std::optional<ArgTuple> deserialize(WrapperFunctionBuffer ArgBytes) {
     assert(!ArgBytes.getOutOfBandError() &&
            "Should not attempt to deserialize out-of-band error");
     SPSInputBuffer IB(ArgBytes.data(), ArgBytes.size());
-    ArgTuple Args;
-    if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>, ArgTuple>::deserialize(
-            IB, Args))
+    DeserializableTuple_t<ArgTuple> Args;
+    if (!SPSSerializationTraits<SPSTuple<SPSArgTs...>,
+                                decltype(Args)>::deserialize(IB, Args))
       return std::nullopt;
-    return Args;
+    return std::apply(
+        [](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
+        std::move(Args));
   }
 };
 
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index 233c3b21e041d..ca165db7188b4 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -168,7 +168,8 @@ struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
                                  Serializer &S) {
     if (auto Val = S.result().template deserialize<std::tuple<T>>(
             std::move(ResultBytes)))
-      return std::move(std::get<0>(*Val));
+      return Expected<T>(std::move(std::get<0>(*Val)),
+                         ForceExpectedSuccessValue());
     else
       return make_error<StringError>("Could not deserialize result");
   }
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
index 0b65515120b7f..c0c86ff8715ce 100644
--- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -144,3 +144,77 @@ TEST(SPSWrapperFunctionUtilsTest, TestBinaryOpViaFunctionPointer) {
       [&](Expected<int32_t> R) { Result = cantFail(std::move(R)); }, 41, 1);
   EXPECT_EQ(Result, 42);
 }
+
+static void improbable_feat_sps_wrapper(orc_rt_SessionRef Session,
+                                        void *CallCtx,
+                                        orc_rt_WrapperFunctionReturn Return,
+                                        orc_rt_WrapperFunctionBuffer ArgBytes) {
+  SPSWrapperFunction<SPSError(bool)>::handle(
+      Session, CallCtx, Return, ArgBytes,
+      [](move_only_function<void(Error)> Return, bool LuckyHat) {
+        if (LuckyHat)
+          Return(Error::success());
+        else
+          Return(make_error<StringError>("crushed by boulder"));
+      });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorSuccessCase) {
+  bool DidRun = false;
+  SPSWrapperFunction<SPSError(bool)>::call(
+      DirectCaller(nullptr, improbable_feat_sps_wrapper),
+      [&](Expected<Error> E) {
+        DidRun = true;
+        cantFail(cantFail(std::move(E)));
+      },
+      true);
+
+  EXPECT_TRUE(DidRun);
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningErrorFailureCase) {
+  std::string ErrMsg;
+  SPSWrapperFunction<SPSError(bool)>::call(
+      DirectCaller(nullptr, improbable_feat_sps_wrapper),
+      [&](Expected<Error> E) { ErrMsg = toString(cantFail(std::move(E))); },
+      false);
+
+  EXPECT_EQ(ErrMsg, "crushed by boulder");
+}
+
+static void halve_number_sps_wrapper(orc_rt_SessionRef Session, void *CallCtx,
+                                     orc_rt_WrapperFunctionReturn Return,
+                                     orc_rt_WrapperFunctionBuffer ArgBytes) {
+  SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::handle(
+      Session, CallCtx, Return, ArgBytes,
+      [](move_only_function<void(Expected<int32_t>)> Return, int N) {
+        if (N % 2 == 0)
+          Return(N >> 1);
+        else
+          Return(make_error<StringError>("N is not a multiple of 2"));
+      });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedSuccessCase) {
+  int32_t Result = 0;
+  SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
+      DirectCaller(nullptr, halve_number_sps_wrapper),
+      [&](Expected<Expected<int32_t>> R) {
+        Result = cantFail(cantFail(std::move(R)));
+      },
+      2);
+
+  EXPECT_EQ(Result, 1);
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
+  std::string ErrMsg;
+  SPSWrapperFunction<SPSExpected<int32_t>(int32_t)>::call(
+      DirectCaller(nullptr, halve_number_sps_wrapper),
+      [&](Expected<Expected<int32_t>> R) {
+        ErrMsg = toString(cantFail(std::move(R)).takeError());
+      },
+      3);
+
+  EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
+}



More information about the llvm-commits mailing list