[flang-commits] [flang] [flang][cuda] Avoid assign element mismatch when doing data transfer from a constant (PR #128252)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Fri Feb 21 15:55:44 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/128252
Currently when we do a CUDA data transfer from a constant, we embox it and delegate the assignment to the runtime. When the type of the constant is not exactly the same as the destination descriptor, the runtime will emit an assignment mismatch error.
Convert the constant when necessary so the assignment is fine.
>From 20c03043d7a215229a42be42b5d733eba9a72717 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 21 Feb 2025 15:35:43 -0800
Subject: [PATCH] [flang][cuda] Avoid assign element mismatch when doing data
trabsfer from a constant
---
.../Optimizer/Transforms/CUFOpConversion.cpp | 26 ++++++++++++++-----
flang/test/Fir/CUDA/cuda-data-transfer.fir | 22 ++++++++++++++++
2 files changed, 41 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 1f0576aa82f83..2ab2d84f1643d 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
cuf::DataTransferOp op,
- const mlir::SymbolTable &symtab) {
+ const mlir::SymbolTable &symtab,
+ mlir::Type dstEleTy = nullptr) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
// from a LOGICAL constant. Store it as a fir.logical.
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
src = createConvertOp(rewriter, loc, srcTy, src);
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ } else {
+ if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
+ // Use dstEleTy and convert to avoid assign mismatch.
+ addr = builder.createTemporary(loc, dstEleTy);
+ auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
+ builder.create<fir::StoreOp>(loc, conv, addr);
+ srcTy = dstEleTy;
+ } else {
+ // Put constant in memory if it is not.
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ }
}
- // Put constant in memory if it is not.
- mlir::Value alloc = builder.createTemporary(loc, srcTy);
- builder.create<fir::StoreOp>(loc, src, alloc);
- addr = alloc;
} else {
addr = op.getSrc();
}
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
};
// Conversion of data transfer involving at least one descriptor.
- if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+ if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
// Transfer to a descriptor.
mlir::func::FuncOp func =
isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
- src = emboxSrc(rewriter, op, symtab);
+ mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
+ src = emboxSrc(rewriter, op, symtab, dstEleTy);
if (fir::isa_trivial(srcTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b62c500f4a2d3..a724d9f681fb6 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
// CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+func.func @_QPsub20() {
+ %0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
+ %1 = fir.zero_bits !fir.heap<f32>
+ %2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
+ fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
+ %3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+ %c0_i32 = arith.constant 0 : i32
+ cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
+ return
+}
+
+// CHECK-LABEL:func.func @_QPsub20
+// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
+// CHECK: %[[TMP:.*]] = fir.alloca f32
+// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
+// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
+// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
+// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+
} // end of module
+
More information about the flang-commits
mailing list