[llvm-branch-commits] [mlir] 84dc9b4 - [mlir:JitRunner] Use custom shared library init/destroy functions if available
Eugene Zhulenev via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 8 07:19:02 PST 2021
Author: Eugene Zhulenev
Date: 2021-01-08T07:14:21-08:00
New Revision: 84dc9b451bfd62474f44dd1af0e4955a0110d523
URL: https://github.com/llvm/llvm-project/commit/84dc9b451bfd62474f44dd1af0e4955a0110d523
DIFF: https://github.com/llvm/llvm-project/commit/84dc9b451bfd62474f44dd1af0e4955a0110d523.diff
LOG: [mlir:JitRunner] Use custom shared library init/destroy functions if available
Use custom mlir runner init/destroy functions to safely init and destroy shared libraries loaded by the JitRunner.
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D94270
Added:
Modified:
mlir/lib/ExecutionEngine/AsyncRuntime.cpp
mlir/lib/ExecutionEngine/CMakeLists.txt
mlir/lib/ExecutionEngine/JitRunner.cpp
Removed:
################################################################################
diff --git a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
index 45bdcb3733b8..080be9ca029a 100644
--- a/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/AsyncRuntime.cpp
@@ -24,6 +24,8 @@
#include <thread>
#include <vector>
+#include "llvm/ADT/StringMap.h"
+
using namespace mlir::runtime;
//===----------------------------------------------------------------------===//
@@ -109,9 +111,17 @@ class RefCounted {
} // namespace
// Returns the default per-process instance of an async runtime.
-static AsyncRuntime *getDefaultAsyncRuntimeInstance() {
+static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
static auto runtime = std::make_unique<AsyncRuntime>();
- return runtime.get();
+ return runtime;
+}
+
+static void resetDefaultAsyncRuntime() {
+ return getDefaultAsyncRuntimeInstance().reset();
+}
+
+static AsyncRuntime *getDefaultAsyncRuntime() {
+ return getDefaultAsyncRuntimeInstance().get();
}
// Async token provides a mechanism to signal asynchronous operation completion.
@@ -184,19 +194,19 @@ extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
// Creates a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
- AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
+ AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
return token;
}
// Creates a new `async.value` in not-ready state.
extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int32_t size) {
- AsyncValue *value = new AsyncValue(getDefaultAsyncRuntimeInstance(), size);
+ AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
return value;
}
// Create a new `async.group` in empty state.
extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup() {
- AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
+ AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime());
return group;
}
@@ -342,4 +352,55 @@ extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
std::cout << "Current thread id: " << thisId << std::endl;
}
+//===----------------------------------------------------------------------===//
+// MLIR Runner (JitRunner) dynamic library integration.
+//===----------------------------------------------------------------------===//
+
+// Export symbols for the MLIR runner integration. All other symbols are hidden.
+#define API __attribute__((visibility("default")))
+
+extern "C" API void __mlir_runner_init(llvm::StringMap<void *> &exportSymbols) {
+ auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
+ assert(exportSymbols.count(name) == 0 && "symbol already exists");
+ exportSymbols[name] = reinterpret_cast<void *>(ptr);
+ };
+
+ exportSymbol("mlirAsyncRuntimeAddRef",
+ &mlir::runtime::mlirAsyncRuntimeAddRef);
+ exportSymbol("mlirAsyncRuntimeDropRef",
+ &mlir::runtime::mlirAsyncRuntimeDropRef);
+ exportSymbol("mlirAsyncRuntimeExecute",
+ &mlir::runtime::mlirAsyncRuntimeExecute);
+ exportSymbol("mlirAsyncRuntimeGetValueStorage",
+ &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
+ exportSymbol("mlirAsyncRuntimeCreateToken",
+ &mlir::runtime::mlirAsyncRuntimeCreateToken);
+ exportSymbol("mlirAsyncRuntimeCreateValue",
+ &mlir::runtime::mlirAsyncRuntimeCreateValue);
+ exportSymbol("mlirAsyncRuntimeEmplaceToken",
+ &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
+ exportSymbol("mlirAsyncRuntimeEmplaceValue",
+ &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
+ exportSymbol("mlirAsyncRuntimeAwaitToken",
+ &mlir::runtime::mlirAsyncRuntimeAwaitToken);
+ exportSymbol("mlirAsyncRuntimeAwaitValue",
+ &mlir::runtime::mlirAsyncRuntimeAwaitValue);
+ exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
+ &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
+ exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
+ &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
+ exportSymbol("mlirAsyncRuntimeCreateGroup",
+ &mlir::runtime::mlirAsyncRuntimeCreateGroup);
+ exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
+ &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
+ exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
+ &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
+ exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
+ &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
+ exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
+ &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
+}
+
+extern "C" API void __mlir_runner_destroy() { resetDefaultAsyncRuntime(); }
+
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index 47dbe45d8138..7d86811fe4fd 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -111,4 +111,5 @@ add_mlir_library(mlir_async_runtime
mlir_c_runner_utils_static
${LLVM_PTHREAD_LIB}
)
+set_property(TARGET mlir_async_runtime PROPERTY CXX_VISIBILITY_PRESET hidden)
target_compile_definitions(mlir_async_runtime PRIVATE mlir_async_runtime_EXPORTS)
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index c7548b0d8a85..d2a9336030e3 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -155,17 +155,59 @@ static Error compileAndExecute(Options &options, ModuleOp module,
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel =
static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
+
+ // If shared library implements custom mlir-runner library init and destroy
+ // functions, we'll use them to register the library with the execution
+ // engine. Otherwise we'll pass library directly to the execution engine.
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
+
+ // Libraries that we'll pass to the ExecutionEngine for loading.
+ SmallVector<StringRef, 4> executionEngineLibs;
+
+ using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
+ using MlirRunnerDestroyFn = void (*)();
+
+ llvm::StringMap<void *> exportSymbols;
+ SmallVector<MlirRunnerDestroyFn> destroyFns;
+
+ // Handle libraries that do support mlir-runner init/destroy callbacks.
+ for (auto libPath : libs) {
+ auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
+ void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
+ void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
+
+ // Library does not support mlir runner, load it with ExecutionEngine.
+ if (!initSym || !destroySim) {
+ executionEngineLibs.push_back(libPath);
+ continue;
+ }
+
+ auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
+ initFn(exportSymbols);
+
+ auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
+ destroyFns.push_back(destroyFn);
+ }
+
+ // Build a runtime symbol map from the config and exported symbols.
+ auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
+ auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
+ : llvm::orc::SymbolMap();
+ for (auto &exportSymbol : exportSymbols)
+ symbolMap[interner(exportSymbol.getKey())] =
+ llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
+ return symbolMap;
+ };
+
auto expectedEngine = mlir::ExecutionEngine::create(
module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
- libs);
+ executionEngineLibs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
- if (config.runtimeSymbolMap)
- engine->registerSymbols(config.runtimeSymbolMap);
+ engine->registerSymbols(runtimeSymbolMap);
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
@@ -179,6 +221,9 @@ static Error compileAndExecute(Options &options, ModuleOp module,
void (*fptr)(void **) = *expectedFPtr;
(*fptr)(args);
+ // Run all dynamic library destroy callbacks to prepare for the shutdown.
+ llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
+
return Error::success();
}
More information about the llvm-branch-commits
mailing list