[flang-commits] [flang] [flang][cuda] Fix kernel registration (PR #113372)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Oct 22 12:39:35 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/113372
The registration needs the fct pointer and the name. This patch updates the entry point with an extra arg and the translation as well.
>From 125fe851da9e090e709244e2ced67e53798f07e6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 22 Oct 2024 12:37:54 -0700
Subject: [PATCH] [flang][cuda] Fix kernel registration
---
flang/include/flang/Runtime/CUDA/registration.h | 3 ++-
.../Dialect/CUF/CUFToLLVMIRTranslation.cpp | 14 ++++++++------
flang/runtime/CUDA/registration.cpp | 7 ++++---
3 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
index cbe202c4d23e0d..009715613e29f7 100644
--- a/flang/include/flang/Runtime/CUDA/registration.h
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -20,7 +20,8 @@ extern "C" {
void *RTDECL(CUFRegisterModule)(void *data);
/// Register a device function.
-void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+void RTDECL(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName);
} // extern "C"
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
index c6c9f96b811352..63eac46a997718 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -63,13 +63,15 @@ LogicalResult registerKernel(cuf::RegisterKernelOp op,
llvm::Type *ptrTy = builder.getPtrTy(0);
llvm::FunctionCallee fct = module->getOrInsertFunction(
RTNAME_STRING(CUFRegisterFunction),
- llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
- false));
+ llvm::FunctionType::get(
+ ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy, ptrTy}), false));
llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
- builder.CreateCall(
- fct, {modulePtr, getOrCreateFunctionName(module, builder,
- op.getKernelModuleName().str(),
- op.getKernelName().str())});
+ llvm::Function *fctSym =
+ moduleTranslation.lookupFunction(op.getKernelName().str());
+ builder.CreateCall(fct, {modulePtr, fctSym,
+ getOrCreateFunctionName(
+ module, builder, op.getKernelModuleName().str(),
+ op.getKernelName().str())});
return mlir::success();
}
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
index e5d9503e95fd8f..22d43a7dc57a3a 100644
--- a/flang/runtime/CUDA/registration.cpp
+++ b/flang/runtime/CUDA/registration.cpp
@@ -26,9 +26,10 @@ void *RTDECL(CUFRegisterModule)(void *data) {
return fatHandle;
}
-void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
- __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
- (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+void RTDEF(CUFRegisterFunction)(
+ void **module, const char *fctSym, char *fctName) {
+ __cudaRegisterFunction(module, fctSym, fctName, fctName, -1, (uint3 *)0,
+ (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
}
}
} // namespace Fortran::runtime::cuda
More information about the flang-commits
mailing list