[flang-commits] [flang] [flang][cuda] Update device descriptor on data transfer (PR #114838)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Nov 4 09:31:43 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/114838

When the destination of the data transfer is a global we might need to sync the descriptor after the data transfer is done. This is the case when the data transfer is from host/device to device as reallocation might have happened and the descriptor on the device needs to take the new values written on the host. 

A new entry point is added `CUFDataTransferGlobalDescDesc` with the sync when needed. 

>From 28325e6d7d59c16f99e9077d76b93ea76735e5d2 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 1 Nov 2024 14:20:34 -0700
Subject: [PATCH] [flang][cuda] Update device descriptor on data transfer

---
 flang/include/flang/Runtime/CUDA/memory.h     |  4 +++
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 17 +++++++++++--
 flang/runtime/CUDA/memory.cpp                 | 14 +++++++++++
 flang/test/Fir/CUDA/cuda-data-transfer.fir    | 25 +++++++++++++++++++
 4 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 6d2e0c0f15942b..51d6b8d4545f09 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -49,6 +49,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
 void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
+/// Data transfer from a descriptor to a global descriptor.
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
+    unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
 } // extern "C"
 } // namespace Fortran::runtime::cuda
 #endif // FORTRAN_RUNTIME_CUDA_MEMORY_H_
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 4050064ebe95d8..a28d0a562f2f0b 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -429,6 +429,16 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
   }
 };
 
+static bool isDstGlobal(cuf::DataTransferOp op) {
+  if (auto declareOp = op.getDst().getDefiningOp<fir::DeclareOp>())
+    if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+      return true;
+  if (auto declareOp = op.getDst().getDefiningOp<hlfir::DeclareOp>())
+    if (declareOp.getMemref().getDefiningOp<fir::AddrOfOp>())
+      return true;
+  return false;
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -522,8 +532,11 @@ struct CUFDataTransferOpConversion
         mlir::isa<fir::BaseBoxType>(dstTy)) {
       // Transfer between two descriptor.
       mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
-              loc, builder);
+          isDstGlobal(op)
+              ? fir::runtime::getRuntimeFunc<mkRTKey(
+                    CUFDataTransferGlobalDescDesc)>(loc, builder)
+              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
+                    loc, builder);
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index daf1db684a3d2e..d8b79e9b88a68f 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -9,6 +9,7 @@
 #include "flang/Runtime/CUDA/memory.h"
 #include "../terminator.h"
 #include "flang/Runtime/CUDA/common.h"
+#include "flang/Runtime/CUDA/descriptor.h"
 #include "flang/Runtime/assign.h"
 
 #include "cuda_runtime.h"
@@ -123,5 +124,18 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
   Fortran::runtime::Assign(
       *dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
 }
+
+void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
+    Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+    int sourceLine) {
+  RTNAME(CUFDataTransferDescDesc)(
+      dstDesc, srcDesc, mode, sourceFile, sourceLine);
+  if ((mode == kHostToDevice) || (mode == kDeviceToDevice)) {
+    void *deviceAddr{
+        RTNAME(CUFGetDeviceAddress)((void *)dstDesc, sourceFile, sourceLine)};
+    RTNAME(CUFDescriptorSync)
+    ((Descriptor *)deviceAddr, srcDesc, sourceFile, sourceLine);
+  }
+}
 }
 } // namespace Fortran::runtime::cuda
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index cee3048e279cc7..a760650d143583 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -224,4 +224,29 @@ func.func @_QPsub9() {
 // CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
 // CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
 // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+
+fir.global @_QMmod1Ea {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+
+func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
+  %0 = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  %2 = fir.address_of(@_QFEahost) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %3:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFEahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+  cuf.data_transfer %3#0 to %1#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"}
+// CHECK: %[[GLOBAL_ADDRESS:.*]] = fir.address_of(@_QMmod1Ea) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[GLOBAL_DECL:.*]]:2 = hlfir.declare %[[GLOBAL_ADDRESS]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod1Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+
+
 } // end of module



More information about the flang-commits mailing list