[flang-commits] [flang] d37bc32 - [flang][cuda] Translate cuf.register_kernel and cuf.register_module (#112972)
via flang-commits
flang-commits at lists.llvm.org
Fri Oct 18 21:31:49 PDT 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-18T21:31:47-07:00
New Revision: d37bc32a65651e647148236ffb9728ea2e77eac3
URL: https://github.com/llvm/llvm-project/commit/d37bc32a65651e647148236ffb9728ea2e77eac3
DIFF: https://github.com/llvm/llvm-project/commit/d37bc32a65651e647148236ffb9728ea2e77eac3.diff
LOG: [flang][cuda] Translate cuf.register_kernel and cuf.register_module (#112972)
Add LLVM IR Translation for `cuf.register_module` and
`cuf.register_kernel`. These are lowered to function call to the CUF
runtime entries.
Added:
flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
flang/include/flang/Runtime/CUDA/registration.h
flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
flang/runtime/CUDA/registration.cpp
Modified:
flang/include/flang/Optimizer/Support/InitFIR.h
flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
flang/lib/Optimizer/Transforms/CufOpConversion.cpp
flang/runtime/CUDA/CMakeLists.txt
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
new file mode 100644
index 00000000000000..f3edb7fca649d0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
@@ -0,0 +1,29 @@
+//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for GPU dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+} // namespace mlir
+
+namespace cuf {
+
+/// Register the CUF dialect and the translation from it to the LLVM IR in
+/// the given registry.
+void registerCUFDialectTranslation(mlir::DialectRegistry ®istry);
+
+} // namespace cuf
+
+#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
diff --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 04a5dd323e5508..1c61c367199923 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -14,6 +14,7 @@
#define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "mlir/Conversion/Passes.h"
@@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry ®istry,
if (addFIRInlinerInterface)
addFIRInlinerExtension(registry);
addFIRToLLVMIRExtension(registry);
+ cuf::registerCUFDialectTranslation(registry);
}
inline void loadNonCodegenDialects(mlir::MLIRContext &context) {
diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
new file mode 100644
index 00000000000000..cbe202c4d23e0d
--- /dev/null
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -0,0 +1,28 @@
+//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+
+#include "flang/Runtime/entry-names.h"
+#include <cstddef>
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+/// Register a CUDA module.
+void *RTDECL(CUFRegisterModule)(void *data);
+
+/// Register a device function.
+void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+
+} // extern "C"
+
+} // namespace Fortran::runtime::cuda
+#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
diff --git a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
index b2221199995d58..5d4bd0785971f7 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Attributes)
add_flang_library(CUFDialect
CUFDialect.cpp
CUFOps.cpp
+ CUFToLLVMIRTranslation.cpp
DEPENDS
MLIRIR
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..c6c9f96b811352
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -0,0 +1,104 @@
+//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR CUF dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult registerModule(cuf::RegisterModuleOp op,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ std::string binaryIdentifier =
+ op.getName().getLeafReference().str() + "_bin_cst";
+ llvm::Module *module = moduleTranslation.getLLVMModule();
+ llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
+ if (!binary)
+ return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
+
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::FunctionCallee fct = module->getOrInsertFunction(
+ RTNAME_STRING(CUFRegisterModule),
+ llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
+ auto *handle = builder.CreateCall(fct, {binary});
+ moduleTranslation.mapValue(op->getResults().front()) = handle;
+ return mlir::success();
+}
+
+llvm::Value *getOrCreateFunctionName(llvm::Module *module,
+ llvm::IRBuilderBase &builder,
+ llvm::StringRef moduleName,
+ llvm::StringRef kernelName) {
+ std::string globalName =
+ std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+
+ if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
+ return gv;
+
+ return builder.CreateGlobalString(kernelName, globalName);
+}
+
+LogicalResult registerKernel(cuf::RegisterKernelOp op,
+ llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::Module *module = moduleTranslation.getLLVMModule();
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::FunctionCallee fct = module->getOrInsertFunction(
+ RTNAME_STRING(CUFRegisterFunction),
+ llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
+ false));
+ llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
+ builder.CreateCall(
+ fct, {modulePtr, getOrCreateFunctionName(module, builder,
+ op.getKernelModuleName().str(),
+ op.getKernelName().str())});
+ return mlir::success();
+}
+
+class CUFDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ LogicalResult
+ convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) const override {
+ return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
+ .Case([&](cuf::RegisterModuleOp op) {
+ return registerModule(op, builder, moduleTranslation);
+ })
+ .Case([&](cuf::RegisterKernelOp op) {
+ return registerKernel(op, builder, moduleTranslation);
+ })
+ .Default([&](Operation *op) {
+ return op->emitError("unsupported GPU operation: ") << op->getName();
+ });
+ }
+};
+
+} // namespace
+
+void cuf::registerCUFDialectTranslation(DialectRegistry ®istry) {
+ registry.insert<cuf::CUFDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
+ dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
+ });
+}
diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 9df559ee0ab1f8..629f0c69f8cb5d 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -20,6 +20,7 @@
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 193dd77e934558..86523b419f8711 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
allocatable.cpp
descriptor.cpp
memory.cpp
+ registration.cpp
)
if (BUILD_SHARED_LIBS)
diff --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
new file mode 100644
index 00000000000000..aed275e964680e
--- /dev/null
+++ b/flang/runtime/CUDA/registration.cpp
@@ -0,0 +1,31 @@
+//===-- runtime/CUDA/registration.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Runtime/CUDA/registration.h"
+
+#include "cuda_runtime.h"
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+extern void **__cudaRegisterFatBinary(void *data);
+extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
+ char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
+ uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);
+
+void *RTDECL(CUFRegisterModule)(void *data) {
+ return __cudaRegisterFatBinary(data);
+}
+
+void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
+ __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
+ (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+}
+}
+} // namespace Fortran::runtime::cuda
More information about the flang-commits
mailing list