[flang-commits] [flang] d37bc32 - [flang][cuda] Translate cuf.register_kernel and cuf.register_module (#112972)

via flang-commits flang-commits at lists.llvm.org
Fri Oct 18 21:31:49 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-18T21:31:47-07:00
New Revision: d37bc32a65651e647148236ffb9728ea2e77eac3

URL: https://github.com/llvm/llvm-project/commit/d37bc32a65651e647148236ffb9728ea2e77eac3
DIFF: https://github.com/llvm/llvm-project/commit/d37bc32a65651e647148236ffb9728ea2e77eac3.diff

LOG: [flang][cuda] Translate cuf.register_kernel and cuf.register_module (#112972)

Add LLVM IR Translation for `cuf.register_module` and
`cuf.register_kernel`. These are lowered to function call to the CUF
runtime entries.

Added: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
    flang/include/flang/Runtime/CUDA/registration.h
    flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
    flang/runtime/CUDA/registration.cpp

Modified: 
    flang/include/flang/Optimizer/Support/InitFIR.h
    flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
    flang/lib/Optimizer/Transforms/CufOpConversion.cpp
    flang/runtime/CUDA/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
new file mode 100644
index 00000000000000..f3edb7fca649d0
--- /dev/null
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h
@@ -0,0 +1,29 @@
+//===- CUFToLLVMIRTranslation.h - CUF Dialect to LLVM IR --------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This provides registration calls for GPU dialect to LLVM IR translation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+#define FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_
+
+namespace mlir {
+class DialectRegistry;
+class MLIRContext;
+} // namespace mlir
+
+namespace cuf {
+
+/// Register the CUF dialect and the translation from it to the LLVM IR in
+/// the given registry.
+void registerCUFDialectTranslation(mlir::DialectRegistry &registry);
+
+} // namespace cuf
+
+#endif // FLANG_OPTIMIZER_DIALECT_CUF_GPUTOLLVMIRTRANSLATION_H_

diff  --git a/flang/include/flang/Optimizer/Support/InitFIR.h b/flang/include/flang/Optimizer/Support/InitFIR.h
index 04a5dd323e5508..1c61c367199923 100644
--- a/flang/include/flang/Optimizer/Support/InitFIR.h
+++ b/flang/include/flang/Optimizer/Support/InitFIR.h
@@ -14,6 +14,7 @@
 #define FORTRAN_OPTIMIZER_SUPPORT_INITFIR_H
 
 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "mlir/Conversion/Passes.h"
@@ -61,6 +62,7 @@ inline void addFIRExtensions(mlir::DialectRegistry &registry,
   if (addFIRInlinerInterface)
     addFIRInlinerExtension(registry);
   addFIRToLLVMIRExtension(registry);
+  cuf::registerCUFDialectTranslation(registry);
 }
 
 inline void loadNonCodegenDialects(mlir::MLIRContext &context) {

diff  --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h
new file mode 100644
index 00000000000000..cbe202c4d23e0d
--- /dev/null
+++ b/flang/include/flang/Runtime/CUDA/registration.h
@@ -0,0 +1,28 @@
+//===-- include/flang/Runtime/CUDA/registration.h ---------------*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+#define FORTRAN_RUNTIME_CUDA_REGISTRATION_H_
+
+#include "flang/Runtime/entry-names.h"
+#include <cstddef>
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+/// Register a CUDA module.
+void *RTDECL(CUFRegisterModule)(void *data);
+
+/// Register a device function.
+void RTDECL(CUFRegisterFunction)(void **module, const char *fct);
+
+} // extern "C"
+
+} // namespace Fortran::runtime::cuda
+#endif // FORTRAN_RUNTIME_CUDA_REGISTRATION_H_

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
index b2221199995d58..5d4bd0785971f7 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
+++ b/flang/lib/Optimizer/Dialect/CUF/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(Attributes)
 add_flang_library(CUFDialect
   CUFDialect.cpp
   CUFOps.cpp
+  CUFToLLVMIRTranslation.cpp
 
   DEPENDS
   MLIRIR

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
new file mode 100644
index 00000000000000..c6c9f96b811352
--- /dev/null
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.cpp
@@ -0,0 +1,104 @@
+//===- CUFToLLVMIRTranslation.cpp - Translate CUF dialect to LLVM IR ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR CUF dialect and LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/CUF/CUFToLLVMIRTranslation.h"
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
+#include "flang/Runtime/entry-names.h"
+#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/FormatVariadic.h"
+
+using namespace mlir;
+
+namespace {
+
+LogicalResult registerModule(cuf::RegisterModuleOp op,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  std::string binaryIdentifier =
+      op.getName().getLeafReference().str() + "_bin_cst";
+  llvm::Module *module = moduleTranslation.getLLVMModule();
+  llvm::Value *binary = module->getGlobalVariable(binaryIdentifier, true);
+  if (!binary)
+    return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
+
+  llvm::Type *ptrTy = builder.getPtrTy(0);
+  llvm::FunctionCallee fct = module->getOrInsertFunction(
+      RTNAME_STRING(CUFRegisterModule),
+      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy}), false));
+  auto *handle = builder.CreateCall(fct, {binary});
+  moduleTranslation.mapValue(op->getResults().front()) = handle;
+  return mlir::success();
+}
+
+llvm::Value *getOrCreateFunctionName(llvm::Module *module,
+                                     llvm::IRBuilderBase &builder,
+                                     llvm::StringRef moduleName,
+                                     llvm::StringRef kernelName) {
+  std::string globalName =
+      std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, kernelName));
+
+  if (llvm::GlobalVariable *gv = module->getGlobalVariable(globalName))
+    return gv;
+
+  return builder.CreateGlobalString(kernelName, globalName);
+}
+
+LogicalResult registerKernel(cuf::RegisterKernelOp op,
+                             llvm::IRBuilderBase &builder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::Module *module = moduleTranslation.getLLVMModule();
+  llvm::Type *ptrTy = builder.getPtrTy(0);
+  llvm::FunctionCallee fct = module->getOrInsertFunction(
+      RTNAME_STRING(CUFRegisterFunction),
+      llvm::FunctionType::get(ptrTy, ArrayRef<llvm::Type *>({ptrTy, ptrTy}),
+                              false));
+  llvm::Value *modulePtr = moduleTranslation.lookupValue(op.getModulePtr());
+  builder.CreateCall(
+      fct, {modulePtr, getOrCreateFunctionName(module, builder,
+                                               op.getKernelModuleName().str(),
+                                               op.getKernelName().str())});
+  return mlir::success();
+}
+
+class CUFDialectLLVMIRTranslationInterface
+    : public LLVMTranslationDialectInterface {
+public:
+  using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+  LogicalResult
+  convertOperation(Operation *operation, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) const override {
+    return llvm::TypeSwitch<Operation *, LogicalResult>(operation)
+        .Case([&](cuf::RegisterModuleOp op) {
+          return registerModule(op, builder, moduleTranslation);
+        })
+        .Case([&](cuf::RegisterKernelOp op) {
+          return registerKernel(op, builder, moduleTranslation);
+        })
+        .Default([&](Operation *op) {
+          return op->emitError("unsupported GPU operation: ") << op->getName();
+        });
+  }
+};
+
+} // namespace
+
+void cuf::registerCUFDialectTranslation(DialectRegistry &registry) {
+  registry.insert<cuf::CUFDialect>();
+  registry.addExtension(+[](MLIRContext *ctx, cuf::CUFDialect *dialect) {
+    dialect->addInterfaces<CUFDialectLLVMIRTranslationInterface>();
+  });
+}

diff  --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
index 9df559ee0ab1f8..629f0c69f8cb5d 100644
--- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp
@@ -20,6 +20,7 @@
 #include "flang/Runtime/CUDA/descriptor.h"
 #include "flang/Runtime/CUDA/memory.h"
 #include "flang/Runtime/allocatable.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"

diff  --git a/flang/runtime/CUDA/CMakeLists.txt b/flang/runtime/CUDA/CMakeLists.txt
index 193dd77e934558..86523b419f8711 100644
--- a/flang/runtime/CUDA/CMakeLists.txt
+++ b/flang/runtime/CUDA/CMakeLists.txt
@@ -18,6 +18,7 @@ add_flang_library(${CUFRT_LIBNAME}
   allocatable.cpp
   descriptor.cpp
   memory.cpp
+  registration.cpp
 )
 
 if (BUILD_SHARED_LIBS)

diff  --git a/flang/runtime/CUDA/registration.cpp b/flang/runtime/CUDA/registration.cpp
new file mode 100644
index 00000000000000..aed275e964680e
--- /dev/null
+++ b/flang/runtime/CUDA/registration.cpp
@@ -0,0 +1,31 @@
+//===-- runtime/CUDA/registration.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Runtime/CUDA/registration.h"
+
+#include "cuda_runtime.h"
+
+namespace Fortran::runtime::cuda {
+
+extern "C" {
+
+extern void **__cudaRegisterFatBinary(void *data);
+extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun,
+    char *deviceFun, const char *deviceName, int thread_limit, uint3 *tid,
+    uint3 *bid, dim3 *bDim, dim3 *gDim, int *wSize);
+
+void *RTDECL(CUFRegisterModule)(void *data) {
+  return __cudaRegisterFatBinary(data);
+}
+
+void RTDEF(CUFRegisterFunction)(void **module, const char *fct) {
+  __cudaRegisterFunction(module, fct, const_cast<char *>(fct), fct, -1,
+      (uint3 *)0, (uint3 *)0, (dim3 *)0, (dim3 *)0, (int *)0);
+}
+}
+} // namespace Fortran::runtime::cuda


        


More information about the flang-commits mailing list