[Mlir-commits] [mlir] [MLIR][JitRunner] Correctly register symbol map (PR #90381)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 28 00:59:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Menooker (Menooker)
<details>
<summary>Changes</summary>
Fixed an issue of JitRunner, where it ignores the runtimeSymbolMap passed by the user. Added the code to call `registerSymbols` on the ExecutionEngine.
---
Full diff: https://github.com/llvm/llvm-project/pull/90381.diff
3 Files Affected:
- (modified) mlir/lib/ExecutionEngine/JitRunner.cpp (+2)
- (modified) mlir/unittests/ExecutionEngine/CMakeLists.txt (+1)
- (modified) mlir/unittests/ExecutionEngine/Invoke.cpp (+63)
``````````diff
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index cf462ddf6f17ca..af5cd2f0784124 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -201,6 +201,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
+ if (config.runtimeSymbolMap)
+ engine->registerSymbols(config.runtimeSymbolMap);
auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index 383e172aa3f667..c1b840e991f72e 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -11,6 +11,7 @@ target_link_libraries(MLIRExecutionEngineTests
MLIRExecutionEngine
MLIRMemRefToLLVM
MLIRReconcileUnrealizedCasts
+ MLIRJitRunner
${dialect_libs}
)
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index ff87fc9fad805a..ca132605274ee6 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/JitRunner.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"
#include "mlir/IR/MLIRContext.h"
@@ -295,4 +296,66 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
ASSERT_EQ(elt, coefficient * count++);
}
+namespace {
+struct TestFile {
+ SmallString<256> filename;
+ std::error_code ec;
+ TestFile(StringRef filename, StringRef contents) : filename{filename} {
+ llvm::raw_fd_ostream outf{filename, ec, llvm::sys::fs::OF_None};
+ if (!ec) {
+ outf << contents;
+ }
+ }
+ ~TestFile() {
+ if (!ec) {
+ llvm::sys::fs::remove(filename);
+ }
+ }
+};
+static int32_t callbackval = 0;
+static void intcallback(int32_t v) { callbackval = v; }
+} // namespace
+
+TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(JitRunnerSymbol)) {
+ std::string moduleStr = R"mlir(
+module {
+ llvm.func @callback(i32) attributes {sym_visibility = "private"}
+ llvm.func @caller_for_callback() attributes {llvm.emit_c_interface} {
+ %0 = llvm.mlir.constant(114514 : i32) : i32
+ llvm.call @callback(%0) : (i32) -> ()
+ llvm.return
+ }
+ llvm.func @_mlir_ciface_caller_for_callback() attributes {llvm.emit_c_interface} {
+ llvm.call @caller_for_callback() : () -> ()
+ llvm.return
+ }
+}
+ )mlir";
+ char path[] = "mlir_executionengine_jitrunnertest.mlir";
+ TestFile testfile{path, moduleStr};
+ ASSERT_TRUE(!testfile.ec);
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ registerBuiltinDialectTranslation(registry);
+ registerLLVMDialectTranslation(registry);
+ mlir::JitRunnerConfig config;
+ config.runtimesymbolMap = [](llvm::orc::MangleAndInterner interner) {
+ llvm::orc::SymbolMap symbolMap;
+ symbolMap[interner("callback")] = {
+ llvm::orc::ExecutorAddr::fromPtr(intcallback),
+ llvm::JITSymbolFlags::Exported};
+ return symbolMap;
+ };
+
+ char S_toolname[] = "test-runner";
+ char S_e[] = "-e";
+ char S_main[] = "caller_for_callback";
+ char S_entryPoint[] = "-entry-point-result=void";
+ char *argv[] = {S_toolname, path, S_e, S_main, S_entryPoint};
+ int exitcode = mlir::JitRunnerMain(sizeof(argv) / sizeof(argv[0]), argv,
+ registry, config);
+ ASSERT_EQ(exitcode, 0);
+ ASSERT_EQ(callbackval, 114514);
+}
+
#endif // _WIN32
``````````
</details>
https://github.com/llvm/llvm-project/pull/90381
More information about the Mlir-commits
mailing list