[flang-commits] [flang] [flang][cuda] Support data transfer from pointer to a descriptor (PR #114892)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Mon Nov 4 21:48:27 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/114892

>From 2b0a9dd0b871143a28b451ceff899d5d553ab708 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 4 Nov 2024 13:28:40 -0800
Subject: [PATCH 1/2] [flang][cuda] Support data transfer from pointer to a
 descriptor

---
 flang/include/flang/Runtime/CUDA/memory.h     |  5 --
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 51 ++++++++++----
 flang/runtime/CUDA/memory.cpp                 |  7 --
 flang/test/Fir/CUDA/cuda-data-transfer.fir    | 70 ++++++++++++++-----
 4 files changed, 91 insertions(+), 42 deletions(-)

diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 51d6b8d4545f09..4ac2528c1aedbc 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -35,11 +35,6 @@ void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
 void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
-/// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
-    std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
-    int sourceLine = 0);
-
 /// Data transfer from a descriptor to a pointer.
 void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
     std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index a28d0a562f2f0b..9b8bcf2f719281 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -23,6 +23,7 @@
 #include "flang/Runtime/allocatable.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -439,6 +440,14 @@ static bool isDstGlobal(cuf::DataTransferOp op) {
   return false;
 }
 
+static mlir::Value getShapeFromDecl(mlir::Value src) {
+  if (auto declareOp = src.getDefiningOp<fir::DeclareOp>())
+    return declareOp.getShape();
+  if (auto declareOp = src.getDefiningOp<hlfir::DeclareOp>())
+    return declareOp.getShape();
+  return mlir::Value{};
+}
+
 struct CUFDataTransferOpConversion
     : public mlir::OpRewritePattern<cuf::DataTransferOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -528,22 +537,45 @@ struct CUFDataTransferOpConversion
     }
 
     // Conversion of data transfer involving at least one descriptor.
-    if (mlir::isa<fir::BaseBoxType>(srcTy) &&
-        mlir::isa<fir::BaseBoxType>(dstTy)) {
-      // Transfer between two descriptor.
+    if (mlir::isa<fir::BaseBoxType>(dstTy)) {
+      // Transfer to a descriptor.
       mlir::func::FuncOp func =
           isDstGlobal(op)
               ? fir::runtime::getRuntimeFunc<mkRTKey(
                     CUFDataTransferGlobalDescDesc)>(loc, builder)
               : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
                     loc, builder);
+      mlir::Value dst = op.getDst();
+      mlir::Value src = op.getSrc();
+
+      if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
+        // If src is not a descriptor, create one.
+        mlir::Value addr;
+        if (fir::isa_trivial(srcTy) &&
+            mlir::matchPattern(op.getSrc().getDefiningOp(),
+                               mlir::m_Constant())) {
+          // Put constant in memory if it is not.
+          mlir::Value alloc = builder.createTemporary(loc, srcTy);
+          builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
+          addr = alloc;
+        } else {
+          addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
+        }
+        mlir::Type boxTy = fir::BoxType::get(srcTy);
+        llvm::SmallVector<mlir::Value> lenParams;
+        mlir::Value box =
+            builder.createBox(loc, boxTy, addr, getShapeFromDecl(src),
+                              /*slice=*/nullptr, lenParams,
+                              /*tdesc=*/nullptr);
+        mlir::Value memBox = builder.createTemporary(loc, box.getType());
+        builder.create<fir::StoreOp>(loc, box, memBox);
+        src = memBox;
+      }
 
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
           fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
       llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
@@ -573,9 +605,7 @@ struct CUFDataTransferOpConversion
       // Type used to compute the width.
       mlir::Type computeType = dstTy;
       auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      bool dstIsDesc = false;
       if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        dstIsDesc = true;
         computeType = srcTy;
         seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
       }
@@ -606,11 +636,8 @@ struct CUFDataTransferOpConversion
           rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
 
       mlir::func::FuncOp func =
-          dstIsDesc
-              ? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
-                    loc, builder)
-              : fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-                    loc, builder);
+          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
+              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 0e03c618663ebd..2d499f93fbaece 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -96,13 +96,6 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
   CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
 }
 
-void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
-    std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
-  Terminator terminator{sourceFile, sourceLine};
-  terminator.Crash(
-      "not yet implemented: CUDA data transfer from a pointer to a descriptor");
-}
-
 void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
     std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
   Terminator terminator{sourceFile, sourceLine};
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index a760650d143583..6a33190168024f 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -29,13 +29,16 @@ func.func @_QPsub2() {
 }
 
 // CHECK-LABEL: func.func @_QPsub2()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[TEMP:.*]] = fir.alloca i32
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[C2:.*]] = arith.constant 2 : i32
 // CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[TEMP]] : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub3() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -48,12 +51,15 @@ func.func @_QPsub3() {
 }
 
 // CHECK-LABEL: func.func @_QPsub3()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<i32>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[V]]#0 : (!fir.ref<i32>) -> !fir.box<i32>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
-
+// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
+  
 func.func @_QPsub4() {
   %0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
   %4:2 = hlfir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
@@ -67,15 +73,14 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}}) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -110,16 +115,15 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[SHAPE]]) {uniq_name = "_QFsub5Eahost"} : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x?xi32>>, !fir.ref<!fir.array<?x?xi32>>)
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 // CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
 // CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
 // CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
@@ -248,5 +252,35 @@ func.func @_QQdesc_global() attributes {fir.bindc_name = "host_sub"} {
 // CHECK: %[[BOX_NONE:.*]] = fir.convert %[[GLOBAL_DECL:.*]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[BOX_NONE]],{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
+fir.global @_QMmod2Eadev {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xi32>>> {
+  %c0 = arith.constant 0 : index
+  %0 = fir.zero_bits !fir.heap<!fir.array<?xi32>>
+  %1 = fir.shape %c0 : (index) -> !fir.shape<1>
+  %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.heap<!fir.array<?xi32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xi32>>>
+  fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xi32>>>
+}
+func.func @_QPdesc_global_ptr() {
+  %c10 = arith.constant 10 : index
+  %0 = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %1 = fir.declare %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  %2 = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+  %3 = fir.shape %c10 : (index) -> !fir.shape<1>
+  %4 = fir.declare %2(%3) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+  cuf.data_transfer %4 to %1 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+  return
+}
+
+// CHECK-LABEL: func.func @_QPdesc_global_ptr()
+// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
+// CHECK: %[[ADDR_ADEV:.*]] = fir.address_of(@_QMmod2Eadev) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[DECL_ADEV:.*]] = fir.declare %[[ADDR_ADEV]] {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMmod2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+// CHECK: %[[AHOST:.*]] = fir.alloca !fir.array<10xi32> {bindc_name = "ahost", uniq_name = "_QFdesc_global_ptrEahost"}
+// CHECK: %[[SHAPE:.*]] = fir.shape %c10 : (index) -> !fir.shape<1>
+// CHECK: %[[DECL_AHOST:.*]] = fir.declare %[[AHOST]](%[[SHAPE]]) {uniq_name = "_QFdesc_global_ptrEahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[DECL_AHOST]](%[[SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[ADEV_BOXNONE:.*]] = fir.convert %[[DECL_ADEV]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[AHOST_BOXNONE:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferGlobalDescDesc(%[[ADEV_BOXNONE]], %[[AHOST_BOXNONE]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 } // end of module

>From c85735aad2f9a76a705d5ebfb3da5cc44804993f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 4 Nov 2024 21:48:15 -0800
Subject: [PATCH 2/2] Remove unused code

---
 .../Optimizer/Transforms/CUFOpConversion.cpp  | 21 -------------------
 1 file changed, 21 deletions(-)

diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 9b8bcf2f719281..89d0af1fcd136f 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -580,27 +580,6 @@ struct CUFDataTransferOpConversion
           builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
-    } else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
-      // Scalar to descriptor transfer.
-      mlir::Value val = op.getSrc();
-      if (op.getSrc().getDefiningOp() &&
-          mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
-        mlir::Value alloc = builder.createTemporary(loc, srcTy);
-        builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
-        val = alloc;
-      }
-
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
-                                                                     builder);
-      auto fTy = func.getFunctionType();
-      mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
-      mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
-      llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
-          builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
-      builder.create<fir::CallOp>(loc, func, args);
-      rewriter.eraseOp(op);
     } else {
       // Type used to compute the width.
       mlir::Type computeType = dstTy;



More information about the flang-commits mailing list