[flang-commits] [flang] 93b2e47 - [flang][cuda] Avoid assign element mismatch when doing data transfer from a constant (#128252)
via flang-commits
flang-commits at lists.llvm.org
Fri Feb 21 17:46:49 PST 2025
Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-02-21T17:46:46-08:00
New Revision: 93b2e47f12649bd33b99b88f25beb277a181274a
URL: https://github.com/llvm/llvm-project/commit/93b2e47f12649bd33b99b88f25beb277a181274a
DIFF: https://github.com/llvm/llvm-project/commit/93b2e47f12649bd33b99b88f25beb277a181274a.diff
LOG: [flang][cuda] Avoid assign element mismatch when doing data transfer from a constant (#128252)
Currently when we do a CUDA data transfer from a constant, we embox it
and delegate the assignment to the runtime. When the type of the
constant is not exactly the same as the destination descriptor, the
runtime will emit an assignment mismatch error.
Convert the constant when necessary so the assignment is fine.
Added:
Modified:
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
flang/test/Fir/CUDA/cuda-data-transfer.fir
Removed:
################################################################################
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 1f0576aa82f83..2ab2d84f1643d 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
cuf::DataTransferOp op,
- const mlir::SymbolTable &symtab) {
+ const mlir::SymbolTable &symtab,
+ mlir::Type dstEleTy = nullptr) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
// from a LOGICAL constant. Store it as a fir.logical.
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
src = createConvertOp(rewriter, loc, srcTy, src);
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ } else {
+ if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
+ // Use dstEleTy and convert to avoid assign mismatch.
+ addr = builder.createTemporary(loc, dstEleTy);
+ auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
+ builder.create<fir::StoreOp>(loc, conv, addr);
+ srcTy = dstEleTy;
+ } else {
+ // Put constant in memory if it is not.
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ }
}
- // Put constant in memory if it is not.
- mlir::Value alloc = builder.createTemporary(loc, srcTy);
- builder.create<fir::StoreOp>(loc, src, alloc);
- addr = alloc;
} else {
addr = op.getSrc();
}
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
};
// Conversion of data transfer involving at least one descriptor.
- if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+ if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
// Transfer to a descriptor.
mlir::func::FuncOp func =
isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
- src = emboxSrc(rewriter, op, symtab);
+ mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
+ src = emboxSrc(rewriter, op, symtab, dstEleTy);
if (fir::isa_trivial(srcTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b62c500f4a2d3..a724d9f681fb6 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
// CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+func.func @_QPsub20() {
+ %0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
+ %1 = fir.zero_bits !fir.heap<f32>
+ %2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
+ fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
+ %3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+ %c0_i32 = arith.constant 0 : i32
+ cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
+ return
+}
+
+// CHECK-LABEL:func.func @_QPsub20
+// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
+// CHECK: %[[TMP:.*]] = fir.alloca f32
+// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
+// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
+// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
+// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+
} // end of module
+
More information about the flang-commits
mailing list