[flang-commits] [flang] [flang][cuda] Enable data transfer for descriptors (PR #92804)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Mon May 20 14:51:52 PDT 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/92804
>From 0b06f7202ef6fe7df3ce915d8a705846b5eb7b3c Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 20 May 2024 11:40:04 -0700
Subject: [PATCH 1/2] [flang][cuda] Enable data transfer for descriptor
---
.../flang/Optimizer/Dialect/CUF/CUFOps.td | 6 ++--
flang/lib/Lower/Bridge.cpp | 36 ++++++++++---------
flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp | 13 +++++++
flang/test/Lower/CUDA/cuda-data-transfer.cuf | 19 ++++++++++
4 files changed, 55 insertions(+), 19 deletions(-)
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 72157bce4f768..b33aeca590b56 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -154,13 +154,15 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
```
}];
- let arguments = (ins Arg<AnyReferenceLike, "", [MemWrite]>:$src,
- Arg<AnyReferenceLike, "", [MemRead]>:$dst,
+ let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$src,
+ Arg<AnyRefOrBoxType, "", [MemRead]>:$dst,
cuf_DataTransferKindAttr:$transfer_kind);
let assemblyFormat = [{
$src `to` $dst attr-dict `:` type(operands)
}];
+
+ let hasVerifier = 1;
}
def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7ded9adcd5c2a..8e9ce78119d18 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3717,8 +3717,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
hlfir::Entity &lhs, hlfir::Entity &rhs) {
bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
- if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
- TODO(loc, "CUDA data transfler with descriptors");
+
+ auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+ if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
+ auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+ return loadOp.getMemref();
+ }
+ return val;
+ };
+
+ mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
+ mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
// device = host
if (lhsIsDevice && !rhsIsDevice) {
@@ -3727,11 +3736,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (!rhs.isVariable()) {
auto associate = hlfir::genAssociateExpr(
loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
- builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
+ builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
transferKindAttr);
builder.create<hlfir::EndAssociateOp>(loc, associate);
} else {
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
}
return;
}
@@ -3740,26 +3750,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (!lhsIsDevice && rhsIsDevice) {
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceHost);
- if (!rhs.isVariable()) {
- // evaluateRhs loads scalar. Look for the memory reference to be used in
- // the transfer.
- if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
- auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
- builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
- transferKindAttr);
- return;
- }
- } else {
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
- }
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
return;
}
+ // device = device
if (lhsIsDevice && rhsIsDevice) {
assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
auto transferKindAttr = cuf::DataTransferKindAttr::get(
builder.getContext(), cuf::DataTransferKind::DeviceDevice);
- builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+ builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+ transferKindAttr);
return;
}
llvm_unreachable("Unhandled CUDA data transfer");
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 870652c72fab7..b00c374682922 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
return mlir::success();
}
+//===----------------------------------------------------------------------===//
+// DataTransferOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult cuf::DataTransferOp::verify() {
+ mlir::Type srcTy = getSrc().getType();
+ mlir::Type dstTy = getDst().getType();
+ if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
+ fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
+ return mlir::success();
+ return emitOpError("expect src and dst to be both references or descriptors");
+}
+
//===----------------------------------------------------------------------===//
// DeallocateOp
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 084314ed63ecd..e23792e6efc55 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -159,3 +159,22 @@ end subroutine
! CHECK-LABEL: func.func @_QPsub6
! CHECK: cuf.data_transfer
+
+subroutine sub7(a, b, c)
+ integer, device, allocatable :: a(:), c(:)
+ integer, allocatable :: b(:)
+ b = a
+
+ a = b
+
+ c = a
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub7(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
>From 1205acecc9371533cc717fe7931835b732a55a73 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 20 May 2024 14:51:08 -0700
Subject: [PATCH 2/2] Use dyn_cast_or_null
---
flang/lib/Lower/Bridge.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8e9ce78119d18..e2a011ebbb763 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3719,10 +3719,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
- if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
- auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+ if (auto loadOp =
+ mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
return loadOp.getMemref();
- }
return val;
};
More information about the flang-commits
mailing list