[llvm-branch-commits] [clang] [CIR][CUDA] Do Runtime Kernel Registration (PR #188926)

David Rivera via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Apr 2 14:51:53 PDT 2026


https://github.com/RiverDave updated https://github.com/llvm/llvm-project/pull/188926

>From f41fc9f0f01be7eca3000ec69b406002c0fddfe7 Mon Sep 17 00:00:00 2001
From: David Rivera <davidriverg at gmail.com>
Date: Thu, 2 Apr 2026 17:47:02 -0400
Subject: [PATCH 1/4] remove unused var

---
 clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index b5140c281ed2d..8185385f92b50 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -1741,7 +1741,6 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
   CIRBaseBuilderTy builder(getContext());
   builder.setInsertionPointToStart(mlirModule.getBody());
 
-  VoidType voidTy = builder.getVoidTy();
   PointerType voidPtrTy = builder.getVoidPtrTy();
   PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
   IntType intTy = builder.getSIntNTy(32);

>From 2f6160dce1ce8595f0769a3a49e41459aa84ac7d Mon Sep 17 00:00:00 2001
From: David Rivera <davidriverg at gmail.com>
Date: Wed, 25 Mar 2026 22:29:47 -0400
Subject: [PATCH 2/4] [CIR][CUDA] Handle CUDA module constructor and destructor
 emission.

---
 .../Dialect/Transforms/LoweringPrepare.cpp    | 126 +++++++++++++++++-
 clang/test/CIR/CodeGenCUDA/device-stub.cu     |  41 ++++++
 2 files changed, 165 insertions(+), 2 deletions(-)

diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index 8185385f92b50..78462e571e85b 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -10,8 +10,10 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Value.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Mangle.h"
+#include "clang/Basic/Cuda.h"
 #include "clang/Basic/Module.h"
 #include "clang/Basic/Specifiers.h"
 #include "clang/Basic/TargetCXXABI.h"
@@ -27,10 +29,15 @@
 #include "clang/CIR/MissingFeatures.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/TypeSwitch.h"
+<<<<<<< HEAD
 #include "llvm/Support/MemoryBuffer.h"
+=======
+#include "llvm/IR/Instructions.h"
+>>>>>>> fff0ddb60480 ([CIR][CUDA] Handle CUDA module constructor and destructor emission.)
 #include "llvm/Support/Path.h"
 
 #include <memory>
+#include <optional>
 
 using namespace mlir;
 using namespace cir;
@@ -121,6 +128,7 @@ struct LoweringPreparePass
   /// Build the CUDA module constructor that registers the fat binary
   /// with the CUDA runtime.
   void buildCUDAModuleCtor();
+  std::optional<FuncOp> buildCUDAModuleDtor();
 
   /// Handle static local variable initialization with guard variables.
   void handleStaticLocal(cir::GlobalOp globalOp, cir::GetGlobalOp getGlobalOp);
@@ -1806,8 +1814,122 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
   gpuBinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
   gpuBinHandle.setPrivate();
 
-  // TODO: ctor/dtor/register_globals
-  assert(!cir::MissingFeatures::globalRegistration());
+  // Declare this function:
+  //    void **__{cuda|hip}RegisterFatBinary(void *);
+
+  std::string regFuncName =
+      addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
+  FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
+  cir::FuncOp regFunc =
+      buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
+
+  std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
+  cir::FuncOp moduleCtor = buildRuntimeFunction(
+      builder, moduleCtorName, loc, FuncType::get({}, voidTy),
+      GlobalLinkageKind::InternalLinkage);
+
+  globalCtorList.emplace_back(moduleCtorName,
+                              cir::GlobalCtorAttr::getDefaultPriority());
+  builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
+  assert(!cir::MissingFeatures::opGlobalCtorPriority());
+  if (isHIP) {
+    llvm_unreachable("HIP Module Constructor Support");
+  } else if (!astCtx->getLangOpts().GPURelocatableDeviceCode) {
+
+    // --- Create CUDA CTOR-DTOR ---
+    // Register binary with CUDA runtime. This is substantially different in
+    // default mode vs. separate compilation.
+    // Corresponding code:
+    //     gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
+    mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
+    mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
+    cir::CallOp gpuBinaryHandleCall =
+        builder.createCallOp(loc, regFunc, fatbinVoidPtr);
+    mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
+    // Store the value back to the global `__cuda_gpubin_handle`.
+    mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
+    builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
+
+    // TODO: Generate __cuda_register_globals and emit a call.
+    assert(!cir::MissingFeatures::globalRegistration());
+
+    // From CUDA 10.1 onwards, we must call this function to end registration:
+    //      void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
+    // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
+    if (clang::CudaFeatureEnabled(
+            astCtx->getTargetInfo().getSDKVersion(),
+            clang::CudaFeature::CUDA_USES_FATBIN_REGISTER_END)) {
+      cir::CIRBaseBuilderTy globalBuilder(getContext());
+      globalBuilder.setInsertionPointToStart(mlirModule.getBody());
+      FuncOp endFunc =
+          buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
+                               FuncType::get({voidPtrPtrTy}, voidTy));
+      builder.createCallOp(loc, endFunc, gpuBinaryHandle);
+    }
+  }
+
+  // Create destructor and register it with atexit() the way NVCC does it. Doing
+  // it during regular destructor phase worked in CUDA before 9.2 but results in
+  // double-free in 9.2.
+  if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
+
+    // extern "C" int atexit(void (*f)(void));
+    cir::CIRBaseBuilderTy globalBuilder(getContext());
+    globalBuilder.setInsertionPointToStart(mlirModule.getBody());
+    FuncOp atexit = buildRuntimeFunction(
+        globalBuilder, "atexit", loc,
+        FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
+    mlir::Value dtorFunc = GetGlobalOp::create(
+        builder, loc, PointerType::get(dtor->getFunctionType()),
+        mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
+    builder.createCallOp(loc, atexit, dtorFunc);
+  }
+  cir::ReturnOp::create(builder, loc);
+}
+
+std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
+  if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
+    return {};
+
+  llvm::StringRef prefix = getCUDAPrefix(astCtx);
+
+  VoidType voidTy = VoidType::get(&getContext());
+  PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
+
+  mlir::Location loc = mlirModule.getLoc();
+
+  cir::CIRBaseBuilderTy builder(getContext());
+  builder.setInsertionPointToStart(mlirModule.getBody());
+
+  // define: void __cudaUnregisterFatBinary(void ** handle);
+  std::string unregisterFuncName =
+      addUnderscoredPrefix(prefix, "UnregisterFatBinary");
+  FuncOp unregisterFunc = buildRuntimeFunction(
+      builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
+
+  // void __cuda_module_dtor();
+  // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
+  // put into globalDtorList. If it were a real dtor, then it would cause
+  // double free above CUDA 9.2. The way to use it is to manually call
+  // atexit() at end of module ctor.
+  std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
+  FuncOp dtor =
+      buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
+                           GlobalLinkageKind::InternalLinkage);
+
+  builder.setInsertionPointToStart(dtor.addEntryBlock());
+
+  // For dtor, we only need to call:
+  //    __cudaUnregisterFatBinary(__cuda_gpubin_handle);
+
+  std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
+  GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
+  mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
+  mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
+  builder.createCallOp(loc, unregisterFunc, gpubin);
+  ReturnOp::create(builder, loc);
+
+  return dtor;
 }
 
 void LoweringPreparePass::runOnOperation() {
diff --git a/clang/test/CIR/CodeGenCUDA/device-stub.cu b/clang/test/CIR/CodeGenCUDA/device-stub.cu
index 2e9deaee9b225..4562bf1523141 100644
--- a/clang/test/CIR/CodeGenCUDA/device-stub.cu
+++ b/clang/test/CIR/CodeGenCUDA/device-stub.cu
@@ -21,6 +21,22 @@ __global__ void kernelfunc(int i, int j, int k) {}
 
 void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 
+// Check module constructor is registered in module attributes.
+// CIR: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", 65535>]
+
+// Check runtime function declarations (appear before dtor in output).
+// CIR: cir.func private @atexit(!cir.ptr<!cir.func<()>>) -> !s32i
+// CIR: cir.func private @__cudaUnregisterFatBinary(!cir.ptr<!cir.ptr<!void>>)
+
+// Check the module destructor body: load handle and call UnregisterFatBinary.
+// CIR: cir.func internal private @__cuda_module_dtor()
+// CIR-NEXT: %[[HANDLE_ADDR:.*]] = cir.get_global @__cuda_gpubin_handle
+// CIR-NEXT: %[[HANDLE:.*]] = cir.load %[[HANDLE_ADDR]]
+// CIR-NEXT: cir.call @__cudaUnregisterFatBinary(%[[HANDLE]])
+// CIR-NEXT: cir.return
+
+// CIR: cir.func private @__cudaRegisterFatBinaryEnd(!cir.ptr<!cir.ptr<!void>>)
+
 // CIR: cir.global "private" constant cir_private @__cuda_fatbin_str = #cir.const_array<"GPU binary would be here."> : !cir.array<!u8i x 25> {alignment = 8 : i64, section = ".nv_fatbin"}
 
 // Check the fatbin wrapper struct: { magic, version, ptr to fatbin, null }, with section.
@@ -34,9 +50,34 @@ void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 // Check the GPU binary handle global.
 // CIR: cir.global "private" internal @__cuda_gpubin_handle = #cir.ptr<null> : !cir.ptr<!cir.ptr<!void>>
 
+// CIR: cir.func private @__cudaRegisterFatBinary(!cir.ptr<!void>) -> !cir.ptr<!cir.ptr<!void>>
+
+// Check the module constructor body: register fatbin, store handle,
+// call RegisterFatBinaryEnd (CUDA >= 10.1), then register dtor with atexit.
+// CIR: cir.func internal private @__cuda_module_ctor()
+// CIR-NEXT: %[[WRAPPER:.*]] = cir.get_global @__cuda_fatbin_wrapper
+// CIR-NEXT: %[[VOID_PTR:.*]] = cir.cast bitcast %[[WRAPPER]]
+// CIR-NEXT: %[[RET:.*]] = cir.call @__cudaRegisterFatBinary(%[[VOID_PTR]])
+// CIR-NEXT: %[[HANDLE_ADDR:.*]] = cir.get_global @__cuda_gpubin_handle
+// CIR-NEXT: cir.store %[[RET]], %[[HANDLE_ADDR]]
+// CIR-NEXT: cir.call @__cudaRegisterFatBinaryEnd(%[[RET]])
+// CIR-NEXT: %[[DTOR_PTR:.*]] = cir.get_global @__cuda_module_dtor
+// CIR-NEXT: {{.*}} = cir.call @atexit(%[[DTOR_PTR]])
+// CIR-NEXT: cir.return
+
 // OGCG: constant [25 x i8] c"GPU binary would be here.", section ".nv_fatbin", align 8
 // OGCG: @__cuda_fatbin_wrapper = internal constant { i32, i32, ptr, ptr } { i32 1180844977, i32 1, ptr @{{.*}}, ptr null }, section ".nvFatBinSegment"
 // OGCG: @__cuda_gpubin_handle = internal global ptr null
+// OGCG: @llvm.global_ctors = appending global {{.*}}@__cuda_module_ctor
+
+// OGCG: define internal void @__cuda_module_ctor
+// OGCG: call{{.*}}__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
+// OGCG: store ptr %{{.*}}, ptr @__cuda_gpubin_handle
+// OGCG: call i32 @atexit(ptr @__cuda_module_dtor)
+
+// OGCG: define internal void @__cuda_module_dtor
+// OGCG: load ptr, ptr @__cuda_gpubin_handle
+// OGCG: call void @__cudaUnregisterFatBinary
 
 // No GPU binary — no registration infrastructure at all.
 // NOGPUBIN-NOT: fatbin

>From 031a73232f2eec25ee3915604724960c4cb78210 Mon Sep 17 00:00:00 2001
From: David Rivera <davidriverg at gmail.com>
Date: Sun, 29 Mar 2026 14:04:44 -0400
Subject: [PATCH 3/4] unreachable on RDC compilation

---
 clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp | 8 +++-----
 1 file changed, 3 insertions(+), 5 deletions(-)

diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index 78462e571e85b..d9783e3eb5983 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -29,11 +29,8 @@
 #include "clang/CIR/MissingFeatures.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/TypeSwitch.h"
-<<<<<<< HEAD
-#include "llvm/Support/MemoryBuffer.h"
-=======
 #include "llvm/IR/Instructions.h"
->>>>>>> fff0ddb60480 ([CIR][CUDA] Handle CUDA module constructor and destructor emission.)
+#include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/Path.h"
 
 #include <memory>
@@ -1866,7 +1863,8 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
                                FuncType::get({voidPtrPtrTy}, voidTy));
       builder.createCallOp(loc, endFunc, gpuBinaryHandle);
     }
-  }
+  } else
+    llvm_unreachable("GPU RDC NYI");
 
   // Create destructor and register it with atexit() the way NVCC does it. Doing
   // it during regular destructor phase worked in CUDA before 9.2 but results in

>From 7f4a61f9117053af35ca020a9d8073cf5551757f Mon Sep 17 00:00:00 2001
From: David Rivera <davidriverg at gmail.com>
Date: Fri, 27 Mar 2026 04:40:08 -0400
Subject: [PATCH 4/4] [CIR][CUDA] Do Runtime Kernel Registration

---
 .../Dialect/Transforms/LoweringPrepare.cpp    | 122 +++++++++++++++++-
 clang/test/CIR/CodeGenCUDA/device-stub.cu     |  30 ++++-
 2 files changed, 148 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
index d9783e3eb5983..5338d3685d9ab 100644
--- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
@@ -10,6 +10,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/Location.h"
 #include "mlir/IR/Value.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Mangle.h"
@@ -30,6 +31,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/Path.h"
 
@@ -126,6 +128,9 @@ struct LoweringPreparePass
   /// with the CUDA runtime.
   void buildCUDAModuleCtor();
   std::optional<FuncOp> buildCUDAModuleDtor();
+  std::optional<FuncOp> buildCUDARegisterGlobals();
+  void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
+                                        FuncOp regGlobalFunc);
 
   /// Handle static local variable initialization with guard variables.
   void handleStaticLocal(cir::GlobalOp globalOp, cir::GetGlobalOp getGlobalOp);
@@ -1746,6 +1751,7 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
   CIRBaseBuilderTy builder(getContext());
   builder.setInsertionPointToStart(mlirModule.getBody());
 
+  Type voidTy = builder.getVoidTy();
   PointerType voidPtrTy = builder.getVoidPtrTy();
   PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
   IntType intTy = builder.getSIntNTy(32);
@@ -1847,8 +1853,11 @@ void LoweringPreparePass::buildCUDAModuleCtor() {
     mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
     builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
 
-    // TODO: Generate __cuda_register_globals and emit a call.
-    assert(!cir::MissingFeatures::globalRegistration());
+    // --- Generate __cuda_register_globals and call it ---
+    std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals();
+    if (regGlobal) {
+      builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
+    }
 
     // From CUDA 10.1 onwards, we must call this function to end registration:
     //      void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
@@ -1930,6 +1939,115 @@ std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
   return dtor;
 }
 
+std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
+  // There is nothing to register.
+  if (cudaKernelMap.empty())
+    return {};
+
+  cir::CIRBaseBuilderTy builder(getContext());
+  builder.setInsertionPointToStart(mlirModule.getBody());
+
+  mlir::Location loc = mlirModule.getLoc();
+  llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
+
+  auto voidTy = VoidType::get(&getContext());
+  auto voidPtrTy = PointerType::get(voidTy);
+  auto voidPtrPtrTy = PointerType::get(voidPtrTy);
+
+  // Create the function:
+  //      void __cuda_register_globals(void **fatbinHandle)
+  std::string regGlobalFuncName =
+      addUnderscoredPrefix(cudaPrefix, "_register_globals");
+  auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
+  FuncOp regGlobalFunc =
+      buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
+                           /*linkage=*/GlobalLinkageKind::InternalLinkage);
+  builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
+
+  buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
+  // TODO: Handle shadow registration
+  assert(!cir::MissingFeatures::globalRegistration());
+
+  ReturnOp::create(builder, loc);
+  return regGlobalFunc;
+}
+
+void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
+    cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
+  mlir::Location loc = mlirModule.getLoc();
+  llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
+  cir::CIRDataLayout dataLayout(mlirModule);
+
+  auto voidTy = VoidType::get(&getContext());
+  auto voidPtrTy = PointerType::get(voidTy);
+  auto voidPtrPtrTy = PointerType::get(voidPtrTy);
+  IntType intTy = builder.getSIntNTy(32);
+  IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
+                                     /*isSigned=*/false);
+
+  // Extract the GPU binary handle argument.
+  mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
+
+  cir::CIRBaseBuilderTy globalBuilder(getContext());
+  globalBuilder.setInsertionPointToStart(mlirModule.getBody());
+
+  // Declare CUDA internal functions:
+  // int __cudaRegisterFunction(
+  //   void **fatbinHandle,
+  //   const char *hostFunc,
+  //   char *deviceFunc,
+  //   const char *deviceName,
+  //   int threadLimit,
+  //   uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
+  //   int *wsize
+  // )
+  // OG doesn't care about the types at all. They're treated as void*.
+
+  FuncOp cudaRegisterFunction = buildRuntimeFunction(
+      globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
+      FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
+                     voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
+                    intTy));
+
+  auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
+    auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
+    auto tmpString = cir::GlobalOp::create(
+        globalBuilder, loc, (".str" + str).str(), strType,
+        /*isConstant=*/true, {},
+        /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
+
+    // We must make the string zero-terminated.
+    tmpString.setInitialValueAttr(ConstArrayAttr::get(
+        strType, StringAttr::get(&getContext(), str + "\0")));
+    tmpString.setPrivate();
+    return tmpString;
+  };
+
+  cir::ConstantOp cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
+  bool isHIP = astCtx->getLangOpts().HIP;
+  for (auto kernelName : cudaKernelMap.keys()) {
+    FuncOp deviceStub = cudaKernelMap[kernelName];
+    GlobalOp deviceFuncStr = makeConstantString(kernelName);
+    mlir::Value deviceFunc = builder.createBitcast(
+        builder.createGetGlobal(deviceFuncStr), voidPtrTy);
+
+    if (isHIP) {
+      llvm_unreachable("HIP kernel registration NYI");
+    } else {
+      mlir::Value hostFunc = builder.createBitcast(
+          GetGlobalOp::create(
+              builder, loc, PointerType::get(deviceStub.getFunctionType()),
+              mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
+          voidPtrTy);
+      builder.createCallOp(
+          loc, cudaRegisterFunction,
+          {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
+           ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)),
+           cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
+    }
+  }
+}
+
 void LoweringPreparePass::runOnOperation() {
   mlir::Operation *op = getOperation();
   if (isa<::mlir::ModuleOp>(op))
diff --git a/clang/test/CIR/CodeGenCUDA/device-stub.cu b/clang/test/CIR/CodeGenCUDA/device-stub.cu
index 4562bf1523141..0f9d4d68d67ff 100644
--- a/clang/test/CIR/CodeGenCUDA/device-stub.cu
+++ b/clang/test/CIR/CodeGenCUDA/device-stub.cu
@@ -24,7 +24,7 @@ void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 // Check module constructor is registered in module attributes.
 // CIR: cir.global_ctors = [#cir.global_ctor<"__cuda_module_ctor", 65535>]
 
-// Check runtime function declarations (appear before dtor in output).
+// Check runtime function declarations.
 // CIR: cir.func private @atexit(!cir.ptr<!cir.func<()>>) -> !s32i
 // CIR: cir.func private @__cudaUnregisterFatBinary(!cir.ptr<!cir.ptr<!void>>)
 
@@ -37,6 +37,25 @@ void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 
 // CIR: cir.func private @__cudaRegisterFatBinaryEnd(!cir.ptr<!cir.ptr<!void>>)
 
+// Check the __cudaRegisterFunction runtime declaration:
+//   int __cudaRegisterFunction(void**, void*, void*, void*, int,
+//                              void*, void*, void*, void*, void*)
+// CIR: cir.func private @__cudaRegisterFunction(!cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>, !cir.ptr<!void>, !cir.ptr<!void>, !s32i, !cir.ptr<!void>, !cir.ptr<!void>, !cir.ptr<!void>, !cir.ptr<!void>, !cir.ptr<!void>) -> !s32i
+
+// Check the device-side name string for kernelfunc (mangled, null-terminated).
+// CIR: cir.global "private" constant cir_private @".str_Z10kernelfunciii" = #cir.const_array<"_Z10kernelfunciii", trailing_zeros> : !cir.array<!u8i x 18>
+
+// Check __cuda_register_globals body: one __cudaRegisterFunction call per kernel.
+// CIR: cir.func internal private @__cuda_register_globals(%arg0: !cir.ptr<!cir.ptr<!void>>
+// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!void>
+// CIR-NEXT: %[[STR_ADDR:.*]] = cir.get_global @".str_Z10kernelfunciii"
+// CIR-NEXT: %[[DEVICE_FUNC:.*]] = cir.cast bitcast %[[STR_ADDR]]
+// CIR-NEXT: %[[HOST_FUNC_RAW:.*]] = cir.get_global @{{.*}}kernelfunc{{.*}}
+// CIR-NEXT: %[[HOST_FUNC:.*]] = cir.cast bitcast %[[HOST_FUNC_RAW]]
+// CIR-NEXT: %[[THREAD_LIMIT:.*]] = cir.const #cir.int<-1> : !s32i
+// CIR-NEXT: cir.call @__cudaRegisterFunction(%{{.*}}, %[[HOST_FUNC]], %[[DEVICE_FUNC]], %[[DEVICE_FUNC]], %[[THREAD_LIMIT]], %[[NULL]], %[[NULL]], %[[NULL]], %[[NULL]], %[[NULL]])
+// CIR-NEXT: cir.return
+
 // CIR: cir.global "private" constant cir_private @__cuda_fatbin_str = #cir.const_array<"GPU binary would be here."> : !cir.array<!u8i x 25> {alignment = 8 : i64, section = ".nv_fatbin"}
 
 // Check the fatbin wrapper struct: { magic, version, ptr to fatbin, null }, with section.
@@ -53,13 +72,15 @@ void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 // CIR: cir.func private @__cudaRegisterFatBinary(!cir.ptr<!void>) -> !cir.ptr<!cir.ptr<!void>>
 
 // Check the module constructor body: register fatbin, store handle,
-// call RegisterFatBinaryEnd (CUDA >= 10.1), then register dtor with atexit.
+// call __cuda_register_globals, call RegisterFatBinaryEnd (CUDA >= 10.1),
+// then register dtor with atexit.
 // CIR: cir.func internal private @__cuda_module_ctor()
 // CIR-NEXT: %[[WRAPPER:.*]] = cir.get_global @__cuda_fatbin_wrapper
 // CIR-NEXT: %[[VOID_PTR:.*]] = cir.cast bitcast %[[WRAPPER]]
 // CIR-NEXT: %[[RET:.*]] = cir.call @__cudaRegisterFatBinary(%[[VOID_PTR]])
 // CIR-NEXT: %[[HANDLE_ADDR:.*]] = cir.get_global @__cuda_gpubin_handle
 // CIR-NEXT: cir.store %[[RET]], %[[HANDLE_ADDR]]
+// CIR-NEXT: cir.call @__cuda_register_globals(%[[RET]])
 // CIR-NEXT: cir.call @__cudaRegisterFatBinaryEnd(%[[RET]])
 // CIR-NEXT: %[[DTOR_PTR:.*]] = cir.get_global @__cuda_module_dtor
 // CIR-NEXT: {{.*}} = cir.call @atexit(%[[DTOR_PTR]])
@@ -70,9 +91,14 @@ void hostfunc(void) { kernelfunc<<<1, 1>>>(1, 1, 1); }
 // OGCG: @__cuda_gpubin_handle = internal global ptr null
 // OGCG: @llvm.global_ctors = appending global {{.*}}@__cuda_module_ctor
 
+// OGCG: define internal void @__cuda_register_globals
+// OGCG: call{{.*}}__cudaRegisterFunction(ptr %0, {{.*}}kernelfunc{{.*}}, ptr @0
+// OGCG: ret void
+
 // OGCG: define internal void @__cuda_module_ctor
 // OGCG: call{{.*}}__cudaRegisterFatBinary(ptr @__cuda_fatbin_wrapper)
 // OGCG: store ptr %{{.*}}, ptr @__cuda_gpubin_handle
+// OGCG-NEXT: call void @__cuda_register_globals
 // OGCG: call i32 @atexit(ptr @__cuda_module_dtor)
 
 // OGCG: define internal void @__cuda_module_dtor



More information about the llvm-branch-commits mailing list