[flang-commits] [flang] 60105ac - [flang][cuda] Fix kernel registration (#113372)

via flang-commits flang-commits at lists.llvm.org
Wed Oct 23 11:26:04 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-23T11:25:58-07:00
New Revision: 60105ac6bab130c2694fc7f5b7b6a5fddaaab752

URL: https://github.com/llvm/llvm-project/commit/60105ac6bab130c2694fc7f5b7b6a5fddaaab752
DIFF: https://github.com/llvm/llvm-project/commit/60105ac6bab130c2694fc7f5b7b6a5fddaaab752.diff

LOG: [flang][cuda] Fix kernel registration (#113372)

The registration needs the fct pointer and the name. This patch updates
the entry point with an extra arg and the translation as well.

Added: 
    

Modified: 
    flang/include/flang/Runtime/CUDA/registration.h
    flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
    flang/runtime/CUDA/registration.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
 void *RTDECL(CUFRegisterModule)(void *data);
 
 /// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName);
 
 } // extern "C"
 

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
   llvm::Type *ptrTy = builder.getPtrTy(0);
   llvm::FunctionCallee fct = module->getOrInsertFunction(
       RTNAME_STRING(CUFRegisterFunction),
-      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
-                              false));
+      llvm::FunctionType::get(
+          ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
   llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
-  builder.CreateCall(
-      fct, {modulePtr, getOrCreateFunctionName(module, builder,
-                                               op.getKernelModuleName().str(),
-                                               op.getKernelName().str())});
+  llvm::Function *fctSym =
+      moduleTranslation.lookupFunction(op.getKernelName().str());
+  builder.CreateCall(fct, {modulePtr, fctSym,
+                           getOrCreateFunctionName(
+                               module, builder, op.getKernelModuleName().str(),
+                               op.getKernelName().str())});
   return mlir::success();
 }
 

diff  --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
   return fatHandle;
 }
 
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
-  __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
-      (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+    void **module, const char *fctSym, char *fctName) {
+  __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+      (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
 }
 }
 } // namespace Fortran::runtime::cuda


        


More information about the flang-commits mailing list