[flang-commits] [flang] 5406834 - [flang][cuda] Add cuf.register_module operation (#112971)
via flang-commits
flang-commits at lists.llvm.org
Fri Oct 18 21:30:42 PDT 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-10-18T21:30:38-07:00
New Revision: 5406834cdaa6d26b98484d634df579606ae02229
URL: https://github.com/llvm/llvm-project/commit/5406834cdaa6d26b98484d634df579606ae02229
DIFF: https://github.com/llvm/llvm-project/commit/5406834cdaa6d26b98484d634df579606ae02229.diff
LOG: [flang][cuda] Add cuf.register_module operation (#112971)
Add a new operation to register the fatbin and pass it to
`cuf.register_kernel`
Added:
Modified:
flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
flang/test/Fir/CUDA/cuda-register-func.fir
flang/test/Fir/cuf-invalid.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
index 4132db672e394d..1edded090f8ce1 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.h
@@ -12,6 +12,7 @@
#include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.h"
#include "flang/Optimizer/Dialect/CUF/CUFDialect.h"
#include "flang/Optimizer/Dialect/FIRType.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/OpDefinition.h"
#define GET_OP_CLASSES
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 98d1ef529738c7..d34a8af0394a44 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -18,6 +18,7 @@ include "flang/Optimizer/Dialect/CUF/CUFDialect.td"
include "flang/Optimizer/Dialect/CUF/Attributes/CUFAttr.td"
include "flang/Optimizer/Dialect/FIRTypes.td"
include "flang/Optimizer/Dialect/FIRAttr.td"
+include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"
@@ -288,15 +289,30 @@ def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,
let hasVerifier = 1;
}
+def cuf_RegisterModuleOp : cuf_Op<"register_module", []> {
+ let summary = "Register a CUDA module";
+
+ let arguments = (ins
+ SymbolRefAttr:$name
+ );
+
+ let assemblyFormat = [{
+ $name attr-dict `->` type($modulePtr)
+ }];
+
+ let results = (outs LLVM_AnyPointer:$modulePtr);
+}
+
def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
let summary = "Register a CUDA kernel";
let arguments = (ins
- SymbolRefAttr:$name
+ SymbolRefAttr:$name,
+ LLVM_AnyPointer:$modulePtr
);
let assemblyFormat = [{
- $name attr-dict
+ $name `(` $modulePtr `:` type($modulePtr) `)`attr-dict
}];
let hasVerifier = 1;
diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
index 3db24226e75042..f260437e710417 100644
--- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp
@@ -62,12 +62,15 @@ struct CUFAddConstructor
// Register kernels
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaModName);
if (gpuMod) {
+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
+ auto registeredMod = builder.create<cuf::RegisterModuleOp>(
+ loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaModName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
- builder.create<cuf::RegisterKernelOp>(loc, kernelName);
+ builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
}
}
diff --git a/flang/test/Fir/CUDA/cuda-register-func.fir b/flang/test/Fir/CUDA/cuda-register-func.fir
index 277475f0883dcc..6b0cbfd3aca63d 100644
--- a/flang/test/Fir/CUDA/cuda-register-func.fir
+++ b/flang/test/Fir/CUDA/cuda-register-func.fir
@@ -12,5 +12,6 @@ module attributes {gpu.container_module} {
}
// CHECK-LABEL: llvm.func internal @__cudaFortranConstructor()
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1
-// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+// CHECK: %[[MOD_HANDLE:.*]] = cuf.register_module @cuda_device_mod -> !llvm.ptr
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%[[MOD_HANDLE]] : !llvm.ptr)
+// CHECK: cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%[[MOD_HANDLE]] : !llvm.ptr)
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 8a1eb48576832c..a3b9be3ee8223b 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -135,8 +135,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
+ %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error at +1{{'cuf.register_kernel' op only kernel gpu.func can be registered}}
- cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+ cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
@@ -150,8 +151,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
+ %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error at +1{{'cuf.register_kernel' op device function not found}}
- cuf.register_kernel @cuda_device_mod::@_QPsub_device2
+ cuf.register_kernel @cuda_device_mod::@_QPsub_device2(%0 : !llvm.ptr)
llvm.return
}
}
@@ -160,8 +162,9 @@ module attributes {gpu.container_module} {
module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
+ %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error at +1{{'cuf.register_kernel' op gpu module not found}}
- cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+ cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
@@ -170,8 +173,9 @@ module attributes {gpu.container_module} {
module attributes {gpu.container_module} {
llvm.func internal @__cudaFortranConstructor() {
+ %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error at +1{{'cuf.register_kernel' op expect a module and a kernel name}}
- cuf.register_kernel @_QPsub_device1
+ cuf.register_kernel @_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
@@ -185,8 +189,9 @@ module attributes {gpu.container_module} {
}
}
llvm.func internal @__cudaFortranConstructor() {
+ %0 = cuf.register_module @cuda_device_mod -> !llvm.ptr
// expected-error at +1{{'cuf.register_kernel' op only gpu.kernel llvm.func can be registered}}
- cuf.register_kernel @cuda_device_mod::@_QPsub_device1
+ cuf.register_kernel @cuda_device_mod::@_QPsub_device1(%0 : !llvm.ptr)
llvm.return
}
}
More information about the flang-commits
mailing list