[flang-commits] [flang] 652db7e - [flang][cuda] Support data transfer from pointer to a descriptor (#114892)

via flang-commits flang-commits at lists.llvm.org
Tue Nov 5 08:56:23 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-05T08:56:19-08:00
New Revision: 652db7e4ff773df1bc78c920d1bc75a93e92bae6

URL: https://github.com/llvm/llvm-project/commit/652db7e4ff773df1bc78c920d1bc75a93e92bae6
DIFF: https://github.com/llvm/llvm-project/commit/652db7e4ff773df1bc78c920d1bc75a93e92bae6.diff

LOG: [flang][cuda] Support data transfer from pointer to a descriptor (#114892)

When source is a pointer to an array or a scalar, embox it and use the
`CUFDataTransferDescDesc` or `CUFDataTransferGlobalDescDesc` entry
points. The runtime is already able to deal with all the corner cases
like non contiguous arrays and so on so we exploit this.

Memset might still be used for simple case where we want to initialize
to 0 for example. This will come in a follow up patch.

Added: 
    

Modified: 
    flang/include/flang/Runtime/CUDA/memory.h
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
    flang/runtime/CUDA/memory.cpp
    flang/test/Fir/CUDA/cuda-data-transfer.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 51d6b8d4545f09..4ac2528c1aedbc 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
 void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
-/// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
-    std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
-    int sourceLine = 0);
-
 /// Data transfer from a descriptor to a pointer.
 void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,

diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index a28d0a562f2f0b..89d0af1fcd136f 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -23,6 +23,7 @@
 #include "flang/Runtime/allocatable.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) {
   return false;
 }
 
+static mlir::Value getShapeFromDecl(mlir::Value src) {
+  if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
+    return declareOp.getShape();
+  if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
+    return declareOp.getShape();
+  return mlir::Value{};
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -528,54 +537,54 @@ struct CUFDataTransferOpConversion
     }
 
     // Conversion of data transfer involving at least one descriptor.
-    if (mlir::isa<fir::BaseBoxType>(srcTy) &&
-        mlir::isa<fir::BaseBoxType>(dstTy)) {
-      // Transfer between two descriptor.
+    if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+      // Transfer to a descriptor.
       mlir::func::FuncOp func =
           isDstGlobal(op)
               ? fir::runtime::getRuntimeFunc<mkRTKey(
                     CUFDataTransferGlobalDescDesc)>(loc, builder)
               : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
                     loc, builder);
-
-      auto fTy = func.getFunctionType();
-      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       mlir::Value dst = op.getDst();
       mlir::Value src = op.getSrc();
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
-      builder.create<fir::CallOp>(loc, func, args);
-      rewriter.eraseOp(op);
-    } else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
-      // Scalar to descriptor transfer.
-      mlir::Value val = op.getSrc();
-      if (op.getSrc().getDefiningOp() &&
-          mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
-        mlir::Value alloc = builder.createTemporary(loc, srcTy);
-        builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
-        val = alloc;
+
+      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
+        // If src is not a descriptor, create one.
+        mlir::Value addr;
+        if (fir::isa_trivial(srcTy) &&
+            mlir::matchPattern(op.getSrc().getDefiningOp(),
+                               mlir::m_Constant())) {
+          // Put constant in memory if it is not.
+          mlir::Value alloc = builder.createTemporary(loc, srcTy);
+          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+          addr = alloc;
+        } else {
+          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+        }
+        mlir::Type boxTy = fir::BoxType::get(srcTy);
+        llvm::SmallVector<mlir::Value> lenParams;
+        mlir::Value box =
+            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
+                              /*slice=*/nullptr, lenParams,
+                              /*tdesc=*/nullptr);
+        mlir::Value memBox = builder.createTemporary(loc, box.getType());
+        builder.create<fir::StoreOp>(loc, box, memBox);
+        src = memBox;
       }
 
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
-                                                                     builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
+          builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
     } else {
       // Type used to compute the width.
       mlir::Type computeType = dstTy;
       auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      bool dstIsDesc = false;
       if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        dstIsDesc = true;
         computeType = srcTy;
         seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
       }
@@ -606,11 +615,8 @@ struct CUFDataTransferOpConversion
           rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
 
       mlir::func::FuncOp func =
-          dstIsDesc
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-                    loc, builder);
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
+              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =

diff  --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 0e03c618663ebd..2d499f93fbaece 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
   CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
 }
 
-void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
-    std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
-  Terminator terminator{sourceFile, sourceLine};
-  terminator.Crash(
-      "not yet implemented: CUDA data transfer from a pointer to a descriptor");
-}
-
 void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};

diff  --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index a760650d143583..6a33190168024f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -29,13 +29,16 @@ func.func @_QPsub2() {
 }
 
 // CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[TEMP:.*]] = fir.alloca i32
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[C2:.*]] = arith.constant 2 : i32
 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub3() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -48,12 +51,15 @@ func.func @_QPsub3() {
 }
 
 // CHECK-LABEL: func.func @_QPsub3()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
-
+// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+  
 func.func @_QPsub4() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
   %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
@@ -67,15 +73,14 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmod2Eadev {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+func.func @_QPdesc_global_ptr() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+  %3 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+  cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPdesc_global_ptr()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
+// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
+// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 } // end of module


        


More information about the flang-commits mailing list