[Mlir-commits] [mlir] f6c9f6e - [mlir] JitRunner: add a config option to register symbols with ExecutionEngine at runtime
Eugene Zhulenev
llvmlistbot at llvm.org
Tue Oct 27 15:57:43 PDT 2020
Author: Eugene Zhulenev
Date: 2020-10-27T15:57:34-07:00
New Revision: f6c9f6eccda44a1fa5a57652a73f9ebf6595f5a6
URL: https://github.com/llvm/llvm-project/commit/f6c9f6eccda44a1fa5a57652a73f9ebf6595f5a6
DIFF: https://github.com/llvm/llvm-project/commit/f6c9f6eccda44a1fa5a57652a73f9ebf6595f5a6.diff
LOG: [mlir] JitRunner: add a config option to register symbols with ExecutionEngine at runtime
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D90264
Added:
Modified:
mlir/include/mlir/ExecutionEngine/JitRunner.h
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h
index 2b7518c8cde2..43c9f9699fee 100644
--- a/mlir/include/mlir/ExecutionEngine/JitRunner.h
+++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h
@@ -18,29 +18,42 @@
#ifndef MLIR_SUPPORT_JITRUNNER_H_
#define MLIR_SUPPORT_JITRUNNER_H_
-#include "mlir/IR/Module.h"
-
#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Module.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
-namespace mlir {
+namespace llvm {
+class Module;
+class LLVMContext;
-using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
- ModuleOp, llvm::LLVMContext &)>;
+namespace orc {
+class MangleAndInterner;
+} // namespace orc
+} // namespace llvm
+
+namespace mlir {
class ModuleOp;
struct LogicalResult;
+struct JitRunnerConfig {
+ /// MLIR transformer applied after parsing the input into MLIR IR and before
+ /// passing the MLIR module to the ExecutionEngine.
+ llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
+
+ /// A custom function that is passed to ExecutionEngine. It processes MLIR
+ /// module and creates LLVM IR module.
+ llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+ llvm::LLVMContext &)>
+ llvmModuleBuilder = nullptr;
+
+ /// A callback to register symbols with ExecutionEngine at runtime.
+ llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+ runtimesymbolMap = nullptr;
+};
+
// Entry point for all CPU runners. Expects the common argc/argv arguments for
-// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int JitRunnerMain(
- int argc, char **argv,
- llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
- TranslationCallback llvmModuleBuilder = nullptr);
+// standard C++ main functions.
+int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
} // namespace mlir
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index c1bb1e68a4ca..d9e0ff63c125 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -92,6 +92,23 @@ struct Options {
"object-filename",
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
};
+
+struct CompileAndExecuteConfig {
+ /// LLVM module transformer that is passed to ExecutionEngine.
+ llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
+
+ /// A custom function that is passed to ExecutionEngine. It processes MLIR
+ /// module and creates LLVM IR module.
+ llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+ llvm::LLVMContext &)>
+ llvmModuleBuilder;
+
+ /// A custom function that is passed to ExecutinEngine to register symbols at
+ /// runtime.
+ llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+ runtimeSymbolMap;
+};
+
} // end anonymous namespace
static OwningModuleRef parseMLIRInput(StringRef inputFilename,
@@ -131,11 +148,9 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
}
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
-static Error
-compileAndExecute(Options &options, ModuleOp module,
- TranslationCallback llvmModuleBuilder, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer,
- void **args) {
+static Error compileAndExecute(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config, void **args) {
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel =
@@ -143,11 +158,15 @@ compileAndExecute(Options &options, ModuleOp module,
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
auto expectedEngine = mlir::ExecutionEngine::create(
- module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
+ module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
+ libs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
+ if (config.runtimeSymbolMap)
+ engine->registerSymbols(config.runtimeSymbolMap);
+
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
@@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
return Error::success();
}
-static Error compileAndExecuteVoidFunction(
- Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
- StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
+static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.empty())
return make_string_error("entry point not found");
void *empty = nullptr;
- return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
- transformer, &empty);
+ return compileAndExecute(options, module, entryPoint, config, &empty);
}
template <typename Type>
@@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
return Error::success();
}
template <typename Type>
-Error compileAndExecuteSingleReturnFunction(
- Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
- StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
+Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
@@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
void *data;
} data;
data.data = &res;
- if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
- entryPoint, transformer, (void **)&data))
+ if (auto error = compileAndExecute(options, module, entryPoint, config,
+ (void **)&data))
return error;
// Intentional printing of the output so we can test.
@@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
}
/// Entry point for all CPU runners. Expects the common argc/argv arguments for
-/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int mlir::JitRunnerMain(
- int argc, char **argv,
- function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
- TranslationCallback llvmModuleBuilder) {
+/// standard C++ main functions.
+int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
// Create the options struct containing the command line options for the
// runner. This must come before the command line options are parsed.
Options options;
@@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
return 1;
}
- if (mlirTransformer)
- if (failed(mlirTransformer(m.get())))
+ if (config.mlirTransformer)
+ if (failed(config.mlirTransformer(m.get())))
return EXIT_FAILURE;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
@@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
+ CompileAndExecuteConfig compileAndExecuteConfig;
+ compileAndExecuteConfig.transformer = transformer;
+ compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
+ compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
+
// Get the function used to compile and execute the module.
using CompileAndExecuteFnT =
- Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
- std::function<llvm::Error(llvm::Module *)>);
+ Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
auto compileAndExecuteFn =
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
@@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
.Case("void", compileAndExecuteVoidFunction)
.Default(nullptr);
- Error error =
- compileAndExecuteFn
- ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
- options.mainFuncName.getValue(), transformer)
- : make_string_error("unsupported function type");
+ Error error = compileAndExecuteFn
+ ? compileAndExecuteFn(options, m.get(),
+ options.mainFuncName.getValue(),
+ compileAndExecuteConfig)
+ : make_string_error("unsupported function type");
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
index 7667908c39b3..a2661c167af3 100644
--- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
+++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
@@ -24,5 +24,5 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();
- return mlir::JitRunnerMain(argc, argv, nullptr);
+ return mlir::JitRunnerMain(argc, argv);
}
diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
index be00646bd0ce..cfffaaa13126 100644
--- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
+++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -136,5 +136,9 @@ int main(int argc, char **argv) {
LLVMInitializeNVPTXAsmPrinter();
mlir::initializeLLVMPasses();
- return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+
+ mlir::JitRunnerConfig jitRunnerConfig;
+ jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+
+ return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}
diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
index 9979801023b2..cc0f503f9a50 100644
--- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
+++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
@@ -86,5 +86,9 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();
- return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
+ mlir::JitRunnerConfig jitRunnerConfig;
+ jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+ jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule;
+
+ return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 905d2e422115..322f9491bb42 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -58,5 +58,8 @@ int main(int argc, char **argv) {
llvm::InitializeNativeTargetAsmPrinter();
mlir::initializeLLVMPasses();
- return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+ mlir::JitRunnerConfig jitRunnerConfig;
+ jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+
+ return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
}
More information about the Mlir-commits
mailing list