[Mlir-commits] [mlir] 84dc9b4 - [mlir:JitRunner] Use custom shared library init/destroy functions if available

Eugene Zhulenev llvmlistbot at llvm.org
Fri Jan 8 07:14:28 PST 2021


Author: Eugene Zhulenev
Date: 2021-01-08T07:14:21-08:00
New Revision: 84dc9b451bfd62474f44dd1af0e4955a0110d523

URL: https://github.com/llvm/llvm-project/commit/84dc9b451bfd62474f44dd1af0e4955a0110d523
DIFF: https://github.com/llvm/llvm-project/commit/84dc9b451bfd62474f44dd1af0e4955a0110d523.diff

LOG: [mlir:JitRunner] Use custom shared library init/destroy functions if available

Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner.

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D94270

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/AsyncRuntime.cpp
    mlir/lib/ExecutionEngine/CMakeLists.txt
    mlir/lib/ExecutionEngine/JitRunner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 45bdcb3733b8..080be9ca029a 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -24,6 +24,8 @@
 #include <thread>
 #include <vector>
 
+#include "llvm/ADT/StringMap.h"
+
 using namespace mlir::runtime;
 
 //===----------------------------------------------------------------------===//
@@ -109,9 +111,17 @@ class RefCounted {
 } // namespace
 
 // Returns the default per-process instance of an async runtime.
-static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
+static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
   static auto runtime = std::make_unique<AsyncRuntime>();
-  return runtime.get();
+  return runtime;
+}
+
+static void resetDefaultAsyncRuntime() {
+  return getDefaultAsyncRuntimeInstance().reset();
+}
+
+static AsyncRuntime *getDefaultAsyncRuntime() {
+  return getDefaultAsyncRuntimeInstance().get();
 }
 
 // Async token provides a mechanism to signal asynchronous operation completion.
@@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
 
 // Creates a new `async.token` in not-ready state.
 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
-  AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
+  AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
   return token;
 }
 
 // Creates a new `async.value` in not-ready state.
 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
-  AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
+  AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
   return value;
 }
 
 // Create a new `async.group` in empty state.
 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
-  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
+  AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
   return group;
 }
 
@@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
   std::cout << "Current thread id: " << thisId << std::endl;
 }
 
+//===----------------------------------------------------------------------===//
+// MLIR Runner (JitRunner) dynamic library integration.
+//===----------------------------------------------------------------------===//
+
+// Export symbols for the MLIR runner integration. All other symbols are hidden.
+#define API __attribute__((visibility("default")))
+
+extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
+  auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
+    assert(exportSymbols.count(name) == 0 && "symbol already exists");
+    exportSymbols[name] = reinterpret_cast<void *>(ptr);
+  };
+
+  exportSymbol("mlirAsyncRuntimeAddRef",
+               &mlir::runtime::mlirAsyncRuntimeAddRef);
+  exportSymbol("mlirAsyncRuntimeDropRef",
+               &mlir::runtime::mlirAsyncRuntimeDropRef);
+  exportSymbol("mlirAsyncRuntimeExecute",
+               &mlir::runtime::mlirAsyncRuntimeExecute);
+  exportSymbol("mlirAsyncRuntimeGetValueStorage",
+               &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
+  exportSymbol("mlirAsyncRuntimeCreateToken",
+               &mlir::runtime::mlirAsyncRuntimeCreateToken);
+  exportSymbol("mlirAsyncRuntimeCreateValue",
+               &mlir::runtime::mlirAsyncRuntimeCreateValue);
+  exportSymbol("mlirAsyncRuntimeEmplaceToken",
+               &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
+  exportSymbol("mlirAsyncRuntimeEmplaceValue",
+               &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
+  exportSymbol("mlirAsyncRuntimeAwaitToken",
+               &mlir::runtime::mlirAsyncRuntimeAwaitToken);
+  exportSymbol("mlirAsyncRuntimeAwaitValue",
+               &mlir::runtime::mlirAsyncRuntimeAwaitValue);
+  exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
+               &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
+  exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
+               &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
+  exportSymbol("mlirAsyncRuntimeCreateGroup",
+               &mlir::runtime::mlirAsyncRuntimeCreateGroup);
+  exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
+               &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
+  exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
+               &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
+  exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
+               &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
+  exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
+               &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
+}
+
+extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
+
 #endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS

diff  --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index 47dbe45d8138..7d86811fe4fd 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
   mlir_c_runner_utils_static
   ${LLVM_PTHREAD_LIB}
 )
+set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
 target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)

diff  --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index c7548b0d8a85..d2a9336030e3 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
   if (auto clOptLevel = getCommandLineOptLevel(options))
     jitCodeGenOptLevel =
         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
+
+  // If shared library implements custom mlir-runner library init and destroy
+  // functions, we'll use them to register the library with the execution
+  // engine. Otherwise we'll pass library directly to the execution engine.
   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
                                  options.clSharedLibs.end());
+
+  // Libraries that we'll pass to the ExecutionEngine for loading.
+  SmallVector<StringRef, 4> executionEngineLibs;
+
+  using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
+  using MlirRunnerDestroyFn = void (*)();
+
+  llvm::StringMap<void *> exportSymbols;
+  SmallVector<MlirRunnerDestroyFn> destroyFns;
+
+  // Handle libraries that do support mlir-runner init/destroy callbacks.
+  for (auto libPath : libs) {
+    auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
+    void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
+    void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
+
+    // Library does not support mlir runner, load it with ExecutionEngine.
+    if (!initSym || !destroySim) {
+      executionEngineLibs.push_back(libPath);
+      continue;
+    }
+
+    auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
+    initFn(exportSymbols);
+
+    auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
+    destroyFns.push_back(destroyFn);
+  }
+
+  // Build a runtime symbol map from the config and exported symbols.
+  auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
+    auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
+                                             : llvm::orc::SymbolMap();
+    for (auto &exportSymbol : exportSymbols)
+      symbolMap[interner(exportSymbol.getKey())] =
+          llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
+    return symbolMap;
+  };
+
   auto expectedEngine = mlir::ExecutionEngine::create(
       module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
-      libs);
+      executionEngineLibs);
   if (!expectedEngine)
     return expectedEngine.takeError();
 
   auto engine = std::move(*expectedEngine);
-  if (config.runtimeSymbolMap)
-    engine->registerSymbols(config.runtimeSymbolMap);
+  engine->registerSymbols(runtimeSymbolMap);
 
   auto expectedFPtr = engine->lookup(entryPoint);
   if (!expectedFPtr)
@@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
   void (*fptr)(void **) = *expectedFPtr;
   (*fptr)(args);
 
+  // Run all dynamic library destroy callbacks to prepare for the shutdown.
+  llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
+
   return Error::success();
 }
 


        


More information about the Mlir-commits mailing list