[llvm] [mlir] [mlir][gpu] Change GPU modules to globals (PR #135478)

Christian Sigg via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 11 23:18:40 PDT 2025


https://github.com/chsigg created https://github.com/llvm/llvm-project/pull/135478

Load/unload GPU modules in global ctors/dtors instead of each time when launching a kernel.

Loading GPU modules is a heavy-weight operation and synchronizes the GPU context. Now that the modules are loaded ahead of time, asynchronously launched kernels can run concurrently, see https://discourse.llvm.org/t/how-to-lower-the-combination-of-async-gpu-ops-in-gpu-dialect.

The implementations of `embedBinary()` and `launchKernel()` use slightly different mechanics at the moment but I prefer to not change the latter more than necessary as part of this PR. I will prepare a follow-up NFC for `launchKernel()` to align them again.

>From 1e66a9b5a7b555ce003dcf2e7bfde346ddf7144f Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Sat, 12 Apr 2025 08:08:54 +0200
Subject: [PATCH] Load/unload GPU modules in global ctors/dtors instead of each
 time when launching a kernel.

Loading GPU modules is a heavy-weight operation and synchronizes the GPU context. Now that the modules are loaded ahead of time, asynchronously launched kernels can run concurrently, see https://discourse.llvm.org/t/how-to-lower-the-combination-of-async-gpu-ops-in-gpu-dialect.
---
 .../LLVMIR/Dialect/GPU/SelectObjectAttr.cpp   | 258 +++++++++---------
 .../GPU/CUDA/concurrent-kernels.mlir          |  48 ++++
 mlir/test/Target/LLVMIR/gpu.mlir              |  71 +++--
 .../llvm-project-overlay/mlir/BUILD.bazel     |   1 +
 4 files changed, 212 insertions(+), 166 deletions(-)
 create mode 100644 mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir

diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
index 8d4a0bcf8adbf..d3216d9ad17eb 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp
@@ -18,11 +18,13 @@
 #include "mlir/Target/LLVMIR/Export.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Transforms/Utils/ModuleUtils.h"
 
 using namespace mlir;
 
@@ -31,9 +33,13 @@ namespace {
 class SelectObjectAttrImpl
     : public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
           SelectObjectAttrImpl> {
+  // Returns the selected object for embedding.
+  gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
+
 public:
   // Translates a `gpu.binary`, embedding the binary into a host LLVM module as
-  // global binary string.
+  // global binary string which gets loaded/unloaded into a global module
+  // object through a global ctor/dtor.
   LogicalResult embedBinary(Attribute attribute, Operation *operation,
                             llvm::IRBuilderBase &builder,
                             LLVM::ModuleTranslation &moduleTranslation) const;
@@ -45,23 +51,9 @@ class SelectObjectAttrImpl
                              Operation *binaryOperation,
                              llvm::IRBuilderBase &builder,
                              LLVM::ModuleTranslation &moduleTranslation) const;
-
-  // Returns the selected object for embedding.
-  gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
 };
-// Returns an identifier for the global string holding the binary.
-std::string getBinaryIdentifier(StringRef binaryName) {
-  return binaryName.str() + "_bin_cst";
-}
 } // namespace
 
-void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
-    SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
-  });
-}
-
 gpu::ObjectAttr
 SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
   ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
@@ -96,6 +88,94 @@ SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
   return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
 }
 
+static Twine getModuleIdentifier(StringRef moduleName) {
+  return moduleName + "_module";
+}
+
+namespace llvm {
+static LogicalResult embedBinaryImpl(StringRef moduleName,
+                                     gpu::ObjectAttr object, Module &module) {
+
+  // Embed the object as a global string.
+  // Add null for assembly output for JIT paths that expect null-terminated
+  // strings.
+  bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
+  StringRef serializedStr = object.getObject().getValue();
+  Constant *serializedCst =
+      ConstantDataArray::getString(module.getContext(), serializedStr, addNull);
+  GlobalVariable *serializedObj =
+      new GlobalVariable(module, serializedCst->getType(), true,
+                         GlobalValue::LinkageTypes::InternalLinkage,
+                         serializedCst, moduleName + "_binary");
+  serializedObj->setAlignment(MaybeAlign(8));
+  serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
+
+  // Default JIT optimization level.
+  auto optLevel = APInt::getZero(32);
+
+  if (DictionaryAttr objectProps = object.getProperties()) {
+    if (auto section = dyn_cast_or_null<StringAttr>(
+            objectProps.get(gpu::elfSectionName))) {
+      serializedObj->setSection(section.getValue());
+    }
+    // Check if there's an optimization level embedded in the object.
+    if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get("O")))
+      optLevel = optAttr.getValue();
+  }
+
+  IRBuilder<> builder(module.getContext());
+  auto i32Ty = builder.getInt32Ty();
+  auto i64Ty = builder.getInt64Ty();
+  auto ptrTy = builder.getPtrTy(0);
+  auto voidTy = builder.getVoidTy();
+
+  // Embed the module as a global object.
+  auto *modulePtr = new GlobalVariable(
+      module, ptrTy, /*isConstant=*/false, GlobalValue::InternalLinkage,
+      /*Initializer=*/ConstantPointerNull::get(ptrTy),
+      getModuleIdentifier(moduleName));
+
+  auto *loadFn = Function::Create(FunctionType::get(voidTy, /*IsVarArg=*/false),
+                                  GlobalValue::InternalLinkage,
+                                  moduleName + "_load", module);
+  loadFn->setSection(".text.startup");
+  auto *loadBlock = BasicBlock::Create(module.getContext(), "entry", loadFn);
+  builder.SetInsertPoint(loadBlock);
+  Value *moduleObj = [&] {
+    if (object.getFormat() == gpu::CompilationTarget::Assembly) {
+      FunctionCallee moduleLoadFn = module.getOrInsertFunction(
+          "mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
+      Constant *optValue = ConstantInt::get(i32Ty, optLevel);
+      return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
+    } else {
+      FunctionCallee moduleLoadFn = module.getOrInsertFunction(
+          "mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
+      Constant *binarySize =
+          ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
+      return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
+    }
+  }();
+  builder.CreateStore(moduleObj, modulePtr);
+  builder.CreateRetVoid();
+  appendToGlobalCtors(module, loadFn, /*Priority=*/123);
+
+  auto *unloadFn = Function::Create(
+      FunctionType::get(voidTy, /*IsVarArg=*/false),
+      GlobalValue::InternalLinkage, moduleName + "_unload", module);
+  unloadFn->setSection(".text.startup");
+  auto *unloadBlock =
+      BasicBlock::Create(module.getContext(), "entry", unloadFn);
+  builder.SetInsertPoint(unloadBlock);
+  FunctionCallee moduleUnloadFn = module.getOrInsertFunction(
+      "mgpuModuleUnload", FunctionType::get(voidTy, ptrTy, false));
+  builder.CreateCall(moduleUnloadFn, builder.CreateLoad(ptrTy, modulePtr));
+  builder.CreateRetVoid();
+  appendToGlobalDtors(module, unloadFn, /*Priority=*/123);
+
+  return success();
+}
+} // namespace llvm
+
 LogicalResult SelectObjectAttrImpl::embedBinary(
     Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation) const {
@@ -113,29 +193,8 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
   if (!object)
     return failure();
 
-  llvm::Module *module = moduleTranslation.getLLVMModule();
-
-  // Embed the object as a global string.
-  // Add null for assembly output for JIT paths that expect null-terminated
-  // strings.
-  bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
-  llvm::Constant *binary = llvm::ConstantDataArray::getString(
-      builder.getContext(), object.getObject().getValue(), addNull);
-  llvm::GlobalVariable *serializedObj =
-      new llvm::GlobalVariable(*module, binary->getType(), true,
-                               llvm::GlobalValue::LinkageTypes::InternalLinkage,
-                               binary, getBinaryIdentifier(op.getName()));
-
-  if (object.getProperties()) {
-    if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
-            object.getProperties().get(gpu::elfSectionName))) {
-      serializedObj->setSection(section.getValue());
-    }
-  }
-  serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
-  serializedObj->setAlignment(llvm::MaybeAlign(8));
-  serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
-  return success();
+  return embedBinaryImpl(op.getName(), object,
+                         *moduleTranslation.getLLVMModule());
 }
 
 namespace llvm {
@@ -153,15 +212,6 @@ class LaunchKernel {
   // Get the module function callee.
   FunctionCallee getModuleFunctionFn();
 
-  // Get the module load callee.
-  FunctionCallee getModuleLoadFn();
-
-  // Get the module load JIT callee.
-  FunctionCallee getModuleLoadJITFn();
-
-  // Get the module unload callee.
-  FunctionCallee getModuleUnloadFn();
-
   // Get the stream create callee.
   FunctionCallee getStreamCreateFn();
 
@@ -261,24 +311,6 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
       FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
 }
 
-llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
-  return module.getOrInsertFunction(
-      "mgpuModuleLoad",
-      FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
-}
-
-llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
-  return module.getOrInsertFunction(
-      "mgpuModuleLoadJIT",
-      FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
-}
-
-llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
-  return module.getOrInsertFunction(
-      "mgpuModuleUnload",
-      FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
-}
-
 llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
   return module.getOrInsertFunction("mgpuStreamCreate",
                                     FunctionType::get(ptrTy, false));
@@ -301,9 +333,9 @@ llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
 llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
                                                          StringRef kernelName) {
   std::string globalName =
-      std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+      std::string(formatv("{0}_{1}_name", moduleName, kernelName));
 
-  if (GlobalVariable *gv = module.getGlobalVariable(globalName))
+  if (GlobalVariable *gv = module.getGlobalVariable(globalName, true))
     return gv;
 
   return builder.CreateGlobalString(kernelName, globalName);
@@ -346,16 +378,13 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
 }
 
 // Emits LLVM IR to launch a kernel function:
-// %0 = call %binarygetter
-// %1 = call %moduleLoad(%0)
-// %2 = <see generateKernelNameConstant>
-// %3 = call %moduleGetFunction(%1, %2)
-// %4 = call %streamCreate()
-// %5 = <see generateParamsArray>
-// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
-// call %streamSynchronize(%4)
-// call %streamDestroy(%4)
-// call %moduleUnload(%1)
+// %1 = load %global_module_object
+// %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
+// %3 = call @mgpuStreamCreate()
+// %4 = <see createKernelArgArray()>
+// call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
+// call @mgpuStreamSynchronize(%3)
+// call @mgpuStreamDestroy(%3)
 llvm::LogicalResult
 llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
                                        mlir::gpu::ObjectAttr object) {
@@ -385,58 +414,29 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
   // Create the argument array.
   Value *argArray = createKernelArgArray(op);
 
-  // Default JIT optimization level.
-  llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
-  // Check if there's an optimization level embedded in the object.
-  DictionaryAttr objectProps = object.getProperties();
-  mlir::Attribute optAttr;
-  if (objectProps && (optAttr = objectProps.get("O"))) {
-    auto optLevel = dyn_cast<IntegerAttr>(optAttr);
-    if (!optLevel)
-      return op.emitError("the optimization level must be an integer");
-    optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
-  }
-
-  // Load the kernel module.
-  StringRef moduleName = op.getKernelModuleName().getValue();
-  std::string binaryIdentifier = getBinaryIdentifier(moduleName);
-  Value *binary = module.getGlobalVariable(binaryIdentifier, true);
-  if (!binary)
-    return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
-
-  auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
-  if (!binaryVar)
-    return op.emitError() << "Binary is not a global variable: "
-                          << binaryIdentifier;
-  llvm::Constant *binaryInit = binaryVar->getInitializer();
-  auto binaryDataSeq =
-      dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
-  if (!binaryDataSeq)
-    return op.emitError() << "Couldn't find binary data array: "
-                          << binaryIdentifier;
-  llvm::Constant *binarySize =
-      llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
-                                        binaryDataSeq->getElementByteSize());
-
-  Value *moduleObject =
-      object.getFormat() == gpu::CompilationTarget::Assembly
-          ? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
-          : builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
-
   // Load the kernel function.
-  Value *moduleFunction = builder.CreateCall(
-      getModuleFunctionFn(),
-      {moduleObject,
-       getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
+  StringRef moduleName = op.getKernelModuleName().getValue();
+  Twine moduleIdentifier = getModuleIdentifier(moduleName);
+  Value *modulePtr = module.getGlobalVariable(moduleIdentifier.str(), true);
+  if (!modulePtr)
+    return op.emitError() << "Couldn't find the binary: " << moduleIdentifier;
+  Value *moduleObj = builder.CreateLoad(ptrTy, modulePtr);
+  Value *functionName = getOrCreateFunctionName(moduleName, op.getKernelName());
+  Value *moduleFunction =
+      builder.CreateCall(getModuleFunctionFn(), {moduleObj, functionName});
 
   // Get the stream to use for execution. If there's no async object then create
   // a stream to make a synchronous kernel launch.
   Value *stream = nullptr;
-  bool handleStream = false;
+  // Sync & destroy the stream, for synchronous launches.
+  auto destroyStream = make_scope_exit([&]() {
+    builder.CreateCall(getStreamSyncFn(), {stream});
+    builder.CreateCall(getStreamDestroyFn(), {stream});
+  });
   if (mlir::Value asyncObject = op.getAsyncObject()) {
     stream = llvmValue(asyncObject);
+    destroyStream.release();
   } else {
-    handleStream = true;
     stream = builder.CreateCall(getStreamCreateFn(), {});
   }
 
@@ -462,14 +462,12 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
                                           argArray, nullPtr, paramsCount}));
   }
 
-  // Sync & destroy the stream, for synchronous launches.
-  if (handleStream) {
-    builder.CreateCall(getStreamSyncFn(), {stream});
-    builder.CreateCall(getStreamDestroyFn(), {stream});
-  }
-
-  // Unload the kernel module.
-  builder.CreateCall(getModuleUnloadFn(), {moduleObject});
-
   return success();
 }
+
+void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
+    SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
+  });
+}
diff --git a/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir
new file mode 100644
index 0000000000000..80cc6d6bf91dd
--- /dev/null
+++ b/mlir/test/Integration/GPU/CUDA/concurrent-kernels.mlir
@@ -0,0 +1,48 @@
+// Tests that we can run multiple kernels concurrently. Runs two kernels, which
+// increment a global atomic counter, then wait for the counter to reach 2.
+//
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --entry-point-result=void
+
+module attributes {gpu.container_module} {
+    gpu.module @kernels {
+        gpu.func @kernel(%memref: memref<i32>) kernel {
+            %c0 = arith.constant 0 : i32
+            %c1 = arith.constant 1 : i32
+            %c2 = arith.constant 2 : i32
+            %block = memref.atomic_rmw addi %c1, %memref[] : (i32, memref<i32>) -> i32
+            scf.while: () -> () {
+                %value = memref.atomic_rmw addi %c0, %memref[] : (i32, memref<i32>) -> i32
+                %cond = arith.cmpi slt, %value, %c2 : i32
+                scf.condition(%cond)
+            } do {
+                scf.yield
+            }
+            gpu.return
+        }
+    }
+
+    func.func @main() {
+        %memref = gpu.alloc host_shared () : memref<i32>
+        %c0 = arith.constant 0 : i32
+        memref.store %c0, %memref[] : memref<i32>
+
+        %0 = gpu.wait async
+        %1 = gpu.wait async
+        %c1 = arith.constant 1 : index
+        %2 = gpu.launch_func async [%0] @kernels::@kernel
+            blocks in (%c1, %c1, %c1)
+            threads in (%c1, %c1, %c1)
+            args(%memref: memref<i32>)
+        %3 = gpu.launch_func async [%1] @kernels::@kernel
+            blocks in (%c1, %c1, %c1)
+            threads in (%c1, %c1, %c1)
+            args(%memref: memref<i32>)
+        gpu.wait [%2, %3]
+        return
+    }
+}
diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir
index 6b7e7fcc71960..0d29a95b12266 100644
--- a/mlir/test/Target/LLVMIR/gpu.mlir
+++ b/mlir/test/Target/LLVMIR/gpu.mlir
@@ -3,8 +3,11 @@
 // Checking the translation of the `gpu.binary` & `gpu.launch_fun` ops.
 module attributes {gpu.container_module} {
   // CHECK: [[ARGS_TY:%.*]] = type { i32, i32 }
-  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
-  // CHECK: @kernel_module_kernel_kernel_name = private unnamed_addr constant [7 x i8] c"kernel\00", align 1
+  // CHECK-DAG: @kernel_module_binary = internal constant [4 x i8] c"BLOB", align 8
+  // CHECK-DAG: kernel_module_module = internal global ptr null
+  // CHECK-DAG: @llvm.global_ctors = appending global {{.*}} @kernel_module_load
+  // CHECK-DAG: @llvm.global_dtors = appending global {{.*}} @kernel_module_unload
+  // CHECK-DAG: @kernel_module_kernel_name = private unnamed_addr constant [7 x i8] c"kernel\00", align 1
   gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
   llvm.func @foo() {
     // CHECK: [[ARGS:%.*]] = alloca %{{.*}}, align 8
@@ -17,26 +20,32 @@ module attributes {gpu.container_module} {
     // CHECK: store i32 32, ptr [[ARG1]], align 4
     // CHECK: %{{.*}} = getelementptr ptr, ptr [[ARGS_ARRAY]], i32 1
     // CHECK: store ptr [[ARG1]], ptr %{{.*}}, align 8
-    // CHECK: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst, i64 4)
-    // CHECK: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name)
+    // CHECK: [[MODULE:%.*]] = load ptr, ptr @kernel_module_module
+    // CHECK: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_name)
     // CHECK: [[STREAM:%.*]] = call ptr @mgpuStreamCreate()
     // CHECK: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 256, ptr [[STREAM]], ptr [[ARGS_ARRAY]], ptr null, i64 2)
     // CHECK: call void @mgpuStreamSynchronize(ptr [[STREAM]])
     // CHECK: call void @mgpuStreamDestroy(ptr [[STREAM]])
-    // CHECK: call void @mgpuModuleUnload(ptr [[MODULE]])
     %0 = llvm.mlir.constant(8 : index) : i64
     %1 = llvm.mlir.constant(32 : i32) : i32
     %2 = llvm.mlir.constant(256 : i32) : i32
     gpu.launch_func @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32)
     llvm.return
   }
+  // CHECK: @kernel_module_load() section ".text.startup"
+  // CHECK: [[MODULE:%.*]] = call ptr @mgpuModuleLoad
+  // CHECK: store ptr [[MODULE]], ptr @kernel_module_module
+  //
+  // CHECK: @kernel_module_unload() section ".text.startup"
+  // CHECK: [[MODULE:%.*]] = load ptr, ptr @kernel_module_module
+  // CHECK: call void @mgpuModuleUnload(ptr [[MODULE]])
 }
 
 // -----
 
 // Checking the correct selection of the second object using an index as a selector.
 module {
-  // CHECK: @kernel_module_bin_cst = internal constant [1 x i8] c"1", align 8
+  // CHECK: @kernel_module_binary = internal constant [1 x i8] c"1", align 8
   gpu.binary @kernel_module <#gpu.select_object<1>> [#gpu.object<#nvvm.target, "0">, #gpu.object<#nvvm.target, "1">]
 }
 
@@ -44,7 +53,7 @@ module {
 
 // Checking the correct selection of the second object using a target as a selector.
 module {
-  // CHECK: @kernel_module_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8
+  // CHECK: @kernel_module_binary = internal constant [6 x i8] c"AMDGPU", align 8
   gpu.binary @kernel_module <#gpu.select_object<#rocdl.target>> [#gpu.object<#nvvm.target, "NVPTX">, #gpu.object<#rocdl.target, "AMDGPU">]
 }
 
@@ -52,52 +61,42 @@ module {
 
 // Checking the correct selection of the second object using a target as a selector.
 module {
-  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
+  // CHECK: @kernel_module_binary = internal constant [4 x i8] c"BLOB", align 8
   gpu.binary @kernel_module <#gpu.select_object<#spirv.target_env<#spirv.vce<v1.0, [Addresses, Int64, Kernel], []>, api=OpenCL, #spirv.resource_limits<>>>> [#gpu.object<#nvvm.target, "NVPTX">, #gpu.object<#spirv.target_env<#spirv.vce<v1.0, [Addresses, Int64, Kernel], []>, api=OpenCL, #spirv.resource_limits<>>, "BLOB">]
 }
 
 // -----
 // Checking the translation of `gpu.launch_fun` with an async dependency.
 module attributes {gpu.container_module} {
-  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
   gpu.binary @kernel_module  [#gpu.object<#rocdl.target, "BLOB">]
-  llvm.func @foo() {
+  llvm.func @foo(%stream : !llvm.ptr) {
     %0 = llvm.mlir.constant(8 : index) : i64
-    // CHECK: = call ptr @mgpuStreamCreate()
-    // CHECK-NEXT: = alloca {{.*}}, align 8
-    // CHECK-NEXT: [[ARGS:%.*]] = alloca ptr, i64 0, align 8
-    // CHECK-NEXT: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst, i64 4)
-    // CHECK-NEXT: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name)
-    // CHECK-NEXT: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 0, ptr {{.*}}, ptr [[ARGS]], ptr null, i64 0)
-    // CHECK-NEXT: call void @mgpuModuleUnload(ptr [[MODULE]])
-    // CHECK-NEXT: call void @mgpuStreamSynchronize(ptr %{{.*}})
-    // CHECK-NEXT: call void @mgpuStreamDestroy(ptr %{{.*}})
-    %1 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr
-    gpu.launch_func <%1 : !llvm.ptr> @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
-    llvm.call @mgpuStreamSynchronize(%1) : (!llvm.ptr) -> ()
-    llvm.call @mgpuStreamDestroy(%1) : (!llvm.ptr) -> ()
+    // CHECK-NOT: @mgpuStreamCreate
+    // CHECK: call void @mgpuLaunchKernel
+    gpu.launch_func <%stream : !llvm.ptr> @kernel_module::@kernel blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
+    // CHECK-NOT: @mgpuStreamSynchronize
+    // CHECK-NOT: @mgpuStreamDestroy
     llvm.return
   }
-  llvm.func @mgpuStreamCreate() -> !llvm.ptr
-  llvm.func @mgpuStreamSynchronize(!llvm.ptr)
-  llvm.func @mgpuStreamDestroy(!llvm.ptr)
 }
 
 // -----
 
 // Test cluster/block/thread syntax.
 module attributes {gpu.container_module} {
-  // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8
   gpu.binary @kernel_module  [#gpu.object<#nvvm.target, "BLOB">]
   llvm.func @foo() {
-  // CHECK: [[S2:%.*]] = alloca ptr, i64 0, align 8
-  // CHECK: [[S3:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst, i64 4)
-  // CHECK: [[S4:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[S3]], ptr @kernel_module_kernel_kernel_name)
-  // CHECK: [[S5:%.*]] = call ptr @mgpuStreamCreate()
-  // CHECK: call void @mgpuLaunchClusterKernel(ptr [[S4]], i64 2, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i64 1, i32 0, ptr [[S5]], ptr [[S2]], ptr null)
-    %0 = llvm.mlir.constant(1 : index) : i64
-    %1 = llvm.mlir.constant(2 : index) : i64
-    gpu.launch_func @kernel_module::@kernel clusters in (%1, %0, %0) blocks in (%0, %0, %0) threads in (%0, %0, %0) : i64
+  // CHECK: call void @mgpuLaunchClusterKernel(
+  // CHECK-SAME: i64 1, i64 1, i64 1,
+  // CHECK-SAME: i64 2, i64 2, i64 2,
+  // CHECK-SAME: i64 3, i64 3, i64 3, i32 0, ptr
+    %c1 = llvm.mlir.constant(1 : index) : i64
+    %c2 = llvm.mlir.constant(2 : index) : i64
+    %c3 = llvm.mlir.constant(3 : index) : i64
+    gpu.launch_func @kernel_module::@kernel 
+        clusters in (%c1, %c1, %c1)
+        blocks in (%c2, %c2, %c2)
+        threads in (%c3, %c3, %c3) : i64
     llvm.return
   }
 }
@@ -106,6 +105,6 @@ module attributes {gpu.container_module} {
 
 // Checking that ELF section is populated
 module attributes {gpu.container_module} {
-  // CHECK: @cuda_device_mod_bin_cst = internal constant [4 x i8] c"BLOB", section "__nv_rel_fatbin", align 8
+  // CHECK: @cuda_device_mod_binary = internal constant [4 x i8] c"BLOB", section "__nv_rel_fatbin", align 8
   gpu.binary @cuda_device_mod  [#gpu.object<#nvvm.target, properties = {section = "__nv_rel_fatbin"}, "BLOB">]
 }
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index cc4af7ce40067..e3f43e5e7d1ab 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8378,6 +8378,7 @@ cc_library(
         ":ToLLVMIRTranslation",
         "//llvm:Core",
         "//llvm:Support",
+        "//llvm:TransformUtils",
     ],
 )
 



More information about the llvm-commits mailing list