[compiler-rt] 0ede1b9 - [ORC-RT] Update WrapperFunctionCall for 089acf25223.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Sat Jan 15 18:48:19 PST 2022


Author: Lang Hames
Date: 2022-01-16T13:48:11+11:00
New Revision: 0ede1b906d4dc949b297d4f5d94ae9f4fc84a2b3

URL: https://github.com/llvm/llvm-project/commit/0ede1b906d4dc949b297d4f5d94ae9f4fc84a2b3
DIFF: https://github.com/llvm/llvm-project/commit/0ede1b906d4dc949b297d4f5d94ae9f4fc84a2b3.diff

LOG: [ORC-RT] Update WrapperFunctionCall for 089acf25223.

089acf25223 updated WrapperFunctionCall to carry arbitrary argument payloads
(rather than plain address ranges). This commit implements the corresponding
update for the ORC runtime.

Added: 
    

Modified: 
    compiler-rt/lib/orc/macho_platform.cpp
    compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
    compiler-rt/lib/orc/wrapper_function_utils.h

Removed: 
    


################################################################################
diff  --git a/compiler-rt/lib/orc/macho_platform.cpp b/compiler-rt/lib/orc/macho_platform.cpp
index 1a21921a7c5bd..e2666aa452c83 100644
--- a/compiler-rt/lib/orc/macho_platform.cpp
+++ b/compiler-rt/lib/orc/macho_platform.cpp
@@ -568,7 +568,7 @@ void destroyMachOTLVMgr(void *MachOTLVMgr) {
 
 Error runWrapperFunctionCalls(std::vector<WrapperFunctionCall> WFCs) {
   for (auto &WFC : WFCs)
-    if (auto Err = WFC.runWithSPSRet())
+    if (auto Err = WFC.runWithSPSRet<void>())
       return Err;
   return Error::success();
 }

diff  --git a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
index fafc2a4b18e9b..031238307b4f0 100644
--- a/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
+++ b/compiler-rt/lib/orc/unittests/wrapper_function_utils_test.cpp
@@ -128,24 +128,29 @@ TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
   EXPECT_EQ(Result, (int32_t)3);
 }
 
-// A non-SPS wrapper function that calculates the sum of a byte array.
-static __orc_rt_CWrapperFunctionResult sumArrayRawWrapper(const char *ArgData,
-                                                          size_t ArgSize) {
-  auto WFR = WrapperFunctionResult::allocate(1);
-  *WFR.data() = 0;
-  for (unsigned I = 0; I != ArgSize; ++I)
-    *WFR.data() += ArgData[I];
-  return WFR.release();
+static __orc_rt_CWrapperFunctionResult sumArrayWrapper(const char *ArgData,
+                                                       size_t ArgSize) {
+  return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
+             ArgData, ArgSize,
+             [](ExecutorAddrRange R) {
+               int8_t Sum = 0;
+               for (char C : R.toSpan<char>())
+                 Sum += C;
+               return Sum;
+             })
+      .release();
 }
 
 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
   {
-    // Check raw wrapper function calls.
+    // Check wrapper function calls.
     char A[] = {1, 2, 3, 4};
 
-    WrapperFunctionCall WFC{ExecutorAddr::fromPtr(sumArrayRawWrapper),
-                            ExecutorAddrRange(ExecutorAddr::fromPtr(A),
-                                              ExecutorAddrDiff(sizeof(A)))};
+    auto WFC =
+        cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
+            ExecutorAddr::fromPtr(sumArrayWrapper),
+            ExecutorAddrRange(ExecutorAddr::fromPtr(A),
+                              ExecutorAddrDiff(sizeof(A)))));
 
     WrapperFunctionResult WFR(WFC.run());
     EXPECT_EQ(WFR.size(), 1U);
@@ -154,20 +159,18 @@ TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
 
   {
     // Check calls to void functions.
-    WrapperFunctionCall WFC{ExecutorAddr::fromPtr(voidNoopWrapper),
-                            ExecutorAddrRange()};
-    auto Err = WFC.runWithSPSRet();
+    auto WFC =
+        cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
+            ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
+    auto Err = WFC.runWithSPSRet<void>();
     EXPECT_FALSE(!!Err);
   }
 
   {
     // Check calls with arguments and return values.
-    auto ArgWFR =
-        WrapperFunctionResult::fromSPSArgs<SPSArgList<int32_t, int32_t>>(2, 4);
-    WrapperFunctionCall WFC{
-        ExecutorAddr::fromPtr(addWrapper),
-        ExecutorAddrRange(ExecutorAddr::fromPtr(ArgWFR.data()),
-                          ExecutorAddrDiff(ArgWFR.size()))};
+    auto WFC =
+        cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
+            ExecutorAddr::fromPtr(addWrapper), 2, 4));
 
     int32_t Result = 0;
     auto Err = WFC.runWithSPSRet<int32_t>(Result);

diff  --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h
index 23385e1bd7944..02ea37393276f 100644
--- a/compiler-rt/lib/orc/wrapper_function_utils.h
+++ b/compiler-rt/lib/orc/wrapper_function_utils.h
@@ -395,25 +395,53 @@ makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
 }
 
 /// Represents a call to a wrapper function.
-struct WrapperFunctionCall {
-  ExecutorAddr Func;
-  ExecutorAddrRange ArgData;
+class WrapperFunctionCall {
+public:
+  // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
+  // smallvector.
+  using ArgDataBufferType = std::vector<char>;
+
+  /// Create a WrapperFunctionCall using the given SPS serializer to serialize
+  /// the arguments.
+  template <typename SPSSerializer, typename... ArgTs>
+  static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
+                                              const ArgTs &...Args) {
+    ArgDataBufferType ArgData;
+    ArgData.resize(SPSSerializer::size(Args...));
+    SPSOutputBuffer OB(&ArgData[0], ArgData.size());
+    if (SPSSerializer::serialize(OB, Args...))
+      return WrapperFunctionCall(FnAddr, std::move(ArgData));
+    return make_error<StringError>("Cannot serialize arguments for "
+                                   "AllocActionCall");
+  }
 
   WrapperFunctionCall() = default;
-  WrapperFunctionCall(ExecutorAddr Func, ExecutorAddrRange ArgData)
-      : Func(Func), ArgData(ArgData) {}
 
-  /// Run and return result as WrapperFunctionResult.
-  WrapperFunctionResult run() {
-    WrapperFunctionResult WFR(
-        Func.toPtr<__orc_rt_CWrapperFunctionResult (*)(const char *, size_t)>()(
-            ArgData.Start.toPtr<const char *>(),
-            static_cast<size_t>(ArgData.size().getValue())));
-    return WFR;
+  /// Create a WrapperFunctionCall from a target function and arg buffer.
+  WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
+      : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
+
+  /// Returns the address to be called.
+  const ExecutorAddr &getCallee() const { return FnAddr; }
+
+  /// Returns the argument data.
+  const ArgDataBufferType &getArgData() const { return ArgData; }
+
+  /// WrapperFunctionCalls convert to true if the callee is non-null.
+  explicit operator bool() const { return !!FnAddr; }
+
+  /// Run call returning raw WrapperFunctionResult.
+  WrapperFunctionResult run() const {
+    using FnTy =
+        __orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
+    return WrapperFunctionResult(
+        FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
   }
 
   /// Run call and deserialize result using SPS.
-  template <typename SPSRetT, typename RetT> Error runWithSPSRet(RetT &RetVal) {
+  template <typename SPSRetT, typename RetT>
+  std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
+  runWithSPSRet(RetT &RetVal) const {
     auto WFR = run();
     if (const char *ErrMsg = WFR.getOutOfBandError())
       return make_error<StringError>(ErrMsg);
@@ -425,30 +453,49 @@ struct WrapperFunctionCall {
   }
 
   /// Overload for SPS functions returning void.
-  Error runWithSPSRet() {
+  template <typename SPSRetT>
+  std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
+  runWithSPSRet() const {
     SPSEmpty E;
     return runWithSPSRet<SPSEmpty>(E);
   }
+
+  /// Run call and deserialize an SPSError result. SPSError returns and
+  /// deserialization failures are merged into the returned error.
+  Error runWithSPSRetErrorMerged() const {
+    detail::SPSSerializableError RetErr;
+    if (auto Err = runWithSPSRet<SPSError>(RetErr))
+      return Err;
+    return detail::fromSPSSerializable(std::move(RetErr));
+  }
+
+private:
+  ExecutorAddr FnAddr;
+  std::vector<char> ArgData;
 };
 
-class SPSWrapperFunctionCall {};
+using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
 
 template <>
 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
 public:
   static size_t size(const WrapperFunctionCall &WFC) {
-    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::size(WFC.Func,
-                                                                   WFC.ArgData);
+    return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
+        WFC.getCallee(), WFC.getArgData());
   }
 
   static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
-    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::serialize(
-        OB, WFC.Func, WFC.ArgData);
+    return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
+        OB, WFC.getCallee(), WFC.getArgData());
   }
 
   static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
-    return SPSArgList<SPSExecutorAddr, SPSExecutorAddrRange>::deserialize(
-        IB, WFC.Func, WFC.ArgData);
+    ExecutorAddr FnAddr;
+    WrapperFunctionCall::ArgDataBufferType ArgData;
+    if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
+      return false;
+    WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
+    return true;
   }
 };
 


        


More information about the llvm-commits mailing list