[Mlir-commits] [mlir] [MLIR][JitRunner] Correctly register symbol map (PR #90381)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 28 00:59:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Menooker (Menooker)

<details>
<summary>Changes</summary>

Fixed an issue of JitRunner, where it ignores the runtimeSymbolMap passed by the user. Added the code to call `registerSymbols` on the ExecutionEngine.

---
Full diff: https://github.com/llvm/llvm-project/pull/90381.diff


3 Files Affected:

- (modified) mlir/lib/ExecutionEngine/JitRunner.cpp (+2) 
- (modified) mlir/unittests/ExecutionEngine/CMakeLists.txt (+1) 
- (modified) mlir/unittests/ExecutionEngine/Invoke.cpp (+63) 


``````````diff
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index cf462ddf6f17ca..af5cd2f0784124 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -201,6 +201,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
     return expectedEngine.takeError();
 
   auto engine = std::move(*expectedEngine);
+  if (config.runtimeSymbolMap)
+    engine->registerSymbols(config.runtimeSymbolMap);
 
   auto expectedFPtr = engine->lookupPacked(entryPoint);
   if (!expectedFPtr)
diff --git a/mlir/unittests/ExecutionEngine/CMakeLists.txt b/mlir/unittests/ExecutionEngine/CMakeLists.txt
index 383e172aa3f667..c1b840e991f72e 100644
--- a/mlir/unittests/ExecutionEngine/CMakeLists.txt
+++ b/mlir/unittests/ExecutionEngine/CMakeLists.txt
@@ -11,6 +11,7 @@ target_link_libraries(MLIRExecutionEngineTests
   MLIRExecutionEngine
   MLIRMemRefToLLVM
   MLIRReconcileUnrealizedCasts
+  MLIRJitRunner
   ${dialect_libs}
 
 )
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index ff87fc9fad805a..ca132605274ee6 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/ExecutionEngine/CRunnerUtils.h"
 #include "mlir/ExecutionEngine/ExecutionEngine.h"
+#include "mlir/ExecutionEngine/JitRunner.h"
 #include "mlir/ExecutionEngine/MemRefUtils.h"
 #include "mlir/ExecutionEngine/RunnerUtils.h"
 #include "mlir/IR/MLIRContext.h"
@@ -295,4 +296,66 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
     ASSERT_EQ(elt, coefficient * count++);
 }
 
+namespace {
+struct TestFile {
+  SmallString<256> filename;
+  std::error_code ec;
+  TestFile(StringRef filename, StringRef contents) : filename{filename} {
+    llvm::raw_fd_ostream outf{filename, ec, llvm::sys::fs::OF_None};
+    if (!ec) {
+      outf << contents;
+    }
+  }
+  ~TestFile() {
+    if (!ec) {
+      llvm::sys::fs::remove(filename);
+    }
+  }
+};
+static int32_t callbackval = 0;
+static void intcallback(int32_t v) { callbackval = v; }
+} // namespace
+
+TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(JitRunnerSymbol)) {
+  std::string moduleStr = R"mlir(
+module {
+  llvm.func @callback(i32) attributes {sym_visibility = "private"}
+  llvm.func @caller_for_callback() attributes {llvm.emit_c_interface} {
+    %0 = llvm.mlir.constant(114514 : i32) : i32
+    llvm.call @callback(%0) : (i32) -> ()
+    llvm.return
+  }
+  llvm.func @_mlir_ciface_caller_for_callback() attributes {llvm.emit_c_interface} {
+    llvm.call @caller_for_callback() : () -> ()
+    llvm.return
+  }
+}
+  )mlir";
+  char path[] = "mlir_executionengine_jitrunnertest.mlir";
+  TestFile testfile{path, moduleStr};
+  ASSERT_TRUE(!testfile.ec);
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerBuiltinDialectTranslation(registry);
+  registerLLVMDialectTranslation(registry);
+  mlir::JitRunnerConfig config;
+  config.runtimesymbolMap = [](llvm::orc::MangleAndInterner interner) {
+    llvm::orc::SymbolMap symbolMap;
+    symbolMap[interner("callback")] = {
+        llvm::orc::ExecutorAddr::fromPtr(intcallback),
+        llvm::JITSymbolFlags::Exported};
+    return symbolMap;
+  };
+
+  char S_toolname[] = "test-runner";
+  char S_e[] = "-e";
+  char S_main[] = "caller_for_callback";
+  char S_entryPoint[] = "-entry-point-result=void";
+  char *argv[] = {S_toolname, path, S_e, S_main, S_entryPoint};
+  int exitcode = mlir::JitRunnerMain(sizeof(argv) / sizeof(argv[0]), argv,
+                                     registry, config);
+  ASSERT_EQ(exitcode, 0);
+  ASSERT_EQ(callbackval, 114514);
+}
+
 #endif // _WIN32

``````````

</details>


https://github.com/llvm/llvm-project/pull/90381


More information about the Mlir-commits mailing list