[compiler-rt] 462251b - [ORC-RT] Replace FnTag arg of WrapperFunction::call with generic dispatch arg.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 9 22:12:59 PDT 2024


Author: Lang Hames
Date: 2024-09-10T15:11:40+10:00
New Revision: 462251b80b7ba51dc6c2ef3676cf50ee92867d34

URL: https://github.com/llvm/llvm-project/commit/462251b80b7ba51dc6c2ef3676cf50ee92867d34
DIFF: https://github.com/llvm/llvm-project/commit/462251b80b7ba51dc6c2ef3676cf50ee92867d34.diff

LOG: [ORC-RT] Replace FnTag arg of WrapperFunction::call with generic dispatch arg.

This decouples function argument serialization / deserialization from the
function call dispatch mechanism. This will eventually allow us to replace the
existing __orc_rt_jit_dispatch function with a system that supports pre-linking
parts of the ORC runtime into the executor.

Added: 
    compiler-rt/lib/orc/jit_dispatch.h

Modified: 
    compiler-rt/lib/orc/coff_platform.cpp
    compiler-rt/lib/orc/elfnix_platform.cpp
    compiler-rt/lib/orc/macho_platform.cpp
    compiler-rt/lib/orc/wrapper_function_utils.h

Removed: 
    


################################################################################
diff  --git a/compiler-rt/lib/orc/coff_platform.cpp b/compiler-rt/lib/orc/coff_platform.cpp
index 346d896f6c9484..49b805a0ec7d31 100644
--- a/compiler-rt/lib/orc/coff_platform.cpp
+++ b/compiler-rt/lib/orc/coff_platform.cpp
@@ -17,6 +17,7 @@
 
 #include "debug.h"
 #include "error.h"
+#include "jit_dispatch.h"
 #include "wrapper_function_utils.h"
 
 #include <array>
@@ -315,9 +316,9 @@ Error COFFPlatformRuntimeState::dlopenFull(JITDylibState &JDS) {
   // Call back to the JIT to push the initializers.
   Expected<COFFJITDylibDepInfoMap> DepInfoMap((COFFJITDylibDepInfoMap()));
   if (auto Err = WrapperFunction<SPSExpected<SPSCOFFJITDylibDepInfoMap>(
-          SPSExecutorAddr)>::call(&__orc_rt_coff_push_initializers_tag,
-                                  DepInfoMap,
-                                  ExecutorAddr::fromPtr(JDS.Header)))
+          SPSExecutorAddr)>::
+          call(JITDispatch(&__orc_rt_coff_push_initializers_tag), DepInfoMap,
+               ExecutorAddr::fromPtr(JDS.Header)))
     return Err;
   if (!DepInfoMap)
     return DepInfoMap.takeError();
@@ -445,10 +446,9 @@ COFFPlatformRuntimeState::lookupSymbolInJITDylib(void *header,
                                                  std::string_view Sym) {
   Expected<ExecutorAddr> Result((ExecutorAddr()));
   if (auto Err = WrapperFunction<SPSExpected<SPSExecutorAddr>(
-          SPSExecutorAddr, SPSString)>::call(&__orc_rt_coff_symbol_lookup_tag,
-                                             Result,
-                                             ExecutorAddr::fromPtr(header),
-                                             Sym))
+          SPSExecutorAddr,
+          SPSString)>::call(JITDispatch(&__orc_rt_coff_symbol_lookup_tag),
+                            Result, ExecutorAddr::fromPtr(header), Sym))
     return std::move(Err);
   return Result;
 }

diff  --git a/compiler-rt/lib/orc/elfnix_platform.cpp b/compiler-rt/lib/orc/elfnix_platform.cpp
index dc6af65dc996a0..3f1ba4ac4ea93e 100644
--- a/compiler-rt/lib/orc/elfnix_platform.cpp
+++ b/compiler-rt/lib/orc/elfnix_platform.cpp
@@ -14,6 +14,7 @@
 #include "common.h"
 #include "compiler.h"
 #include "error.h"
+#include "jit_dispatch.h"
 #include "wrapper_function_utils.h"
 
 #include <algorithm>
@@ -352,10 +353,9 @@ ELFNixPlatformRuntimeState::lookupSymbolInJITDylib(void *DSOHandle,
                                                    std::string_view Sym) {
   Expected<ExecutorAddr> Result((ExecutorAddr()));
   if (auto Err = WrapperFunction<SPSExpected<SPSExecutorAddr>(
-          SPSExecutorAddr, SPSString)>::call(&__orc_rt_elfnix_symbol_lookup_tag,
-                                             Result,
-                                             ExecutorAddr::fromPtr(DSOHandle),
-                                             Sym))
+          SPSExecutorAddr,
+          SPSString)>::call(JITDispatch(&__orc_rt_elfnix_symbol_lookup_tag),
+                            Result, ExecutorAddr::fromPtr(DSOHandle), Sym))
     return std::move(Err);
   return Result;
 }
@@ -368,8 +368,9 @@ ELFNixPlatformRuntimeState::getJITDylibInitializersByName(
   std::string PathStr(Path.data(), Path.size());
   if (auto Err =
           WrapperFunction<SPSExpected<SPSELFNixJITDylibInitializerSequence>(
-              SPSString)>::call(&__orc_rt_elfnix_get_initializers_tag, Result,
-                                Path))
+              SPSString)>::
+              call(JITDispatch(&__orc_rt_elfnix_get_initializers_tag, Result),
+                   Path))
     return std::move(Err);
   return Result;
 }

diff  --git a/compiler-rt/lib/orc/jit_dispatch.h b/compiler-rt/lib/orc/jit_dispatch.h
new file mode 100644
index 00000000000000..9b2329fa1e4fc7
--- /dev/null
+++ b/compiler-rt/lib/orc/jit_dispatch.h
@@ -0,0 +1,50 @@
+//===------ jit_dispatch.h - Call back to an ORC controller -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file is a part of the ORC runtime support library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef ORC_RT_JIT_DISPATCH_H
+#define ORC_RT_JIT_DISPATCH_H
+
+#include "common.h"
+#include "wrapper_function_utils.h"
+
+namespace orc_rt {
+
+class JITDispatch {
+public:
+  JITDispatch(const void *FnTag) : FnTag(FnTag) {}
+
+  WrapperFunctionResult operator()(const char *ArgData, size_t ArgSize) {
+    // Since the functions cannot be zero/unresolved on Windows, the following
+    // reference taking would always be non-zero, thus generating a compiler
+    // warning otherwise.
+#if !defined(_WIN32)
+    if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
+      return WrapperFunctionResult::createOutOfBandError(
+                 "__orc_rt_jit_dispatch_ctx not set")
+          .release();
+    if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
+      return WrapperFunctionResult::createOutOfBandError(
+                 "__orc_rt_jit_dispatch not set")
+          .release();
+#endif
+
+    return __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, ArgData,
+                                 ArgSize);
+  }
+
+private:
+  const void *FnTag;
+};
+
+} // namespace orc_rt
+
+#endif // ORC_RT_JIT_DISPATCH_H

diff  --git a/compiler-rt/lib/orc/macho_platform.cpp b/compiler-rt/lib/orc/macho_platform.cpp
index 1974d3f0ef33f8..c092545b2a3677 100644
--- a/compiler-rt/lib/orc/macho_platform.cpp
+++ b/compiler-rt/lib/orc/macho_platform.cpp
@@ -16,6 +16,7 @@
 #include "debug.h"
 #include "error.h"
 #include "interval_map.h"
+#include "jit_dispatch.h"
 #include "wrapper_function_utils.h"
 
 #include <algorithm>
@@ -915,7 +916,7 @@ Error MachOPlatformRuntimeState::requestPushSymbols(
   Error OpErr = Error::success();
   if (auto Err = WrapperFunction<SPSError(
           SPSExecutorAddr, SPSSequence<SPSTuple<SPSString, bool>>)>::
-          call(&__orc_rt_macho_push_symbols_tag, OpErr,
+          call(JITDispatch(&__orc_rt_macho_push_symbols_tag), OpErr,
                ExecutorAddr::fromPtr(JDS.Header), Symbols)) {
     cantFail(std::move(OpErr));
     return std::move(Err);
@@ -1145,8 +1146,9 @@ Error MachOPlatformRuntimeState::dlopenFull(
   // Unlock so that we can accept the initializer update.
   JDStatesLock.unlock();
   if (auto Err = WrapperFunction<SPSExpected<SPSMachOJITDylibDepInfoMap>(
-          SPSExecutorAddr)>::call(&__orc_rt_macho_push_initializers_tag,
-                                  DepInfo, ExecutorAddr::fromPtr(JDS.Header)))
+          SPSExecutorAddr)>::
+          call(JITDispatch(&__orc_rt_macho_push_initializers_tag), DepInfo,
+               ExecutorAddr::fromPtr(JDS.Header)))
     return Err;
   JDStatesLock.lock();
 

diff  --git a/compiler-rt/lib/orc/wrapper_function_utils.h b/compiler-rt/lib/orc/wrapper_function_utils.h
index e65aac0fe4e53e..d5a709a046210e 100644
--- a/compiler-rt/lib/orc/wrapper_function_utils.h
+++ b/compiler-rt/lib/orc/wrapper_function_utils.h
@@ -13,10 +13,9 @@
 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
 
-#include "orc_rt/c_api.h"
-#include "common.h"
 #include "error.h"
 #include "executor_address.h"
+#include "orc_rt/c_api.h"
 #include "simple_packed_serialization.h"
 #include <type_traits>
 
@@ -288,30 +287,22 @@ class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
 
 public:
-  template <typename RetT, typename... ArgTs>
-  static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
+  template <typename DispatchFn, typename RetT, typename... ArgTs>
+  static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) {
 
     // RetT might be an Error or Expected value. Set the checked flag now:
     // we don't want the user to have to check the unused result if this
     // operation fails.
     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
 
-    // Since the functions cannot be zero/unresolved on Windows, the following
-    // reference taking would always be non-zero, thus generating a compiler
-    // warning otherwise.
-#if !defined(_WIN32)
-    if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
-      return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
-    if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
-      return make_error<StringError>("__orc_rt_jit_dispatch not set");
-#endif
     auto ArgBuffer =
         WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
       return make_error<StringError>(ErrMsg);
 
-    WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
-        &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
+    WrapperFunctionResult ResultBuffer =
+        Dispatch(ArgBuffer.data(), ArgBuffer.size());
+
     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
       return make_error<StringError>(ErrMsg);
 


        


More information about the llvm-commits mailing list