[llvm] [orc-rt] Add CallableTraitsHelper, refactor WrapperFunction to use it. (PR #161761)

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 2 18:45:49 PDT 2025


https://github.com/lhames created https://github.com/llvm/llvm-project/pull/161761

CallableTraitsHelper identifies the return type and argument types of a callable type and passes those to an implementation class template to operate on.

The CallableArgInfo utility uses CallableTraitsHelper to provide typedefs for the return type and argument types (as a tuple) of a callable type.

In WrapperFunction.h, the detail::WFCallableTraits utility is rewritten in terms of CallableTraitsHandler (and renamed to WFHandlerTraits).

>From 568a5d69172745de29c0b3e7912f20febc81bc8e Mon Sep 17 00:00:00 2001
From: Lang Hames <lhames at gmail.com>
Date: Sun, 14 Sep 2025 11:44:56 +1000
Subject: [PATCH] [orc-rt] Add CallableTraitsHelper, refactor WrapperFunction
 to use it.

CallableTraitsHelper identifies the return type and argument types of a
callable type and passes those to an implementation class template to
operate on.

The CallableArgInfo utility uses CallableTraitsHelper to provide typedefs for
the return type and argument types (as a tuple) of a callable type.

In WrapperFunction.h, the detail::WFCallableTraits utility is rewritten in
terms of CallableTraitsHandler (and renamed to WFHandlerTraits).
---
 orc-rt/include/orc-rt/CallableTraitsHelper.h  | 74 ++++++++++++++++++
 orc-rt/include/orc-rt/WrapperFunction.h       | 77 ++++++++-----------
 orc-rt/unittests/CMakeLists.txt               |  1 +
 orc-rt/unittests/CallableTraitsHelperTest.cpp | 69 +++++++++++++++++
 4 files changed, 176 insertions(+), 45 deletions(-)
 create mode 100644 orc-rt/include/orc-rt/CallableTraitsHelper.h
 create mode 100644 orc-rt/unittests/CallableTraitsHelperTest.cpp

diff --git a/orc-rt/include/orc-rt/CallableTraitsHelper.h b/orc-rt/include/orc-rt/CallableTraitsHelper.h
new file mode 100644
index 0000000000000..12d7d5672c73a
--- /dev/null
+++ b/orc-rt/include/orc-rt/CallableTraitsHelper.h
@@ -0,0 +1,74 @@
+//===- CallableTraitsHelper.h - Callable arg/ret type extractor -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// CallableTraitsHelper API.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_CALLABLETRAITSHELPER_H
+#define ORC_RT_CALLABLETRAITSHELPER_H
+
+#include <tuple>
+#include <type_traits>
+
+namespace orc_rt {
+
+/// CallableTraitsHelper takes an implementation class template Impl and some
+/// callable type C and passes the return and argument types of C to the Impl
+/// class template.
+///
+/// This can be used to simplify the implementation of classes that need to
+/// operate on callable types.
+template <template <typename...> typename ImplT, typename C>
+struct CallableTraitsHelper
+    : public CallableTraitsHelper<
+          ImplT,
+          decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
+};
+
+template <template <typename...> typename ImplT, typename RetT,
+          typename... ArgTs>
+struct CallableTraitsHelper<ImplT, RetT(ArgTs...)>
+    : public ImplT<RetT, ArgTs...> {};
+
+template <template <typename...> typename ImplT, typename RetT,
+          typename... ArgTs>
+struct CallableTraitsHelper<ImplT, RetT (*)(ArgTs...)>
+    : public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};
+
+template <template <typename...> typename ImplT, typename RetT,
+          typename... ArgTs>
+struct CallableTraitsHelper<ImplT, RetT (&)(ArgTs...)>
+    : public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};
+
+template <template <typename...> typename ImplT, typename ClassT, typename RetT,
+          typename... ArgTs>
+struct CallableTraitsHelper<ImplT, RetT (ClassT::*)(ArgTs...)>
+    : public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};
+
+template <template <typename...> typename ImplT, typename ClassT, typename RetT,
+          typename... ArgTs>
+struct CallableTraitsHelper<ImplT, RetT (ClassT::*)(ArgTs...) const>
+    : public CallableTraitsHelper<ImplT, RetT(ArgTs...)> {};
+
+namespace detail {
+template <typename RetT, typename... ArgTs> struct CallableArgInfoImpl {
+  typedef RetT return_type;
+  typedef std::tuple<ArgTs...> args_tuple_type;
+};
+} // namespace detail
+
+/// CallableArgInfo provides typedefs for the return type and argument types
+/// (as a tuple) of the given callable type.
+template <typename Callable>
+struct CallableArgInfo
+    : public CallableTraitsHelper<detail::CallableArgInfoImpl, Callable> {};
+
+} // namespace orc_rt
+
+#endif // ORC_RT_CALLABLETRAITSHELPER_H
diff --git a/orc-rt/include/orc-rt/WrapperFunction.h b/orc-rt/include/orc-rt/WrapperFunction.h
index bedc097707e38..41960d24de04e 100644
--- a/orc-rt/include/orc-rt/WrapperFunction.h
+++ b/orc-rt/include/orc-rt/WrapperFunction.h
@@ -14,6 +14,7 @@
 #define ORC_RT_WRAPPERFUNCTION_H
 
 #include "orc-rt-c/WrapperFunction.h"
+#include "orc-rt/CallableTraitsHelper.h"
 #include "orc-rt/Error.h"
 #include "orc-rt/bind.h"
 
@@ -105,37 +106,16 @@ class WrapperFunctionBuffer {
 
 namespace detail {
 
-template <typename C>
-struct WFCallableTraits
-    : public WFCallableTraits<
-          decltype(&std::remove_cv_t<std::remove_reference_t<C>>::operator())> {
-};
-
-template <typename RetT> struct WFCallableTraits<RetT()> {
-  typedef void HeadArgType;
+template <typename RetT, typename ReturnT, typename... ArgTs>
+struct WFHandlerTraitsImpl {
+  static_assert(std::is_void_v<RetT>,
+                "Async wrapper function handler must return void");
+  typedef ReturnT YieldType;
+  typedef std::tuple<ArgTs...> ArgTupleType;
 };
 
-template <typename RetT, typename ArgT, typename... ArgTs>
-struct WFCallableTraits<RetT(ArgT, ArgTs...)> {
-  typedef ArgT HeadArgType;
-  typedef std::tuple<ArgTs...> TailArgTuple;
-};
-
-template <typename RetT, typename... ArgTs>
-struct WFCallableTraits<RetT (*)(ArgTs...)>
-    : public WFCallableTraits<RetT(ArgTs...)> {};
-
-template <typename RetT, typename... ArgTs>
-struct WFCallableTraits<RetT (&)(ArgTs...)>
-    : public WFCallableTraits<RetT(ArgTs...)> {};
-
-template <typename ClassT, typename RetT, typename... ArgTs>
-struct WFCallableTraits<RetT (ClassT::*)(ArgTs...)>
-    : public WFCallableTraits<RetT(ArgTs...)> {};
-
-template <typename ClassT, typename RetT, typename... ArgTs>
-struct WFCallableTraits<RetT (ClassT::*)(ArgTs...) const>
-    : public WFCallableTraits<RetT(ArgTs...)> {};
+template <typename C>
+using WFHandlerTraits = CallableTraitsHelper<WFHandlerTraitsImpl, C>;
 
 template <typename Serializer> class StructuredYieldBase {
 public:
@@ -151,8 +131,11 @@ template <typename Serializer> class StructuredYieldBase {
   std::decay_t<Serializer> S;
 };
 
+template <typename RetT, typename Serializer> class StructuredYield;
+
 template <typename RetT, typename Serializer>
-class StructuredYield : public StructuredYieldBase<Serializer> {
+class StructuredYield<std::tuple<RetT>, Serializer>
+    : public StructuredYieldBase<Serializer> {
 public:
   using StructuredYieldBase<Serializer>::StructuredYieldBase;
   void operator()(RetT &&R) {
@@ -167,7 +150,7 @@ class StructuredYield : public StructuredYieldBase<Serializer> {
 };
 
 template <typename Serializer>
-class StructuredYield<void, Serializer>
+class StructuredYield<std::tuple<>, Serializer>
     : public StructuredYieldBase<Serializer> {
 public:
   using StructuredYieldBase<Serializer>::StructuredYieldBase;
@@ -180,7 +163,7 @@ class StructuredYield<void, Serializer>
 template <typename T, typename Serializer> struct ResultDeserializer;
 
 template <typename T, typename Serializer>
-struct ResultDeserializer<Expected<T>, Serializer> {
+struct ResultDeserializer<std::tuple<Expected<T>>, Serializer> {
   static Expected<T> deserialize(WrapperFunctionBuffer ResultBytes,
                                  Serializer &S) {
     T Val;
@@ -191,7 +174,8 @@ struct ResultDeserializer<Expected<T>, Serializer> {
   }
 };
 
-template <typename Serializer> struct ResultDeserializer<Error, Serializer> {
+template <typename Serializer>
+struct ResultDeserializer<std::tuple<Error>, Serializer> {
   static Error deserialize(WrapperFunctionBuffer ResultBytes, Serializer &S) {
     assert(ResultBytes.empty());
     return Error::success();
@@ -213,11 +197,13 @@ struct WrapperFunction {
             typename... ArgTs>
   static void call(Caller &&C, Serializer &&S, ResultHandler &&RH,
                    ArgTs &&...Args) {
-    typedef detail::WFCallableTraits<ResultHandler> ResultHandlerTraits;
+    typedef CallableArgInfo<ResultHandler> ResultHandlerTraits;
+    static_assert(std::is_void_v<typename ResultHandlerTraits::return_type>,
+                  "Result handler should return void");
     static_assert(
-        std::tuple_size_v<typename ResultHandlerTraits::TailArgTuple> == 0,
-        "Expected one argument to result-handler");
-    typedef typename ResultHandlerTraits::HeadArgType ResultType;
+        std::tuple_size_v<typename ResultHandlerTraits::args_tuple_type> == 1,
+        "Result-handler should have exactly one argument");
+    typedef typename ResultHandlerTraits::args_tuple_type ResultTupleType;
 
     if (auto ArgBytes = S.argumentSerializer()(std::forward<ArgTs>(Args)...)) {
       C(
@@ -227,9 +213,8 @@ struct WrapperFunction {
             if (const char *ErrMsg = ResultBytes.getOutOfBandError())
               RH(make_error<StringError>(ErrMsg));
             else
-              RH(detail::ResultDeserializer<
-                  ResultType, Serializer>::deserialize(std::move(ResultBytes),
-                                                       S));
+              RH(detail::ResultDeserializer<ResultTupleType, Serializer>::
+                     deserialize(std::move(ResultBytes), S));
           },
           std::move(*ArgBytes));
     } else
@@ -246,10 +231,12 @@ struct WrapperFunction {
                      orc_rt_WrapperFunctionReturn Return,
                      WrapperFunctionBuffer ArgBytes, Serializer &&S,
                      Handler &&H) {
-    typedef detail::WFCallableTraits<Handler> HandlerTraits;
-    typedef typename HandlerTraits::HeadArgType Yield;
-    typedef typename HandlerTraits::TailArgTuple ArgTuple;
-    typedef typename detail::WFCallableTraits<Yield>::HeadArgType RetType;
+    typedef detail::WFHandlerTraits<Handler> HandlerTraits;
+    typedef typename HandlerTraits::ArgTupleType ArgTuple;
+    typedef typename HandlerTraits::YieldType Yield;
+    static_assert(std::is_void_v<typename CallableArgInfo<Yield>::return_type>,
+                  "Return callback must return void");
+    typedef typename CallableArgInfo<Yield>::args_tuple_type RetTupleType;
 
     if (ArgBytes.getOutOfBandError())
       return Return(Session, CallCtx, ArgBytes.release());
@@ -258,7 +245,7 @@ struct WrapperFunction {
     if (std::apply(bind_front(S.argumentDeserializer(), std::move(ArgBytes)),
                    Args))
       std::apply(bind_front(std::forward<Handler>(H),
-                            detail::StructuredYield<RetType, Serializer>(
+                            detail::StructuredYield<RetTupleType, Serializer>(
                                 Session, CallCtx, Return, std::move(S))),
                  std::move(Args));
     else
diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt
index f29fb1e3b9d03..54c453de3a97f 100644
--- a/orc-rt/unittests/CMakeLists.txt
+++ b/orc-rt/unittests/CMakeLists.txt
@@ -14,6 +14,7 @@ endfunction()
 add_orc_rt_unittest(CoreTests
   AllocActionTest.cpp
   BitmaskEnumTest.cpp
+  CallableTraitsHelperTest.cpp
   CommonTestUtils.cpp
   ErrorTest.cpp
   ExecutorAddressTest.cpp
diff --git a/orc-rt/unittests/CallableTraitsHelperTest.cpp b/orc-rt/unittests/CallableTraitsHelperTest.cpp
new file mode 100644
index 0000000000000..1db39162a9106
--- /dev/null
+++ b/orc-rt/unittests/CallableTraitsHelperTest.cpp
@@ -0,0 +1,69 @@
+//===- CallableTraitsHelperTest.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Tests for orc-rt's CallableTraitsHelper.h APIs.
+//
+// NOTE: All tests in this file are testing compile-time functionality, so the
+//       tests at runtime all end up being noops. That's fine -- those are
+//       cheap.
+//===----------------------------------------------------------------------===//
+
+#include "orc-rt/CallableTraitsHelper.h"
+#include "gtest/gtest.h"
+
+using namespace orc_rt;
+
+static void freeVoidVoid() {}
+
+TEST(CallableTraitsHelperTest, FreeVoidVoid) {
+  (void)freeVoidVoid;
+  typedef CallableArgInfo<decltype(freeVoidVoid)> CAI;
+  static_assert(std::is_void_v<CAI::return_type>);
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<>>);
+}
+
+static int freeBinaryOp(int, float) { return 0; }
+
+TEST(CallableTraitsHelperTest, FreeBinaryOp) {
+  (void)freeBinaryOp;
+  typedef CallableArgInfo<decltype(freeBinaryOp)> CAI;
+  static_assert(std::is_same_v<CAI::return_type, int>);
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int, float>>);
+}
+
+TEST(CallableTraitsHelperTest, VoidVoidObj) {
+  auto VoidVoid = []() {};
+  typedef CallableArgInfo<decltype(VoidVoid)> CAI;
+  static_assert(std::is_void_v<CAI::return_type>);
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<>>);
+}
+
+TEST(CallableTraitsHelperTest, BinaryOpObj) {
+  auto BinaryOp = [](int X, float Y) -> int { return X + Y; };
+  typedef CallableArgInfo<decltype(BinaryOp)> CAI;
+  static_assert(std::is_same_v<CAI::return_type, int>);
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int, float>>);
+}
+
+TEST(CallableTraitsHelperTest, PreservesLValueRef) {
+  auto RefOp = [](int &) {};
+  typedef CallableArgInfo<decltype(RefOp)> CAI;
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int &>>);
+}
+
+TEST(CallableTraitsHelperTest, PreservesLValueRefConstness) {
+  auto RefOp = [](const int &) {};
+  typedef CallableArgInfo<decltype(RefOp)> CAI;
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<const int &>>);
+}
+
+TEST(CallableTraitsHelperTest, PreservesRValueRef) {
+  auto RefOp = [](int &&) {};
+  typedef CallableArgInfo<decltype(RefOp)> CAI;
+  static_assert(std::is_same_v<CAI::args_tuple_type, std::tuple<int &&>>);
+}



More information about the llvm-commits mailing list