[flang-commits] [flang] 8588014 - [flang][cuda] Add kernel registration in CUF constructor (#112416)

via flang-commits flang-commits at lists.llvm.org
Tue Oct 15 14:18:41 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-15T14:18:37-07:00
New Revision: 85880140be35cdcdcad53cbb7255a85d5634af88

URL: https://github.com/llvm/llvm-project/commit/85880140be35cdcdcad53cbb7255a85d5634af88
DIFF: https://github.com/llvm/llvm-project/commit/85880140be35cdcdcad53cbb7255a85d5634af88.diff

LOG: [flang][cuda] Add kernel registration in CUF constructor (#112416)

Update the CUF constructor with the cuf.register_kernel operations.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Optimizer/Transforms/CMakeLists.txt
    flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
    flang/test/Fir/CUDA/cuda-register-func.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index bf75123e853779..af6bd41cbb71da 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -439,7 +439,7 @@ def CufImplicitDeviceGlobal :
 def CUFAddConstructor : Pass<"cuf-add-constructor", "mlir::ModuleOp"> {
   let summary = "Add constructor to register CUDA Fortran allocators";
   let dependentDialects = [
-    "mlir::func::FuncDialect"
+    "cuf::CUFDialect", "mlir::func::FuncDialect"
   ];
 }
 

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 5e1a0293e63c97..352fe4cbe09e99 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -49,6 +49,7 @@ add_flang_library(FIRTransforms
   HLFIRDialect
   MLIRAffineUtils
   MLIRFuncDialect
+  MLIRGPUDialect
   MLIRLLVMDialect
   MLIRLLVMCommonConversion
   MLIRMathTransforms

diff  --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
index 48620fbc585861..3db24226e75042 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
@@ -12,6 +12,7 @@
 #include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Runtime/entry-names.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "llvm/ADT/SmallVector.h"
@@ -23,6 +24,8 @@ namespace fir {
 
 namespace {
 
+static constexpr llvm::StringRef cudaModName{"cuda_device_mod"};
+
 static constexpr llvm::StringRef cudaFortranCtorName{
     "__cudaFortranConstructor"};
 
@@ -31,6 +34,7 @@ struct CUFAddConstructor
 
   void runOnOperation() override {
     mlir::ModuleOp mod = getOperation();
+    mlir::SymbolTable symTab(mod);
     mlir::OpBuilder builder{mod.getBodyRegion()};
     builder.setInsertionPointToEnd(mod.getBody());
     mlir::Location loc = mod.getLoc();
@@ -48,13 +52,25 @@ struct CUFAddConstructor
         mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
     builder.setInsertionPointToEnd(mod.getBody());
 
-    // Create the constructor function that cal CUFRegisterAllocator.
-    builder.setInsertionPointToEnd(mod.getBody());
+    // Create the constructor function that call CUFRegisterAllocator.
     auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
                                                        funcTy);
     func.setLinkage(mlir::LLVM::Linkage::Internal);
     builder.setInsertionPointToStart(func.addEntryBlock(builder));
     builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
+
+    // Register kernels
+    auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
+    if (gpuMod) {
+      for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
+        if (func.isKernel()) {
+          auto kernelName = mlir::SymbolRefAttr::get(
+              builder.getStringAttr(cudaModName),
+              {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
+          builder.create<cuf::RegisterKernelOp>(loc, kernelName);
+        }
+      }
+    }
     builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
 
     // Create the llvm.global_ctor with the function.

diff  --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir
index a428f68eb3bf42..277475f0883dcc 100644
--- a/flang/test/Fir/CUDA/cuda-register-func.fir
+++ b/flang/test/Fir/CUDA/cuda-register-func.fir
@@ -1,4 +1,4 @@
-// RUN: fir-opt %s | FileCheck %s
+// RUN: fir-opt --cuf-add-constructor %s | FileCheck %s
 
 module attributes {gpu.container_module} {
   gpu.module @cuda_device_mod {
@@ -9,12 +9,8 @@ module attributes {gpu.container_module} {
       gpu.return
     }
   }
-  llvm.func internal @__cudaFortranConstructor() {
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device2
-    llvm.return
-  }
 }
 
+// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
 // CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
 // CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2


        


More information about the flang-commits mailing list