[flang-commits] [flang] db69d69 - [flang][cuda] Support data transfer from descriptor to a pointer (#115023)
via flang-commits
flang-commits at lists.llvm.org
Tue Nov 5 11:59:14 PST 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-11-05T11:59:08-08:00
New Revision: db69d6939a93d1e401abe6bfe114e55b69297975
URL: https://github.com/llvm/llvm-project/commit/db69d6939a93d1e401abe6bfe114e55b69297975
DIFF: https://github.com/llvm/llvm-project/commit/db69d6939a93d1e401abe6bfe114e55b69297975.diff
LOG: [flang][cuda] Support data transfer from descriptor to a pointer (#115023)
Data transfer from a variable with a descriptor to a pointer. We create
a descriptor for the pointer so we can use the flang runtime to perform
the transfer. The Assign function handles all corner cases. We add a new
entry points `CUFDataTransferDescDescNoRealloc` to avoid reallocation
since the variable on the LHS is not an allocatable.
Added:
Modified:
flang/include/flang/Runtime/CUDA/memory.h
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
flang/runtime/CUDA/memory.cpp
flang/test/Fir/CUDA/cuda-data-transfer.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 4ac2528c1aedbc..713bdf536aaf90 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+/// Data transfer from a descriptor to a descriptor.
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
+ unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
/// Data transfer from a descriptor to a global descriptor.
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 89d0af1fcd136f..6187ca03d2c411 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
- // Type used to compute the width.
- mlir::Type computeType = dstTy;
- auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
- if (mlir::isa<fir::BaseBoxType>(dstTy)) {
- computeType = srcTy;
- seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
- }
- int width = computeWidth(loc, computeType, kindMap);
+ // Transfer from a descriptor.
- mlir::Value nbElement;
- mlir::Type idxTy = rewriter.getIndexType();
- if (!op.getShape()) {
- nbElement = rewriter.create<mlir::arith::ConstantOp>(
- loc, idxTy,
- rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
- } else {
- auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
- nbElement =
- createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
- for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
- auto operand =
- createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
- nbElement =
- rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
- }
- }
+ mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+ mlir::Type boxTy = fir::BoxType::get(dstTy);
+ llvm::SmallVector<mlir::Value> lenParams;
+ mlir::Value box =
+ builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
+ /*slice=*/nullptr, lenParams,
+ /*tdesc=*/nullptr);
+ mlir::Value memBox = builder.createTemporary(loc, box.getType());
+ builder.create<fir::StoreOp>(loc, box, memBox);
- mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
- loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
- mlir::Value bytes =
- rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
+ mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFDataTransferDescDescNoRealloc)>(loc, builder);
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
- loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- mlir::Value dst = op.getDst();
- mlir::Value src = op.getSrc();
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
llvm::SmallVector<mlir::Value> args{
- fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+ fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 2d499f93fbaece..7b40b837e7666e 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
+ Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+ int sourceLine) {
+ MemmoveFct memmoveFct;
+ Terminator terminator{sourceFile, sourceLine};
+ if (mode == kHostToDevice) {
+ memmoveFct = &MemmoveHostToDevice;
+ } else if (mode == kDeviceToHost) {
+ memmoveFct = &MemmoveDeviceToHost;
+ } else if (mode == kDeviceToDevice) {
+ memmoveFct = &MemmoveDeviceToDevice;
+ } else {
+ terminator.Crash("host to host copy not supported");
+ }
+ Fortran::runtime::Assign(
+ *dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
+}
+
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 6a33190168024f..d9588942b21e81 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -73,6 +73,7 @@ func.func @_QPsub4() {
return
}
// CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -81,13 +82,11 @@ func.func @_QPsub4() {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fir.dummy_scope : !fir.dscope
@@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
}
// CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
// CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
@@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub6() {
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>
More information about the flang-commits
mailing list