[flang-commits] [flang] [flang][cuda] Add cuf.device_address operation (PR #122975)
via flang-commits
flang-commits at lists.llvm.org
Tue Jan 14 13:53:28 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Introduce a new op to get the device address from a host symbol. This simplify the current conversion and this is also in preparation for some legalization work that need to be done in cuf kernel and cuf kernel launch similar to
https://github.com/llvm/llvm-project/pull/122802
---
Full diff: https://github.com/llvm/llvm-project/pull/122975.diff
5 Files Affected:
- (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+12)
- (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+52-24)
- (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+3)
- (modified) flang/test/Fir/CUDA/cuda-global-addr.mlir (+1)
- (modified) flang/test/Fir/CUDA/cuda-launch.fir (+3-3)
``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 6f886726b12834..a270e69b394104 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -335,4 +335,16 @@ def cuf_RegisterKernelOp : cuf_Op<"register_kernel", []> {
}];
}
+def cuf_DeviceAddressOp : cuf_Op<"device_address", []> {
+ let summary = "Get the device address from a host symbol";
+
+ let arguments = (ins SymbolRefAttr:$hostSymbol);
+
+ let assemblyFormat = [{
+ $hostSymbol attr-dict `->` type($addr)
+ }];
+
+ let results = (outs fir_ReferenceType:$addr);
+}
+
#endif // FORTRAN_DIALECT_CUF_CUF_OPS
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index d61d9f63cb2949..e93bed37d39f78 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -366,22 +366,47 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};
-static mlir::Value genGetDeviceAddress(mlir::PatternRewriter &rewriter,
- mlir::ModuleOp mod, mlir::Location loc,
- mlir::Value inputArg) {
- fir::FirOpBuilder builder(rewriter, mod);
- mlir::func::FuncOp callee =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
- auto fTy = callee.getFunctionType();
- mlir::Value conv = createConvertOp(rewriter, loc, fTy.getInput(0), inputArg);
- mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
- mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
- llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, conv, sourceFile, sourceLine)};
- auto call = rewriter.create<fir::CallOp>(loc, callee, args);
- return createConvertOp(rewriter, loc, inputArg.getType(), call->getResult(0));
-}
+struct CUFDeviceAddressOpConversion
+ : public mlir::OpRewritePattern<cuf::DeviceAddressOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ CUFDeviceAddressOpConversion(mlir::MLIRContext *context,
+ const mlir::SymbolTable &symtab)
+ : OpRewritePattern(context), symTab{symtab} {}
+
+ mlir::LogicalResult
+ matchAndRewrite(cuf::DeviceAddressOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ if (auto global = symTab.lookup<fir::GlobalOp>(
+ op.getHostSymbol().getRootReference().getValue())) {
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ mlir::Location loc = op.getLoc();
+ auto hostAddr = rewriter.create<fir::AddrOfOp>(
+ loc, fir::ReferenceType::get(global.getType()), op.getHostSymbol());
+ fir::FirOpBuilder builder(rewriter, mod);
+ mlir::func::FuncOp callee =
+ fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc,
+ builder);
+ auto fTy = callee.getFunctionType();
+ mlir::Value conv =
+ createConvertOp(rewriter, loc, fTy.getInput(0), hostAddr);
+ mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
+ mlir::Value sourceLine =
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
+ llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
+ builder, loc, fTy, conv, sourceFile, sourceLine)};
+ auto call = rewriter.create<fir::CallOp>(loc, callee, args);
+ mlir::Value addr = createConvertOp(rewriter, loc, hostAddr.getType(),
+ call->getResult(0));
+ rewriter.replaceOp(op, addr.getDefiningOp());
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ const mlir::SymbolTable &symTab;
+};
struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;
@@ -398,9 +423,8 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
addrOfOp.getSymbol().getRootReference().getValue())) {
if (cuf::isRegisteredDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
- auto mod = op->getParentOfType<mlir::ModuleOp>();
- mlir::Value devAddr = genGetDeviceAddress(rewriter, mod, op.getLoc(),
- addrOfOp.getResult());
+ mlir::Value devAddr = rewriter.create<cuf::DeviceAddressOp>(
+ op.getLoc(), addrOfOp.getType(), addrOfOp.getSymbol());
rewriter.startOpModification(op);
op.getMemrefMutable().assign(devAddr);
rewriter.finalizeOpModification(op);
@@ -773,7 +797,6 @@ struct CUFLaunchOpConversion
}
}
llvm::SmallVector<mlir::Value> args;
- auto mod = op->getParentOfType<mlir::ModuleOp>();
for (mlir::Value arg : op.getArgs()) {
// If the argument is a global descriptor, make sure we pass the device
// copy of this descriptor and not the host one.
@@ -785,8 +808,11 @@ struct CUFLaunchOpConversion
if (auto global = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (cuf::isRegisteredDeviceGlobal(global)) {
- arg = genGetDeviceAddress(rewriter, mod, op.getLoc(),
- declareOp.getResult());
+ arg = rewriter
+ .create<cuf::DeviceAddressOp>(op.getLoc(),
+ addrOfOp.getType(),
+ addrOfOp.getSymbol())
+ .getResult();
}
}
}
@@ -907,10 +933,12 @@ void cuf::populateCUFToFIRConversionPatterns(
patterns.getContext());
patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab,
&dl, &converter);
- patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<CUFLaunchOpConversion, CUFDeviceAddressOpConversion>(
+ patterns.getContext(), symtab);
}
void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns) {
- patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
+ patterns.insert<DeclareOpConversion, CUFDeviceAddressOpConversion>(
+ patterns.getContext(), symtab);
}
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 7203c33e7eb11f..5ed27f1be0a430 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -198,6 +198,7 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
// CHECK-LABEL: func.func @_QPsub8()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -222,6 +223,7 @@ func.func @_QPsub9() {
// CHECK-LABEL: func.func @_QPsub9()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
+// CHECK: fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
@@ -380,6 +382,7 @@ func.func @_QPdevice_addr_conv() {
}
// CHECK-LABEL: func.func @_QPdevice_addr_conv()
+// CHECK: fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-global-addr.mlir b/flang/test/Fir/CUDA/cuda-global-addr.mlir
index 94ee74736f6508..0ccd0c797fb6f5 100644
--- a/flang/test/Fir/CUDA/cuda-global-addr.mlir
+++ b/flang/test/Fir/CUDA/cuda-global-addr.mlir
@@ -26,6 +26,7 @@ func.func @_QQmain() attributes {fir.bindc_name = "test"} {
}
// CHECK-LABEL: func.func @_QQmain()
+// CHECK: fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
diff --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 1e19b3bea1296f..8432b9ec926e38 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -98,9 +98,9 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
}
// CHECK-LABEL: func.func @_QQmain()
+// CHECK: _FortranACUFSyncGlobalDescriptor
// CHECK: %[[ADDROF:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[DECL:.*]] = fir.declare %[[ADDROF]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QMdevptrEdev_ptr"} : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
-// CHECK: %[[CONV_DECL:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_DECL]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
+// CHECK: %[[CONV_ADDR:.*]] = fir.convert %[[ADDROF]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[DEVADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[CONV_ADDR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[CONV_DEVADDR:.*]] = fir.convert %[[DEVADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>
// CHECK: gpu.launch_func @cuda_device_mod::@_QMdevptrPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) dynamic_shared_memory_size %{{.*}} args(%[[CONV_DEVADDR]] : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>)
``````````
</details>
https://github.com/llvm/llvm-project/pull/122975
More information about the flang-commits
mailing list