[flang-commits] [flang] 60105ac - [flang][cuda] Fix kernel registration (#113372)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 23 11:26:04 PDT 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-23T11:25:58-07:00
New Revision: 60105ac6bab130c2694fc7f5b7b6a5fddaaab752
URL: https://github.com/llvm/llvm-project/commit/60105ac6bab130c2694fc7f5b7b6a5fddaaab752
DIFF: https://github.com/llvm/llvm-project/commit/60105ac6bab130c2694fc7f5b7b6a5fddaaab752.diff
LOG: [flang][cuda] Fix kernel registration (#113372)
The registration needs the fct pointer and the name. This patch updates
the entry point with an extra arg and the translation as well.
Added:
Modified:
flang/include/flang/Runtime/CUDA/registration.h
flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
flang/runtime/CUDA/registration.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
void *RTDECL(CUFRegisterModule)(void *data);
/// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName);
} // extern "C"
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
- llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
- false));
+ llvm::FunctionType::get(
+ ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
- builder.CreateCall(
- fct, {modulePtr, getOrCreateFunctionName(module, builder,
- op.getKernelModuleName().str(),
- op.getKernelName().str())});
+ llvm::Function *fctSym =
+ moduleTranslation.lookupFunction(op.getKernelName().str());
+ builder.CreateCall(fct, {modulePtr, fctSym,
+ getOrCreateFunctionName(
+ module, builder, op.getKernelModuleName().str(),
+ op.getKernelName().str())});
return mlir::success();
}
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
return fatHandle;
}
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
- __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
- (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName) {
+ __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+ (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
More information about the flang-commits
mailing list