[flang-commits] [flang] [flang][cuda] Pass descriptor by reference for CUFMemsetDescriptor (PR #114338)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Wed Oct 30 18:12:26 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/114338
Similar to #114288
>From ad3ce5aeb93e3db5481e6402429afc80ae6a2640 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 30 Oct 2024 18:10:02 -0700
Subject: [PATCH] [flang][cuda] Pass descriptor by reference for
CUFMemsetDescriptor
---
flang/include/flang/Runtime/CUDA/memory.h | 2 +-
flang/lib/Optimizer/Transforms/CUFOpConversion.cpp | 3 +--
flang/runtime/CUDA/memory.cpp | 4 ++--
flang/test/Fir/CUDA/cuda-data-transfer.fir | 10 ++++------
4 files changed, 8 insertions(+), 11 deletions(-)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index fb48152d707182..6d2e0c0f15942b 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -28,7 +28,7 @@ void RTDECL(CUFMemFree)(void *devicePtr, unsigned type,
/// Set value to the data hold by a descriptor. The \p value pointer must be
/// addressable to the same amount of bytes specified by the element size of
/// the descriptor \p desc.
-void RTDECL(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
+void RTDECL(CUFMemsetDescriptor)(Descriptor *desc, void *value,
const char *sourceFile = nullptr, int sourceLine = 0);
/// Data transfer from a pointer to a pointer.
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index e3e441360e949b..4050064ebe95d8 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -552,9 +552,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
- mlir::Value dst = builder.loadIfRef(loc, op.getDst());
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
- builder, loc, fTy, dst, val, sourceFile, sourceLine)};
+ builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 4778a4ae77683f..d03f1cc0e48d66 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -49,8 +49,8 @@ void RTDEF(CUFMemFree)(
}
}
-void RTDEF(CUFMemsetDescriptor)(const Descriptor &desc, void *value,
- const char *sourceFile, int sourceLine) {
+void RTDEF(CUFMemsetDescriptor)(
+ Descriptor *desc, void *value, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
terminator.Crash("not yet implemented: CUDA data transfer from a scalar "
"value to a descriptor");
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index b99e09fb76468b..cee3048e279cc7 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -33,10 +33,9 @@ func.func @_QPsub2() {
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub2Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[C2:.*]] = arith.constant 2 : i32
// CHECK: fir.store %[[C2]] to %[[TEMP]] : !fir.ref<i32>
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP]] : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[TEMP_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
func.func @_QPsub3() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -51,10 +50,9 @@ func.func @_QPsub3() {
// CHECK-LABEL: func.func @_QPsub3()
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub3Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[V:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub3Ev"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
-// CHECK: %[[ADEV_LOAD:.*]] = fir.load %[[ADEV]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
-// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV_LOAD]] : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.box<none>
+// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[V_CONV:.*]] = fir.convert %[[V]]#0 : (!fir.ref<i32>) -> !fir.llvm_ptr<i8>
-// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.box<none>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFMemsetDescriptor(%[[ADEV_BOX]], %[[V_CONV]], %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> none
func.func @_QPsub4() {
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
More information about the flang-commits
mailing list