[flang-commits] [flang] 5406834 - [flang][cuda] Add cuf.register_module operation (#112971)

via flang-commits flang-commits at lists.llvm.org
Fri Oct 18 21:30:42 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-18T21:30:38-07:00
New Revision: 5406834cdaa6d26b98484d634df579606ae02229

URL: https://github.com/llvm/llvm-project/commit/5406834cdaa6d26b98484d634df579606ae02229
DIFF: https://github.com/llvm/llvm-project/commit/5406834cdaa6d26b98484d634df579606ae02229.diff

LOG: [flang][cuda] Add cuf.register_module operation (#112971)

Add a new operation to register the fatbin and pass it to
`cuf.register_kernel`

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
    flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
    flang/test/Fir/CUDA/cuda-register-func.fir
    flang/test/Fir/cuf-invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
index 4132db672e394d..1edded090f8ce1 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
@@ -12,6 +12,7 @@
 #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/OpDefinition.h"
 
 #define GET_OP_CLASSES

diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 98d1ef529738c7..d34a8af0394a44 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
 include "flang/Optimizer/Dialect/FIRAttr.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 
@@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
   let hasVerifier = 1;
 }
 
+def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
+  let summary = "Register a CUDA module";
+
+  let arguments = (ins
+    SymbolRefAttr:$name
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `->` type($modulePtr)
+  }];
+
+  let results = (outs LLVM_AnyPointer:$modulePtr);
+}
+
 def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
   let summary = "Register a CUDA kernel";
 
   let arguments = (ins
-    SymbolRefAttr:$name
+    SymbolRefAttr:$name,
+    LLVM_AnyPointer:$modulePtr
   );
 
   let assemblyFormat = [{
-    $name attr-dict
+    $name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
   }];
 
   let hasVerifier = 1;

diff  --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
index 3db24226e75042..f260437e710417 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
@@ -62,12 +62,15 @@ struct CUFAddConstructor
     // Register kernels
     auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
     if (gpuMod) {
+      auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
+      auto registeredMod = builder.create<cuf::RegisterModuleOp>(
+          loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
       for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
         if (func.isKernel()) {
           auto kernelName = mlir::SymbolRefAttr::get(
               builder.getStringAttr(cudaModName),
               {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
-          builder.create<cuf::RegisterKernelOp>(loc, kernelName);
+          builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
         }
       }
     }

diff  --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir
index 277475f0883dcc..6b0cbfd3aca63d 100644
--- a/flang/test/Fir/CUDA/cuda-register-func.fir
+++ b/flang/test/Fir/CUDA/cuda-register-func.fir
@@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
 }
 
 // CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)

diff  --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 8a1eb48576832c..a3b9be3ee8223b 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op device function not found}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -160,8 +162,9 @@ module attributes {gpu.container_module} {
 
 module attributes {gpu.container_module} {
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op gpu module not found}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -170,8 +173,9 @@ module attributes {gpu.container_module} {
 
 module attributes {gpu.container_module} {
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op expect a module and a kernel name}}
-    cuf.register_kernel @_QPsub_device1
+    cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }


        


More information about the flang-commits mailing list