[Mlir-commits] [mlir] [MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call (PR #153524)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 13 19:21:17 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-execution-engine

@llvm/pr-subscribers-mlir

Author: Shenghang Tsai (jackalcooper)

<details>
<summary>Changes</summary>

Retry landing https://github.com/llvm/llvm-project/pull/153373
## Major changes from previous attempt
- remove the test in CAPI because no existing tests in CAPI deal with sanitizer exemptions
- update `mlir/docs/Dialects/GPU.md` to reflect the new behavior: load GPU binary in global ctors, instead of loading them at call site. 
- skip the C++ test under more sanitizers (asan, hwasan, ubsan)

---
Full diff: https://github.com/llvm/llvm-project/pull/153524.diff


8 Files Affected:

- (modified) mlir/docs/Dialects/GPU.md (+10-1) 
- (modified) mlir/include/mlir-c/ExecutionEngine.h (+7) 
- (modified) mlir/include/mlir/ExecutionEngine/ExecutionEngine.h (+9) 
- (modified) mlir/lib/Bindings/Python/ExecutionEngineModule.cpp (+12-1) 
- (modified) mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp (+6-3) 
- (modified) mlir/lib/ExecutionEngine/ExecutionEngine.cpp (+12-8) 
- (modified) mlir/lib/ExecutionEngine/JitRunner.cpp (+2) 
- (modified) mlir/unittests/ExecutionEngine/Invoke.cpp (+55) 


``````````diff
diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md
index 94b053daa1615..a06183ad39dce 100644
--- a/mlir/docs/Dialects/GPU.md
+++ b/mlir/docs/Dialects/GPU.md
@@ -193,10 +193,19 @@ llvm.func @foo() {
 // mlir-translate --mlir-to-llvmir:
 @binary_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
 @binary_func_kernel_name = private unnamed_addr constant [7 x i8] c"func\00", align 1
+ at binary_@module = internal global ptr null
+ at llvm.global_ctors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_load, ptr null}]
+ at llvm.global_dtors = appending global [1 x {i32, ptr, ptr}] [{i32 123, ptr @binary_unload, ptr null}]
+define internal void @binary_load() section ".text.startup" {
+entry:
+  %0 = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
+  store ptr %0, ptr @module
+  ...
+}
 ...
 define void @foo() {
   ...
-  %module = call ptr @mgpuModuleLoad(ptr @binary_bin_cst)
+  %module = load ptr, ptr @binary_module, align 8
   %kernel = call ptr @mgpuModuleGetFunction(ptr %module, ptr @binary_func_kernel_name)
   call void @mgpuLaunchKernel(ptr %kernel, ...) ; Launch the kernel
   ...
diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h
index 99cddc5c2598d..1a58d68533f24 100644
--- a/mlir/include/mlir-c/ExecutionEngine.h
+++ b/mlir/include/mlir-c/ExecutionEngine.h
@@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(
     MlirModule op, int optLevel, int numPaths,
     const MlirStringRef *sharedLibPaths, bool enableObjectDump);
 
+/// Initialize the ExecutionEngine. Global constructors specified by
+/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
+/// binary compiled from `gpu.module` gets loaded during initialization. Make
+/// sure all symbols are resolvable before initialization by calling
+/// `mlirExecutionEngineRegisterSymbol` or including shared libraries.
+MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit);
+
 /// Destroy an ExecutionEngine instance.
 MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit);
 
diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
index 96ccebcd5685e..5bd71d68d253a 100644
--- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
+++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
@@ -227,6 +227,13 @@ class ExecutionEngine {
       llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
           symbolMap);
 
+  /// Initialize the ExecutionEngine. Global constructors specified by
+  /// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel
+  /// binary compiled from `gpu.module` gets loaded during initialization. Make
+  /// sure all symbols are resolvable before initialization by calling
+  /// `registerSymbols` or including shared libraries.
+  void initialize();
+
 private:
   /// Ordering of llvmContext and jit is important for destruction purposes: the
   /// jit must be destroyed before the context.
@@ -250,6 +257,8 @@ class ExecutionEngine {
   /// Destroy functions in the libraries loaded by the ExecutionEngine that are
   /// called when this ExecutionEngine is destructed.
   SmallVector<LibraryDestroyFn> destroyFns;
+
+  bool isInitialized = false;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 81dada3553622..4f7a4a628e246 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir-c/ExecutionEngine.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace nb = nanobind;
 using namespace mlir;
@@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
           },
           nb::arg("name"), nb::arg("callback"),
           "Register `callback` as the runtime symbol `name`.")
+      .def(
+          "initialize",
+          [](PyExecutionEngine &executionEngine) {
+            mlirExecutionEngineInitialize(executionEngine.get());
+          },
+          "Initialize the ExecutionEngine. Global constructors specified by "
+          "`llvm.mlir.global_ctors` will be run. One common scenario is that "
+          "kernel binary compiled from `gpu.module` gets loaded during "
+          "initialization. Make sure all symbols are resolvable before "
+          "initialization by calling `raw_register_runtime` or including "
+          "shared libraries.")
       .def(
           "dump_to_object_file",
           [](PyExecutionEngine &executionEngine, const std::string &fileName) {
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 306cebd236be9..2dbb993b1640f 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
   return wrap(jitOrError->release());
 }
 
+extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
+  unwrap(jit)->initialize();
+}
+
 extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
   delete (unwrap(jit));
 }
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
                                                   void *sym) {
   unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
     llvm::orc::SymbolMap symbolMap;
-    symbolMap[interner(unwrap(name))] =
-        { llvm::orc::ExecutorAddr::fromPtr(sym),
-          llvm::JITSymbolFlags::Exported };
+    symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+                                         llvm::JITSymbolFlags::Exported};
     return symbolMap;
   });
 }
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index f704fbfbe8fff..52162a43aeae3 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
   }
   // Compilation is lazy and it doesn't populate object cache unless requested.
   // In case object dump is requested before cache is populated, we need to
-  // force compilation manually. 
+  // force compilation manually.
   if (cache->isEmpty()) {
     for (std::string &functionName : functionNames) {
       auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
     return symbolMap;
   };
   engine->registerSymbols(runtimeSymbolMap);
-
-  // Execute the global constructors from the module being processed.
-  // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
-  // crash for AArch64 see related issue #71963.
-  if (!engine->jit->getTargetTriple().isAArch64())
-    cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
-
   return std::move(engine);
 }
 
@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
 
 Error ExecutionEngine::invokePacked(StringRef name,
                                     MutableArrayRef<void *> args) {
+  initialize();
   auto expectedFPtr = lookupPacked(name);
   if (!expectedFPtr)
     return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
 
   return Error::success();
 }
+
+void ExecutionEngine::initialize() {
+  if (isInitialized)
+    return;
+  // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
+  // crash for AArch64 see related issue #71963.
+  if (!jit->getTargetTriple().isAArch64())
+    cantFail(jit->initialize(jit->getMainJITDylib()));
+  isInitialized = true;
+}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 2107df37d1997..0ada4cc96570a 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
 
   auto engine = std::move(*expectedEngine);
 
+  engine->initialize();
+
   auto expectedFPtr = engine->lookupPacked(entryPoint);
   if (!expectedFPtr)
     return expectedFPtr.takeError();
diff --git a/mlir/unittests/ExecutionEngine/Invoke.cpp b/mlir/unittests/ExecutionEngine/Invoke.cpp
index 312b10f28143f..8d5f99ec71c13 100644
--- a/mlir/unittests/ExecutionEngine/Invoke.cpp
+++ b/mlir/unittests/ExecutionEngine/Invoke.cpp
@@ -322,4 +322,59 @@ TEST(NativeMemRefJit, MAYBE_JITCallback) {
     ASSERT_EQ(elt, coefficient * count++);
 }
 
+static int initCnt = 0;
+// A helper function that will be called during the JIT's initialization.
+static void initCallback() { initCnt += 1; }
+
+#if __has_feature(memory_sanitizer) || __has_feature(address_sanitizer)
+#define MAYBE_JITCallbackInGlobalCtor DISABLED_JITCallbackInGlobalCtor
+#else
+#define MAYBE_JITCallbackInGlobalCtor SKIP_WITHOUT_JIT(JITCallbackInGlobalCtor)
+#endif
+TEST(MLIRExecutionEngine, MAYBE_JITCallbackInGlobalCtor) {
+  std::string moduleStr = R"mlir(
+  llvm.mlir.global_ctors ctors = [@ctor], priorities = [0 : i32], data = [#llvm.zero]
+  llvm.func @ctor() {
+    func.call @init_callback() : () -> ()
+    llvm.return
+  }
+  func.func private @init_callback() attributes { llvm.emit_c_interface }
+  )mlir";
+
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerBuiltinDialectTranslation(registry);
+  registerLLVMDialectTranslation(registry);
+  MLIRContext context(registry);
+  auto module = parseSourceString<ModuleOp>(moduleStr, &context);
+  ASSERT_TRUE(!!module);
+  ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
+  ExecutionEngineOptions jitOptions;
+  auto jitOrError = ExecutionEngine::create(*module, jitOptions);
+  ASSERT_TRUE(!!jitOrError);
+  // validate initialization is not run on construction
+  ASSERT_EQ(initCnt, 0);
+  auto jit = std::move(jitOrError.get());
+  // Define any extra symbols so they're available at initialization.
+  jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
+    llvm::orc::SymbolMap symbolMap;
+    symbolMap[interner("_mlir_ciface_init_callback")] = {
+        llvm::orc::ExecutorAddr::fromPtr(initCallback),
+        llvm::JITSymbolFlags::Exported};
+    return symbolMap;
+  });
+  jit->initialize();
+  // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
+  // crash for AArch64 see related issue #71963.
+  auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
+  ASSERT_TRUE(!!tmBuilderOrError);
+  if (!tmBuilderOrError->getTargetTriple().isAArch64()) {
+    // validate the side effect of initialization
+    ASSERT_EQ(initCnt, 1);
+    // next initialization should be noop
+    jit->initialize();
+    ASSERT_EQ(initCnt, 1);
+  }
+}
+
 #endif // _WIN32

``````````

</details>


https://github.com/llvm/llvm-project/pull/153524


More information about the Mlir-commits mailing list