[flang-commits] [flang] [flang][cuda] Add cuf.register_module operation (PR #112971)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Oct 18 13:25:34 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/112971

Add a new operation to register the fatbin and pass it to `cuf.register_kernel`

>From 9dcc52b0fc70a9aee16d0da560d7857a2bf74f7a Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 17 Oct 2024 13:20:59 -0700
Subject: [PATCH] [flang][cuda] Add cuf.register_module operation

---
 .../flang/Optimizer/Dialect/CUF/CUFOps.h      |  1 +
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     | 20 +++++++++++++++++--
 .../Transforms/CUFAddConstructor.cpp          |  5 ++++-
 flang/test/Fir/CUDA/cuda-register-func.fir    |  5 +++--
 flang/test/Fir/cuf-invalid.fir                | 15 +++++++++-----
 5 files changed, 36 insertions(+), 10 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
index 4132db672e394d..1edded090f8ce1 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
@@ -12,6 +12,7 @@
 #include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
 #include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/OpDefinition.h"
 
 #define GET_OP_CLASSES
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 98d1ef529738c7..d34a8af0394a44 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
 include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
 include "flang/Optimizer/Dialect/FIRTypes.td"
 include "flang/Optimizer/Dialect/FIRAttr.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/BuiltinAttributes.td"
 
@@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
   let hasVerifier = 1;
 }
 
+def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
+  let summary = "Register a CUDA module";
+
+  let arguments = (ins
+    SymbolRefAttr:$name
+  );
+
+  let assemblyFormat = [{
+    $name attr-dict `->` type($modulePtr)
+  }];
+
+  let results = (outs LLVM_AnyPointer:$modulePtr);
+}
+
 def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
   let summary = "Register a CUDA kernel";
 
   let arguments = (ins
-    SymbolRefAttr:$name
+    SymbolRefAttr:$name,
+    LLVM_AnyPointer:$modulePtr
   );
 
   let assemblyFormat = [{
-    $name attr-dict
+    $name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
index 3db24226e75042..f260437e710417 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
@@ -62,12 +62,15 @@ struct CUFAddConstructor
     // Register kernels
     auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
     if (gpuMod) {
+      auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
+      auto registeredMod = builder.create<cuf::RegisterModuleOp>(
+          loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
       for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
         if (func.isKernel()) {
           auto kernelName = mlir::SymbolRefAttr::get(
               builder.getStringAttr(cudaModName),
               {mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
-          builder.create<cuf::RegisterKernelOp>(loc, kernelName);
+          builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
         }
       }
     }
diff --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir
index 277475f0883dcc..6b0cbfd3aca63d 100644
--- a/flang/test/Fir/CUDA/cuda-register-func.fir
+++ b/flang/test/Fir/CUDA/cuda-register-func.fir
@@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
 }
 
 // CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 8a1eb48576832c..a3b9be3ee8223b 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op device function not found}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -160,8 +162,9 @@ module attributes {gpu.container_module} {
 
 module attributes {gpu.container_module} {
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op gpu module not found}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -170,8 +173,9 @@ module attributes {gpu.container_module} {
 
 module attributes {gpu.container_module} {
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op expect a module and a kernel name}}
-    cuf.register_kernel @_QPsub_device1
+    cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }
@@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
     }
   }
   llvm.func internal @__cudaFortranConstructor() {
+    %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
     // expected-error at +1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
-    cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+    cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
     llvm.return
   }
 }



More information about the flang-commits mailing list