[flang-commits] [flang] [flang][cuda] Enable data transfer for descriptors (PR #92804)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue May 21 10:00:28 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/92804

>From 0b06f7202ef6fe7df3ce915d8a705846b5eb7b3c Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 20 May 2024 11:40:04 -0700
Subject: [PATCH 1/3] [flang][cuda] Enable data transfer for descriptor

---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     |  6 ++--
 flang/lib/Lower/Bridge.cpp                    | 36 ++++++++++---------
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp    | 13 +++++++
 flang/test/Lower/CUDA/cuda-data-transfer.cuf  | 19 ++++++++++
 4 files changed, 55 insertions(+), 19 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 72157bce4f768..b33aeca590b56 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -154,13 +154,15 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
     ```
   }];
 
-  let arguments = (ins Arg<AnyReferenceLike, "", [MemWrite]>:$src,
-                       Arg<AnyReferenceLike, "", [MemRead]>:$dst,
+  let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$src,
+                       Arg<AnyRefOrBoxType, "", [MemRead]>:$dst,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
     $src `to` $dst attr-dict `:` type(operands)
   }];
+
+  let hasVerifier = 1;
 }
 
 def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 7ded9adcd5c2a..8e9ce78119d18 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3717,8 +3717,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                            hlfir::Entity &lhs, hlfir::Entity &rhs) {
     bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
     bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
-    if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
-      TODO(loc, "CUDA data transfler with descriptors");
+
+    auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+      if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
+        auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+        return loadOp.getMemref();
+      }
+      return val;
+    };
+
+    mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
+    mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
 
     // device = host
     if (lhsIsDevice && !rhsIsDevice) {
@@ -3727,11 +3736,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       if (!rhs.isVariable()) {
         auto associate = hlfir::genAssociateExpr(
             loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
-        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
+        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
                                             transferKindAttr);
         builder.create<hlfir::EndAssociateOp>(loc, associate);
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                            transferKindAttr);
       }
       return;
     }
@@ -3740,26 +3750,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (!lhsIsDevice && rhsIsDevice) {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
-      if (!rhs.isVariable()) {
-        // evaluateRhs loads scalar. Look for the memory reference to be used in
-        // the transfer.
-        if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
-          auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
-          builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
-                                              transferKindAttr);
-          return;
-        }
-      } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
-      }
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
 
+    // device = device
     if (lhsIsDevice && rhsIsDevice) {
       assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
-      builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 870652c72fab7..b00c374682922 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// DataTransferOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult cuf::DataTransferOp::verify() {
+  mlir::Type srcTy = getSrc().getType();
+  mlir::Type dstTy = getDst().getType();
+  if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
+      fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
+    return mlir::success();
+  return emitOpError("expect src and dst to be both references or descriptors");
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocateOp
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 084314ed63ecd..e23792e6efc55 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -159,3 +159,22 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPsub6
 ! CHECK: cuf.data_transfer
+
+subroutine sub7(a, b, c)
+  integer, device, allocatable :: a(:), c(:)
+  integer, allocatable :: b(:)
+  b = a
+
+  a = b
+
+  c = a
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub7(
+! CHECK-SAME:  %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>

>From 1205acecc9371533cc717fe7931835b732a55a73 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 20 May 2024 14:51:08 -0700
Subject: [PATCH 2/3] Use dyn_cast_or_null

---
 flang/lib/Lower/Bridge.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8e9ce78119d18..e2a011ebbb763 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3719,10 +3719,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
 
     auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
-      if (mlir::isa_and_nonnull<fir::LoadOp>(val.getDefiningOp())) {
-        auto loadOp = mlir::dyn_cast<fir::LoadOp>(val.getDefiningOp());
+      if (auto loadOp =
+              mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
         return loadOp.getMemref();
-      }
       return val;
     };
 

>From 8edd8f1ce8dba607704caede9981a54415e4bdc2 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 20 May 2024 14:56:15 -0700
Subject: [PATCH 3/3] Add comment about descriptor updates

---
 flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index b33aeca590b56..500f44365fb93 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -152,6 +152,10 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
       a = adev ! transfer device to host
       bdev = adev ! transfer device to device
     ```
+    
+    When the data transfer is done on data hold by descriptors, the LHS data
+    hold by the descriptor are updated. When required, the LHS decriptor is also
+    updated.
   }];
 
   let arguments = (ins Arg<AnyRefOrBoxType, "", [MemWrite]>:$src,



More information about the flang-commits mailing list