[flang-commits] [flang] db69d69 - [flang][cuda] Support data transfer from descriptor to a pointer (#115023)

via flang-commits flang-commits at lists.llvm.org
Tue Nov 5 11:59:14 PST 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-05T11:59:08-08:00
New Revision: db69d6939a93d1e401abe6bfe114e55b69297975

URL: https://github.com/llvm/llvm-project/commit/db69d6939a93d1e401abe6bfe114e55b69297975
DIFF: https://github.com/llvm/llvm-project/commit/db69d6939a93d1e401abe6bfe114e55b69297975.diff

LOG: [flang][cuda] Support data transfer from descriptor to a pointer (#115023)

Data transfer from a variable with a descriptor to a pointer. We create
a descriptor for the pointer so we can use the flang runtime to perform
the transfer. The Assign function handles all corner cases. We add a new
entry points `CUFDataTransferDescDescNoRealloc` to avoid reallocation
since the variable on the LHS is not an allocatable.

Added: 
    

Modified: 
    flang/include/flang/Runtime/CUDA/memory.h
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
    flang/runtime/CUDA/memory.cpp
    flang/test/Fir/CUDA/cuda-data-transfer.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 4ac2528c1aedbc..713bdf536aaf90 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
 void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
+/// Data transfer from a descriptor to a descriptor.
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
+    unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
 /// Data transfer from a descriptor to a global descriptor.
 void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 89d0af1fcd136f..6187ca03d2c411 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
     } else {
-      // Type used to compute the width.
-      mlir::Type computeType = dstTy;
-      auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        computeType = srcTy;
-        seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
-      }
-      int width = computeWidth(loc, computeType, kindMap);
+      // Transfer from a descriptor.
 
-      mlir::Value nbElement;
-      mlir::Type idxTy = rewriter.getIndexType();
-      if (!op.getShape()) {
-        nbElement = rewriter.create<mlir::arith::ConstantOp>(
-            loc, idxTy,
-            rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
-      } else {
-        auto shapeOp =
-            mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
-        nbElement =
-            createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
-        for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
-          auto operand =
-              createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
-          nbElement =
-              rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
-        }
-      }
+      mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+      mlir::Type boxTy = fir::BoxType::get(dstTy);
+      llvm::SmallVector<mlir::Value> lenParams;
+      mlir::Value box =
+          builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
+                            /*slice=*/nullptr, lenParams,
+                            /*tdesc=*/nullptr);
+      mlir::Value memBox = builder.createTemporary(loc, box.getType());
+      builder.create<fir::StoreOp>(loc, box, memBox);
 
-      mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
-          loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
-      mlir::Value bytes =
-          rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
+      mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+          CUFDataTransferDescDescNoRealloc)>(loc, builder);
 
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       llvm::SmallVector<mlir::Value> args{
-          fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+          fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
                                         modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);

diff  --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 2d499f93fbaece..7b40b837e7666e 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
       *dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
 }
 
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
+    Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+    int sourceLine) {
+  MemmoveFct memmoveFct;
+  Terminator terminator{sourceFile, sourceLine};
+  if (mode == kHostToDevice) {
+    memmoveFct = &MemmoveHostToDevice;
+  } else if (mode == kDeviceToHost) {
+    memmoveFct = &MemmoveDeviceToHost;
+  } else if (mode == kDeviceToDevice) {
+    memmoveFct = &MemmoveDeviceToDevice;
+  } else {
+    terminator.Crash("host to host copy not supported");
+  }
+  Fortran::runtime::Assign(
+      *dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
+}
+
 void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
     Descriptor *srcDesc, unsigned mode, const char *sourceFile,
     int sourceLine) {

diff  --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 6a33190168024f..d9588942b21e81 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -73,6 +73,7 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -81,13 +82,11 @@ func.func @_QPsub4() {
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
   %0 = fir.dummy_scope : !fir.dscope
@@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
@@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub6() {
   %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>


        


More information about the flang-commits mailing list