[llvm] [orc-rt] WrapperFunction::handle: add by-ref args, minimize temporaries. (PR #161999)
Lang Hames via llvm-commits
llvm-commits at lists.llvm.org
Sat Oct 4 23:44:02 PDT 2025
https://github.com/lhames created https://github.com/llvm/llvm-project/pull/161999
This adds support for WrapperFunction::handle handlers that take their arguments by reference, rather than by value.
This commit also reduces the number of temporary objects created to support SPS-transparent conversion in SPSWrapperFunction.
>From b918ed917b128c6baa4d33d4a79e9f19f71172cd Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Sun, 5 Oct 2025 16:59:35 +1100
Subject: [PATCH] [orc-rt] WrapperFunction::handle: add by-ref args, minimize
temporaries.
This adds support for WrapperFunction::handle handlers that take their
arguments by reference, rather than by value.
This commit also reduces the number of temporary objects created to support
SPS-transparent conversion in SPSWrapperFunction.
---
orc-rt/include/orc-rt/SPSWrapperFunction.h | 9 ++-
orc-rt/include/orc-rt/WrapperFunction.h | 27 +++++--
orc-rt/unittests/SPSWrapperFunctionTest.cpp | 79 +++++++++++++++++++++
3 files changed, 107 insertions(+), 8 deletions(-)
diff --git a/orc-rt/include/orc-rt/SPSWrapperFunction.h b/orc-rt/include/orc-rt/SPSWrapperFunction.h
index 14a3d8e3d6ad6..3ed3295731780 100644
--- a/orc-rt/include/orc-rt/SPSWrapperFunction.h
+++ b/orc-rt/include/orc-rt/SPSWrapperFunction.h
@@ -57,8 +57,8 @@ template <typename... SPSArgTs> struct WFSPSHelper {
template <typename... Ts>
using DeserializableTuple_t = typename DeserializableTuple<Ts...>::type;
- template <typename T> static T fromSerializable(T &&Arg) noexcept {
- return Arg;
+ template <typename T> static T &&fromSerializable(T &&Arg) noexcept {
+ return std::forward<T>(Arg);
}
static Error fromSerializable(SPSSerializableError Err) noexcept {
@@ -86,7 +86,10 @@ template <typename... SPSArgTs> struct WFSPSHelper {
decltype(Args)>::deserialize(IB, Args))
return std::nullopt;
return std::apply(
- [](auto &&...A) { return ArgTuple(fromSerializable(A)...); },
+ [](auto &&...A) {
+ return std::optional<ArgTuple>(std::in_place,
+ std::move(fromSerializable(A))...);
+ },
std::move(Args));
}
};
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index ca165db7188b4..47e770f0bfbf7 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -111,7 +111,23 @@ struct WFHandlerTraitsImpl {
static_assert(std::is_void_v<RetT>,
"Async wrapper function handler must return void");
typedef ReturnT YieldType;
- typedef std::tuple<ArgTs...> ArgTupleType;
+ typedef std::tuple<std::decay_t<ArgTs>...> ArgTupleType;
+
+ // Forwards arguments based on the parameter types of the handler.
+ template <typename FnT> class ForwardArgsAsRequested {
+ public:
+ ForwardArgsAsRequested(FnT &&Fn) : Fn(std::move(Fn)) {}
+ void operator()(ArgTs &...Args) { Fn(std::forward<ArgTs>(Args)...); }
+
+ private:
+ FnT Fn;
+ };
+
+ template <typename FnT>
+ static ForwardArgsAsRequested<std::decay_t<FnT>>
+ forwardArgsAsRequested(FnT &&Fn) {
+ return ForwardArgsAsRequested<std::decay_t<FnT>>(std::forward<FnT>(Fn));
+ }
};
template <typename C>
@@ -244,10 +260,11 @@ struct WrapperFunction {
if (auto Args =
S.arguments().template deserialize<ArgTuple>(std::move(ArgBytes)))
- std::apply(bind_front(std::forward<Handler>(H),
- detail::StructuredYield<RetTupleType, Serializer>(
- Session, CallCtx, Return, std::move(S))),
- std::move(*Args));
+ std::apply(HandlerTraits::forwardArgsAsRequested(bind_front(
+ std::forward<Handler>(H),
+ detail::StructuredYield<RetTupleType, Serializer>(
+ Session, CallCtx, Return, std::move(S)))),
+ *Args);
else
Return(Session, CallCtx,
WrapperFunctionBuffer::createOutOfBandError(
diff --git a/orc-rt/unittests/SPSWrapperFunctionTest.cpp b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
index c0c86ff8715ce..32aaa61639dbb 100644
--- a/orc-rt/unittests/SPSWrapperFunctionTest.cpp
+++ b/orc-rt/unittests/SPSWrapperFunctionTest.cpp
@@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
+#include "CommonTestUtils.h"
+
#include "orc-rt/SPSWrapperFunction.h"
#include "orc-rt/WrapperFunction.h"
#include "orc-rt/move_only_function.h"
@@ -218,3 +220,80 @@ TEST(SPSWrapperFunctionUtilsTest, TestFunctionReturningExpectedFailureCase) {
EXPECT_EQ(ErrMsg, "N is not a multiple of 2");
}
+
+template <size_t N> struct SPSOpCounter {};
+
+namespace orc_rt {
+template <size_t N>
+class SPSSerializationTraits<SPSOpCounter<N>, OpCounter<N>> {
+public:
+ static size_t size(const OpCounter<N> &O) { return 0; }
+ static bool serialize(SPSOutputBuffer &OB, const OpCounter<N> &O) {
+ return true;
+ }
+ static bool deserialize(SPSInputBuffer &OB, OpCounter<N> &O) { return true; }
+};
+} // namespace orc_rt
+
+static void
+handle_with_reference_types_sps_wrapper(orc_rt_SessionRef Session,
+ void *CallCtx,
+ orc_rt_WrapperFunctionReturn Return,
+ orc_rt_WrapperFunctionBuffer ArgBytes) {
+ SPSWrapperFunction<void(
+ SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
+ SPSOpCounter<3>)>::handle(Session, CallCtx, Return, ArgBytes,
+ [](move_only_function<void()> Return,
+ OpCounter<0>, OpCounter<1> &,
+ const OpCounter<2> &,
+ OpCounter<3> &&) { Return(); });
+}
+
+TEST(SPSWrapperFunctionUtilsTest, TestHandlerWithReferences) {
+ // Test that we can handle by-value, by-ref, by-const-ref, and by-rvalue-ref
+ // arguments, and that we generate the expected number of moves.
+ OpCounter<0>::reset();
+ OpCounter<1>::reset();
+ OpCounter<2>::reset();
+ OpCounter<3>::reset();
+
+ bool DidRun = false;
+ SPSWrapperFunction<void(SPSOpCounter<0>, SPSOpCounter<1>, SPSOpCounter<2>,
+ SPSOpCounter<3>)>::
+ call(
+ DirectCaller(nullptr, handle_with_reference_types_sps_wrapper),
+ [&](Error R) {
+ cantFail(std::move(R));
+ DidRun = true;
+ },
+ OpCounter<0>(), OpCounter<1>(), OpCounter<2>(), OpCounter<3>());
+
+ EXPECT_TRUE(DidRun);
+
+ // We expect two default constructions for each parameter: one for the
+ // argument to call, and one for the object to deserialize into.
+ EXPECT_EQ(OpCounter<0>::defaultConstructions(), 2U);
+ EXPECT_EQ(OpCounter<1>::defaultConstructions(), 2U);
+ EXPECT_EQ(OpCounter<2>::defaultConstructions(), 2U);
+ EXPECT_EQ(OpCounter<3>::defaultConstructions(), 2U);
+
+ // Pass-by-value: we expect two moves (one for SPS transparent conversion,
+ // one to copy the value to the parameter), and no copies.
+ EXPECT_EQ(OpCounter<0>::moves(), 2U);
+ EXPECT_EQ(OpCounter<0>::copies(), 0U);
+
+ // Pass-by-lvalue-reference: we expect one move (for SPS transparent
+ // conversion), no copies.
+ EXPECT_EQ(OpCounter<1>::moves(), 1U);
+ EXPECT_EQ(OpCounter<1>::copies(), 0U);
+
+ // Pass-by-const-lvalue-reference: we expect one move (for SPS transparent
+ // conversion), no copies.
+ EXPECT_EQ(OpCounter<2>::moves(), 1U);
+ EXPECT_EQ(OpCounter<2>::copies(), 0U);
+
+ // Pass-by-rvalue-reference: we expect one move (for SPS transparent
+ // conversion), no copies.
+ EXPECT_EQ(OpCounter<3>::moves(), 1U);
+ EXPECT_EQ(OpCounter<3>::copies(), 0U);
+}
More information about the llvm-commits
mailing list