[flang-commits] [flang] bfe486f - Passing descriptors by reference to CUDA runtime calls (#114288)
via flang-commits
flang-commits at lists.llvm.org
Wed Oct 30 13:24:51 PDT 2024
Author: Renaud Kauffmann
Date: 2024-10-30T13:24:47-07:00
New Revision: bfe486fe764667d514124faf2b39afb7e7322640
URL: https://github.com/llvm/llvm-project/commit/bfe486fe764667d514124faf2b39afb7e7322640
DIFF: https://github.com/llvm/llvm-project/commit/bfe486fe764667d514124faf2b39afb7e7322640.diff
LOG: Passing descriptors by reference to CUDA runtime calls (#114288)
Passing a descriptor as a `const Descriptor &` or a `const Descriptor *`
generates a FIR signature where the box is passed by value.
This is an issue, as it requires a load of the box to be passed. But
since, ultimately, all boxes are passed by reference a temporary is
generated in LLVM and the reference to the temporary is passed.
The boxes addresses are registered with the CUDA runtime but the
temporaries are not, thus preventing the runtime to properly map a host
side address to its device side counterpart.
To address this issue, this PR changes the signatures to the transfer
functions to pass a descriptor as a `Descriptor *`, which will in turn
generate a FIR signature with that takes a box reference as an argument.
Added:
Modified:
flang/include/flang/Runtime/CUDA/memory.h
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
flang/runtime/CUDA/memory.cpp
flang/test/Fir/CUDA/cuda-data-transfer.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 3c3ae73d4ad7a1..fb48152d707182 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -36,19 +36,18 @@ void RTDECL(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
/// Data transfer from a pointer to a descriptor.
-void RTDECL(CUFDataTransferDescPtr)(const Descriptor &dst, void *src,
+void RTDECL(CUFDataTransferDescPtr)(Descriptor *dst, void *src,
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
int sourceLine = 0);
/// Data transfer from a descriptor to a pointer.
-void RTDECL(CUFDataTransferPtrDesc)(void *dst, const Descriptor &src,
+void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
std::size_t bytes, unsigned mode, const char *sourceFile = nullptr,
int sourceLine = 0);
/// Data transfer from a descriptor to a descriptor.
-void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dst,
- const Descriptor &src, unsigned mode, const char *sourceFile = nullptr,
- int sourceLine = 0);
+void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
+ unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
} // extern "C"
} // namespace Fortran::runtime::cuda
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index f1f3a95b220df5..e3e441360e949b 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -529,8 +529,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
- mlir::Value dst = builder.loadIfRef(loc, op.getDst());
- mlir::Value src = builder.loadIfRef(loc, op.getSrc());
+ mlir::Value dst = op.getDst();
+ mlir::Value src = op.getSrc();
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
@@ -603,11 +603,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- mlir::Value dst =
- dstIsDesc ? builder.loadIfRef(loc, op.getDst()) : op.getDst();
- mlir::Value src = mlir::isa<fir::BaseBoxType>(srcTy)
- ? builder.loadIfRef(loc, op.getSrc())
- : op.getSrc();
+ mlir::Value dst = op.getDst();
+ mlir::Value src = op.getSrc();
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
modeValue, sourceFile, sourceLine)};
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index fc48b4343eea9d..4778a4ae77683f 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -73,23 +73,22 @@ void RTDEF(CUFDataTransferPtrPtr)(void *dst, void *src, std::size_t bytes,
CUDA_REPORT_IF_ERROR(cudaMemcpy(dst, src, bytes, kind));
}
-void RTDEF(CUFDataTransferDescPtr)(const Descriptor &desc, void *addr,
+void RTDEF(CUFDataTransferDescPtr)(Descriptor *desc, void *addr,
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash(
"not yet implemented: CUDA data transfer from a pointer to a descriptor");
}
-void RTDEF(CUFDataTransferPtrDesc)(void *addr, const Descriptor &desc,
+void RTDEF(CUFDataTransferPtrDesc)(void *addr, Descriptor *desc,
std::size_t bytes, unsigned mode, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash(
"not yet implemented: CUDA data transfer from a descriptor to a pointer");
}
-void RTDECL(CUFDataTransferDescDesc)(const Descriptor &dstDesc,
- const Descriptor &srcDesc, unsigned mode, const char *sourceFile,
- int sourceLine) {
+void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
+ unsigned mode, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash(
"not yet implemented: CUDA data transfer between two descriptors");
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index c33c50115b9fc0..b99e09fb76468b 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -15,11 +15,9 @@ func.func @_QPsub1() {
// CHECK-LABEL: func.func @_QPsub1()
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
-// CHECK: %[[AHOST_LOAD:.*]] = fir.load %[[AHOST]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
-// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
+// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub2() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub2Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -76,19 +74,17 @@ func.func @_QPsub4() {
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fir.dummy_scope : !fir.dscope
@@ -122,19 +118,17 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescPtr(%[[ADEV_BOX]], %[[AHOST_PTR]], %[[BYTES_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>
// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?x?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.box<none>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub6() {
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>
More information about the flang-commits
mailing list