[flang-commits] [flang] [flang][cuda] Avoid assign element mismatch when doing data transfer from a constant (PR #128252)
via flang-commits
flang-commits at lists.llvm.org
Fri Feb 21 15:56:14 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
Currently when we do a CUDA data transfer from a constant, we embox it and delegate the assignment to the runtime. When the type of the constant is not exactly the same as the destination descriptor, the runtime will emit an assignment mismatch error.
Convert the constant when necessary so the assignment is fine.
---
Full diff: https://github.com/llvm/llvm-project/pull/128252.diff
2 Files Affected:
- (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+19-7)
- (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+22)
``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 1f0576aa82f83..2ab2d84f1643d 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -541,7 +541,8 @@ static mlir::Value getShapeFromDecl(mlir::Value src) {
static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
cuf::DataTransferOp op,
- const mlir::SymbolTable &symtab) {
+ const mlir::SymbolTable &symtab,
+ mlir::Type dstEleTy = nullptr) {
auto mod = op->getParentOfType<mlir::ModuleOp>();
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
@@ -555,11 +556,21 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
// from a LOGICAL constant. Store it as a fir.logical.
srcTy = fir::LogicalType::get(rewriter.getContext(), 4);
src = createConvertOp(rewriter, loc, srcTy, src);
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ } else {
+ if (dstEleTy && fir::isa_trivial(dstEleTy) && srcTy != dstEleTy) {
+ // Use dstEleTy and convert to avoid assign mismatch.
+ addr = builder.createTemporary(loc, dstEleTy);
+ auto conv = builder.create<fir::ConvertOp>(loc, dstEleTy, src);
+ builder.create<fir::StoreOp>(loc, conv, addr);
+ srcTy = dstEleTy;
+ } else {
+ // Put constant in memory if it is not.
+ addr = builder.createTemporary(loc, srcTy);
+ builder.create<fir::StoreOp>(loc, src, addr);
+ }
}
- // Put constant in memory if it is not.
- mlir::Value alloc = builder.createTemporary(loc, srcTy);
- builder.create<fir::StoreOp>(loc, src, alloc);
- addr = alloc;
} else {
addr = op.getSrc();
}
@@ -729,7 +740,7 @@ struct CUFDataTransferOpConversion
};
// Conversion of data transfer involving at least one descriptor.
- if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+ if (auto dstBoxTy = mlir::dyn_cast<fir::BaseBoxType>(dstTy)) {
// Transfer to a descriptor.
mlir::func::FuncOp func =
isDstGlobal(op)
@@ -740,7 +751,8 @@ struct CUFDataTransferOpConversion
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
- src = emboxSrc(rewriter, op, symtab);
+ mlir::Type dstEleTy = fir::unwrapInnerType(dstBoxTy.getEleTy());
+ src = emboxSrc(rewriter, op, symtab, dstEleTy);
if (fir::isa_trivial(srcTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
loc, builder);
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b62c500f4a2d3..a724d9f681fb6 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -582,4 +582,26 @@ func.func @_QPchecksums(%arg0: !fir.box<!fir.array<?xf64>> {cuf.data_attr = #cuf
// CHECK: %[[SRC:.*]] = fir.convert %{{.*}} : (!fir.ref<!fir.box<!fir.array<?xf64>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+func.func @_QPsub20() {
+ %0 = cuf.alloc !fir.box<!fir.heap<f32>> {bindc_name = "r", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub20Er"} -> !fir.ref<!fir.box<!fir.heap<f32>>>
+ %1 = fir.zero_bits !fir.heap<f32>
+ %2 = fir.embox %1 {allocator_idx = 2 : i32} : (!fir.heap<f32>) -> !fir.box<!fir.heap<f32>>
+ fir.store %2 to %0 : !fir.ref<!fir.box<!fir.heap<f32>>>
+ %3:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub20Er"} : (!fir.ref<!fir.box<!fir.heap<f32>>>) -> (!fir.ref<!fir.box<!fir.heap<f32>>>, !fir.ref<!fir.box<!fir.heap<f32>>>)
+ %c0_i32 = arith.constant 0 : i32
+ cuf.data_transfer %c0_i32 to %3#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.box<!fir.heap<f32>>>
+ return
+}
+
+// CHECK-LABEL:func.func @_QPsub20
+// CHECK: %[[BOX_ALLOCA:.*]] = fir.alloca !fir.box<f32>
+// CHECK: %[[TMP:.*]] = fir.alloca f32
+// CHECK: %[[CONV:.*]] = fir.convert %c0{{.*}} : (i32) -> f32
+// CHECK: fir.store %[[CONV]] to %[[TMP]] : !fir.ref<f32>
+// CHECK: %[[BOX:.*]] = fir.embox %[[TMP]] : (!fir.ref<f32>) -> !fir.box<f32>
+// CHECK: fir.store %[[BOX]] to %[[BOX_ALLOCA]] : !fir.ref<!fir.box<f32>>
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[BOX_ALLOCA]] : (!fir.ref<!fir.box<f32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%13, %[[BOX_NONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> ()
+
} // end of module
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/128252
More information about the flang-commits
mailing list