[flang-commits] [flang] [flang][cuda] Pass the device address for global descriptor (PR #122802)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon Jan 13 13:56:32 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/122802
Module variables requiring a descriptors are implemented with two descriptors. One residing on the host and one on the device.
When passing a global descriptor to a kernel launch, the address of the device descriptor must be substituted so the kernel will access the descriptor on the device.
This patch insert calls to CUFGetDeviceAddress during the conversion of `cuf.kernel_launch` operation so the arguments are correct.
>From 6b1ae45b0971616987ece8ac57b7aab0f774aa39 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 13 Jan 2025 13:48:40 -0800
Subject: [PATCH] [flang][cuda] Pass the device address for global descriptor
---
.../Optimizer/Transforms/CUFOpConversion.cpp | 65 +++++++++++++------
flang/test/Fir/CUDA/cuda-launch.fir | 42 ++++++++++++
2 files changed, 86 insertions(+), 21 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 8c525fc6daff5e..d61d9f63cb2949 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,6 +366,23 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};
+static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
+ mlir::ModuleOp mod, mlir::Location loc,
+ mlir::Value inputArg) {
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, conv, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+ return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
+}
+
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;
@@ -382,26 +399,10 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
if (cuf::isRegisteredDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
auto mod = op->getParentOfType<mlir::ModuleOp>();
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::Location loc = op.getLoc();
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
- loc, builder);
- auto fTy = callee.getFunctionType();
- mlir::Type toTy = fTy.getInput(0);
- mlir::Value inputArg =
- createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
- mlir::Value sourceFile =
- fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, inputArg, sourceFile, sourceLine)};
- auto call = rewriter.create<fir::CallOp>(loc, callee, args);
- mlir::Value cast = createConvertOp(
- rewriter, loc, op.getMemref().getType(), call->getResult(0));
+ mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+ addrOfOp.getResult());
rewriter.startOpModification(op);
- op.getMemrefMutable().assign(cast);
+ op.getMemrefMutable().assign(devAddr);
rewriter.finalizeOpModification(op);
return success();
}
@@ -771,10 +772,32 @@ struct CUFLaunchOpConversion
loc, clusterDimsAttr.getZ().getInt());
}
}
+ llvm::SmallVector<mlir::Value> args;
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ for (mlir::Value arg : op.getArgs()) {
+ // If the argument is a global descriptor, make sure we pass the device
+ // copy of this descriptor and not the host one.
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(arg.getType()))) {
+ if (auto declareOp =
+ mlir::dyn_cast_or_null<fir::DeclareOp>(arg.getDefiningOp())) {
+ if (auto addrOfOp = mlir::dyn_cast_or_null<fir::AddrOfOp>(
+ declareOp.getMemref().getDefiningOp())) {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ addrOfOp.getSymbol().getRootReference().getValue())) {
+ if (cuf::isRegisteredDeviceGlobal(global)) {
+ arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
+ declareOp.getResult());
+ }
+ }
+ }
+ }
+ }
+ args.push_back(arg);
+ }
+
auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
- mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
- op.getArgs());
+ mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero, args);
if (clusterDimX && clusterDimY && clusterDimZ) {
gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index f11bcbdb7fce55..1e19b3bea1296f 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -62,3 +62,45 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
// CHECK-LABEL: func.func @_QMmod1Phost_sub()
// CHECK: gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 clusters in (%c2{{.*}}, %c2{{.*}}, %c1{{.*}})
+// -----
+
+module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
+ gpu.module @cuda_device_mod {
+ gpu.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) kernel {
+ gpu.return
+ }
+ }
+ fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> {
+ %c0 = arith.constant 0 : index
+ %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>>
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ }
+ func.func @_QMdevptrPtest(%arg0: !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "dp"}) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+ return
+ }
+ func.func @_QQmain() {
+ %c1_i32 = arith.constant 1 : i32
+ %c4 = arith.constant 4 : index
+ %0 = cuf.alloc !fir.array<4xf32> {bindc_name = "a_dev", data_attr = #cuf.cuda<device>, uniq_name = "_QFEa_dev"} -> !fir.ref<!fir.array<4xf32>>
+ %1 = fir.shape %c4 : (index) -> !fir.shape<1>
+ %2 = fir.declare %0(%1) {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<target>, uniq_name = "_QFEa_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
+ %3 = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %4 = fir.declare %3 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ %5 = fir.embox %2(%1) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>>
+ fir.store %5 to %4 : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+ cuf.sync_descriptor @_QMdevptrEdev_ptr
+ cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32>>>(%4) : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
+ cuf.free %2 : !fir.ref<!fir.array<4xf32>> {data_attr = #cuf.cuda<device>}
+ return
+ }
+}
+
+// CHECK-LABEL: func.func @_QQmain()
+// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
+// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
More information about the flang-commits
mailing list