[llvm] [orc-rt] WrapperFunction::handle: add by-ref args, minimize temporaries. (PR #161999)

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 4 23:44:02 PDT 2025


https://github.com/lhames created https://github.com/llvm/llvm-project/pull/161999

This adds support for WrapperFunction::handle handlers that take their arguments by reference, rather than by value.

This commit also reduces the number of temporary objects created to support SPS-transparent conversion in SPSWrapperFunction.

>From b918ed917b128c6baa4d33d4a79e9f19f71172cd Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Sun, 5 Oct 2025 16:59:35 +1100
Subject: [PATCH] [orc-rt] WrapperFunction::handle: add by-ref args, minimize
 temporaries.

This adds support for WrapperFunction::handle handlers that take their
arguments by reference, rather than by value.

This commit also reduces the number of temporary objects created to support
SPS-transparent conversion in SPSWrapperFunction.
---
 orc-rt/include/orc-rt/SPSWrapperFunction.h  |  9 ++-
 orc-rt/include/orc-rt/WrapperFunction.h     | 27 +++++--
 orc-rt/unittests/SPSWrapperFunctionTest.cpp | 79 +++++++++++++++++++++
 3 files changed, 107 insertions(+), 8 deletions(-)

diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
index 14a3d8e3d6ad6..3ed3295731780 100644
--- a/orc-rt/include/orc-rt/SPSWrapperFunction.h
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -57,8 +57,8 @@ template <typename... SPSArgTs> struct WFSPSHelper {
   template <typename... Ts>
   using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
 
-  template <typename T> static T fromSerializable(T &&Arg) noexcept {
-    return Arg;
+  template <typename T> static T &&fromSerializable(T &&Arg) noexcept {
+    return std::forward<T>(Arg);
   }
 
   static Error fromSerializable(SPSSerializableError Err) noexcept {
@@ -86,7 +86,10 @@ template <typename... SPSArgTs> struct WFSPSHelper {
                                 decltype(Args)>::deserialize(IB, Args))
       return std::nullopt;
     return std::apply(
-        [](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
+        [](auto &&...A) {
+          return std::optional<ArgTuple>(std::in_place,
+                                         std::move(fromSerializable(A))...);
+        },
         std::move(Args));
   }
 };
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index ca165db7188b4..47e770f0bfbf7 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl {
   static_assert(std::is_void_v<RetT>,
                 "Async wrapper function handler must return void");
   typedef ReturnT YieldType;
-  typedef std::tuple<ArgTs...> ArgTupleType;
+  typedef std::tuple<std::decay_t<ArgTs>...> ArgTupleType;
+
+  // Forwards arguments based on the parameter types of the handler.
+  template <typename FnT> class ForwardArgsAsRequested {
+  public:
+    ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {}
+    void operator()(ArgTs &...Args) { Fn(std::forward<ArgTs>(Args)...); }
+
+  private:
+    FnT Fn;
+  };
+
+  template <typename FnT>
+  static ForwardArgsAsRequested<std::decay_t<FnT>>
+  forwardArgsAsRequested(FnT &&Fn) {
+    return ForwardArgsAsRequested<std::decay_t<FnT>>(std::forward<FnT>(Fn));
+  }
 };
 
 template <typename C>
@@ -244,10 +260,11 @@ struct WrapperFunction {
 
     if (auto Args =
             S.arguments().template deserialize<ArgTuple>(std::move(ArgBytes)))
-      std::apply(bind_front(std::forward<Handler>(H),
-                            detail::StructuredYield<RetTupleType, Serializer>(
-                                Session, CallCtx, Return, std::move(S))),
-                 std::move(*Args));
+      std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
+                     std::forward<Handler>(H),
+                     detail::StructuredYield<RetTupleType, Serializer>(
+                         Session, CallCtx, Return, std::move(S)))),
+                 *Args);
     else
       Return(Session, CallCtx,
              WrapperFunctionBuffer::createOutOfBandError(
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
index c0c86ff8715ce..32aaa61639dbb 100644
--- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -10,6 +10,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "CommonTestUtils.h"
+
 #include "orc-rt/SPSWrapperFunction.h"
 #include "orc-rt/WrapperFunction.h"
 #include "orc-rt/move_only_function.h"
@@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
 
   EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
 }
+
+template <size_t N> struct SPSOpCounter {};
+
+namespace orc_rt {
+template <size_t N>
+class SPSSerializationTraits<SPSOpCounter<N>, OpCounter<N>> {
+public:
+  static size_t size(const OpCounter<N> &O) { return 0; }
+  static bool serialize(SPSOutputBuffer &OB, const OpCounter<N> &O) {
+    return true;
+  }
+  static bool deserialize(SPSInputBuffer &OB, OpCounter<N> &O) { return true; }
+};
+} // namespace orc_rt
+
+static void
+handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session,
+                                        void *CallCtx,
+                                        orc_rt_WrapperFunctionReturn Return,
+                                        orc_rt_WrapperFunctionBuffer ArgBytes) {
+  SPSWrapperFunction<void(
+      SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
+      SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes,
+                                [](move_only_function<void()> Return,
+                                   OpCounter<0>, OpCounter<1> &,
+                                   const OpCounter<2> &,
+                                   OpCounter<3> &&) { Return(); });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
+  // Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref
+  // arguments, and that we generate the expected number of moves.
+  OpCounter<0>::reset();
+  OpCounter<1>::reset();
+  OpCounter<2>::reset();
+  OpCounter<3>::reset();
+
+  bool DidRun = false;
+  SPSWrapperFunction<void(SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
+                          SPSOpCounter<3>)>::
+      call(
+          DirectCaller(nullptr, handle_with_reference_types_sps_wrapper),
+          [&](Error R) {
+            cantFail(std::move(R));
+            DidRun = true;
+          },
+          OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>());
+
+  EXPECT_TRUE(DidRun);
+
+  // We expect two default constructions for each parameter: one for the
+  // argument to call, and one for the object to deserialize into.
+  EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U);
+  EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U);
+  EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U);
+  EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U);
+
+  // Pass-by-value: we expect two moves (one for SPS transparent conversion,
+  // one to copy the value to the parameter), and no copies.
+  EXPECT_EQ(OpCounter<0>::moves(), 2U);
+  EXPECT_EQ(OpCounter<0>::copies(), 0U);
+
+  // Pass-by-lvalue-reference: we expect one move (for SPS transparent
+  // conversion), no copies.
+  EXPECT_EQ(OpCounter<1>::moves(), 1U);
+  EXPECT_EQ(OpCounter<1>::copies(), 0U);
+
+  // Pass-by-const-lvalue-reference: we expect one move (for SPS transparent
+  // conversion), no copies.
+  EXPECT_EQ(OpCounter<2>::moves(), 1U);
+  EXPECT_EQ(OpCounter<2>::copies(), 0U);
+
+  // Pass-by-rvalue-reference: we expect one move (for SPS transparent
+  // conversion), no copies.
+  EXPECT_EQ(OpCounter<3>::moves(), 1U);
+  EXPECT_EQ(OpCounter<3>::copies(), 0U);
+}



More information about the llvm-commits mailing list