[flang-commits] [flang] 9d09c6f - [flang][cuda] Update device descriptor on data transfer (#114838)
via flang-commits
flang-commits at lists.llvm.org
Mon Nov 4 13:22:09 PST 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-04T13:22:06-08:00
New Revision: 9d09c6fd9c38a0e631e42a902ae215fb2f828da9
URL: https://github.com/llvm/llvm-project/commit/9d09c6fd9c38a0e631e42a902ae215fb2f828da9
DIFF: https://github.com/llvm/llvm-project/commit/9d09c6fd9c38a0e631e42a902ae215fb2f828da9.diff
LOG: [flang][cuda] Update device descriptor on data transfer (#114838)
When the destination of the data transfer is a global we might need to
sync the descriptor after the data transfer is done. This is the case
when the data transfer is from host/device to device as reallocation
might have happened and the descriptor on the device needs to take the
new values written on the host.
A new entry point is added `CUFDataTransferGlobalDescDesc` with the sync
when needed.
Added:
Modified:
flang/include/flang/Runtime/CUDA/memory.h
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
flang/runtime/CUDA/memory.cpp
flang/test/Fir/CUDA/cuda-data-transfer.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 6d2e0c0f15942b..51d6b8d4545f09 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -49,6 +49,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+/// Data transfer from a descriptor to a global descriptor.
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
+ unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
} // extern "C"
} // namespace Fortran::runtime::cuda
#endif // FORTRAN_RUNTIME_CUDA_MEMORY_H_
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 4050064ebe95d8..a28d0a562f2f0b 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -429,6 +429,16 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};
+static bool isDstGlobal(cuf::DataTransferOp op) {
+ if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
+ if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+ return true;
+ if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
+ if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+ return true;
+ return false;
+}
+
struct CUFDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
@@ -522,8 +532,11 @@ struct CUFDataTransferOpConversion
mlir::isa<fir::BaseBoxType>(dstTy)) {
// Transfer between two descriptor.
mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
- loc, builder);
+ isDstGlobal(op)
+ ? fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFDataTransferGlobalDescDesc)>(loc, builder)
+ : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
+ loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 34442d05f495d3..0e03c618663ebd 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -9,6 +9,7 @@
#include "flang/Runtime/CUDA/memory.h"
#include "../terminator.h"
#include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/assign.h"
#include "cuda_runtime.h"
@@ -125,5 +126,18 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
Fortran::runtime::Assign(
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}
+
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
+ Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+ int sourceLine) {
+ RTNAME(CUFDataTransferDescDesc)
+ (dstDesc, srcDesc, mode, sourceFile, sourceLine);
+ if ((mode == kHostToDevice) || (mode == kDeviceToDevice)) {
+ void *deviceAddr{
+ RTNAME(CUFGetDeviceAddress)((void *)dstDesc, sourceFile, sourceLine)};
+ RTNAME(CUFDescriptorSync)
+ ((Descriptor *)deviceAddr, srcDesc, sourceFile, sourceLine);
+ }
+}
}
} // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index cee3048e279cc7..a760650d143583 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -224,4 +224,29 @@ func.func @_QPsub9() {
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+ %c0 = arith.constant 0 : index
+ %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+ %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+ %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+ fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+
+func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
+ %0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+ %2 = fir.address_of(@_QFEahost) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ %3:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+ cuf.data_transfer %3#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+ return
+}
+
+// CHECK-LABEL: func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"}
+// CHECK: %[[GLOBAL_ADDRESS:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL_ADDRESS]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+
+
} // end of module
More information about the flang-commits
mailing list