[compiler-rt] dc8e5e1 - [ORC-RT] Add a WrapperFunctionCall utility.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 28 17:07:33 PDT 2021


Author: Lang Hames
Date: 2021-10-28T17:07:28-07:00
New Revision: dc8e5e1dc03dc49ecde6a19f06e33b3631141b6a

URL: https://github.com/llvm/llvm-project/commit/dc8e5e1dc03dc49ecde6a19f06e33b3631141b6a
DIFF: https://github.com/llvm/llvm-project/commit/dc8e5e1dc03dc49ecde6a19f06e33b3631141b6a.diff

LOG: [ORC-RT] Add a WrapperFunctionCall utility.

WrapperFunctionCall represents a call to a wrapper function as a pair of a
target function (as an ExecutorAddr), and an argument buffer range (as an
ExecutorAddrRange). WrapperFunctionCall instances can be serialized via
SPS to send to remote machines (only the argument buffer address range is
copied, not any buffer content).

This utility will simplify the implementation of JITLinkMemoryManager
allocation actions in the ORC runtime.

Added: 
    

Modified: 
    compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
    compiler-rt/lib/orc/wrapper_function_utils.h

Removed: 
    


################################################################################
diff  --git a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
index 23d32a041a91d..fafc2a4b18e9b 100644
--- a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
+++ b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
@@ -127,3 +127,51 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
       (void *)&addMethodWrapper, Result, ExecutorAddr::fromPtr(&AddObj), 2));
   EXPECT_EQ(Result, (int32_t)3);
 }
+
+// A non-SPS wrapper function that calculates the sum of a byte array.
+static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData,
+                                                          size_t ArgSize) {
+  auto WFR = WrapperFunctionResult::allocate(1);
+  *WFR.data() = 0;
+  for (unsigned I = 0; I != ArgSize; ++I)
+    *WFR.data() += ArgData[I];
+  return WFR.release();
+}
+
+TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
+  {
+    // Check raw wrapper function calls.
+    char A[] = {1, 2, 3, 4};
+
+    WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper),
+                            ExecutorAddrRange(ExecutorAddr::fromPtr(A),
+                                              ExecutorAddrDiff(sizeof(A)))};
+
+    WrapperFunctionResult WFR(WFC.run());
+    EXPECT_EQ(WFR.size(), 1U);
+    EXPECT_EQ(WFR.data()[0], 10);
+  }
+
+  {
+    // Check calls to void functions.
+    WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper),
+                            ExecutorAddrRange()};
+    auto Err = WFC.runWithSPSRet();
+    EXPECT_FALSE(!!Err);
+  }
+
+  {
+    // Check calls with arguments and return values.
+    auto ArgWFR =
+        WrapperFunctionResult::fromSPSArgs<SPSArgList<int32_t, int32_t>>(2, 4);
+    WrapperFunctionCall WFC{
+        ExecutorAddr::fromPtr(addWrapper),
+        ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()),
+                          ExecutorAddrDiff(ArgWFR.size()))};
+
+    int32_t Result = 0;
+    auto Err = WFC.runWithSPSRet<int32_t>(Result);
+    EXPECT_FALSE(!!Err);
+    EXPECT_EQ(Result, 6);
+  }
+}

diff  --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h
index cf92ad890cd17..23385e1bd7944 100644
--- a/compiler-rt/lib/orc/wrapper_function_utils.h
+++ b/compiler-rt/lib/orc/wrapper_function_utils.h
@@ -104,6 +104,16 @@ class WrapperFunctionResult {
     return createOutOfBandError(Msg.c_str());
   }
 
+  template <typename SPSArgListT, typename... ArgTs>
+  static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
+    auto Result = allocate(SPSArgListT::size(Args...));
+    SPSOutputBuffer OB(Result.data(), Result.size());
+    if (!SPSArgListT::serialize(OB, Args...))
+      return createOutOfBandError(
+          "Error serializing arguments to blob in call");
+    return Result;
+  }
+
   /// If this value is an out-of-band error then this returns the error message,
   /// otherwise returns nullptr.
   const char *getOutOfBandError() const {
@@ -116,17 +126,6 @@ class WrapperFunctionResult {
 
 namespace detail {
 
-template <typename SPSArgListT, typename... ArgTs>
-WrapperFunctionResult
-serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
-  auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
-  SPSOutputBuffer OB(Result.data(), Result.size());
-  if (!SPSArgListT::serialize(OB, Args...))
-    return WrapperFunctionResult::createOutOfBandError(
-        "Error serializing arguments to blob in call");
-  return Result;
-}
-
 template <typename RetT> class WrapperFunctionHandlerCaller {
 public:
   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
@@ -212,15 +211,14 @@ class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
 public:
   static WrapperFunctionResult serialize(RetT Result) {
-    return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
-        Result);
+    return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
   }
 };
 
 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
 public:
   static WrapperFunctionResult serialize(Error Err) {
-    return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
+    return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(Err)));
   }
 };
@@ -229,7 +227,7 @@ template <typename SPSRetTagT, typename T>
 class ResultSerializer<SPSRetTagT, Expected<T>> {
 public:
   static WrapperFunctionResult serialize(Expected<T> E) {
-    return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
+    return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
         toSPSSerializable(std::move(E)));
   }
 };
@@ -304,8 +302,7 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
       return make_error<StringError>("__orc_rt_jit_dispatch not set");
 
     auto ArgBuffer =
-        detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
-            Args...);
+        WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
       return make_error<StringError>(ErrMsg);
 
@@ -397,6 +394,64 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
   return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
 }
 
+/// Represents a call to a wrapper function.
+struct WrapperFunctionCall {
+  ExecutorAddr Func;
+  ExecutorAddrRange ArgData;
+
+  WrapperFunctionCall() = default;
+  WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
+      : Func(Func), ArgData(ArgData) {}
+
+  /// Run and return result as WrapperFunctionResult.
+  WrapperFunctionResult run() {
+    WrapperFunctionResult WFR(
+        Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
+            ArgData.Start.toPtr<const char *>(),
+            static_cast<size_t>(ArgData.size().getValue())));
+    return WFR;
+  }
+
+  /// Run call and deserialize result using SPS.
+  template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
+    auto WFR = run();
+    if (const char *ErrMsg = WFR.getOutOfBandError())
+      return make_error<StringError>(ErrMsg);
+    SPSInputBuffer IB(WFR.data(), WFR.size());
+    if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
+      return make_error<StringError>("Could not deserialize result from "
+                                     "serialized wrapper function call");
+    return Error::success();
+  }
+
+  /// Overload for SPS functions returning void.
+  Error runWithSPSRet() {
+    SPSEmpty E;
+    return runWithSPSRet<SPSEmpty>(E);
+  }
+};
+
+class SPSWrapperFunctionCall {};
+
+template <>
+class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
+public:
+  static size_t size(const WrapperFunctionCall &WFC) {
+    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
+                                                                   WFC.ArgData);
+  }
+
+  static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
+    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
+        OB, WFC.Func, WFC.ArgData);
+  }
+
+  static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
+    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
+        IB, WFC.Func, WFC.ArgData);
+  }
+};
+
 } // end namespace __orc_rt
 
 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H


        


More information about the llvm-commits mailing list