[flang-commits] [flang] 1fc3ce1 - [flang][cuda] Enable data transfer for descriptors (#92804)

via flang-commits flang-commits at lists.llvm.org
Tue May 21 11:24:00 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-05-21T11:23:55-07:00
New Revision: 1fc3ce1cdb8390ed64feea939a9555d3642439ea

URL: https://github.com/llvm/llvm-project/commit/1fc3ce1cdb8390ed64feea939a9555d3642439ea
DIFF: https://github.com/llvm/llvm-project/commit/1fc3ce1cdb8390ed64feea939a9555d3642439ea.diff

LOG: [flang][cuda] Enable data transfer for descriptors (#92804)

Remove the TODO when data transfer is done with descriptor variables.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
    flang/lib/Lower/Bridge.cpp
    flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
    flang/test/Lower/CUDA/cuda-data-transfer.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 1c98b4131a139..f2992997c42cb 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -152,15 +152,21 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
       a = adev ! transfer device to host
       bdev = adev ! transfer device to device
     ```
+    
+    When the data transfer is done on data hold by descriptors, the LHS data
+    hold by the descriptor are updated. When required, the LHS decriptor is also
+    updated.
   }];
 
-  let arguments = (ins Arg<AnyReferenceLike, "", [MemRead]>:$src,
-                       Arg<AnyReferenceLike, "", [MemWrite]>:$dst,
+  let arguments = (ins Arg<AnyRefOrBoxType, "", [MemRead]>:$src,
+                       Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
     $src `to` $dst attr-dict `:` type(operands)
   }];
+
+  let hasVerifier = 1;
 }
 
 def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 4e50de3e7ee9c..3e0a6da7fc327 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3782,8 +3782,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                            hlfir::Entity &lhs, hlfir::Entity &rhs) {
     bool lhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.lhs);
     bool rhsIsDevice = Fortran::evaluate::HasCUDAAttrs(assign.rhs);
-    if (rhs.isBoxAddressOrValue() || lhs.isBoxAddressOrValue())
-      TODO(loc, "CUDA data transfler with descriptors");
+
+    auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+      if (auto loadOp =
+              mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
+        return loadOp.getMemref();
+      return val;
+    };
+
+    mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
+    mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
 
     // device = host
     if (lhsIsDevice && !rhsIsDevice) {
@@ -3792,11 +3800,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       if (!rhs.isVariable()) {
         auto associate = hlfir::genAssociateExpr(
             loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
-        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhs,
+        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
                                             transferKindAttr);
         builder.create<hlfir::EndAssociateOp>(loc, associate);
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                            transferKindAttr);
       }
       return;
     }
@@ -3805,26 +3814,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (!lhsIsDevice && rhsIsDevice) {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
-      if (!rhs.isVariable()) {
-        // evaluateRhs loads scalar. Look for the memory reference to be used in
-        // the transfer.
-        if (mlir::isa_and_nonnull<fir::LoadOp>(rhs.getDefiningOp())) {
-          auto loadOp = mlir::dyn_cast<fir::LoadOp>(rhs.getDefiningOp());
-          builder.create<cuf::DataTransferOp>(loc, loadOp.getMemref(), lhs,
-                                              transferKindAttr);
-          return;
-        }
-      } else {
-        builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
-      }
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
 
+    // device = device
     if (lhsIsDevice && rhsIsDevice) {
       assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
-      builder.create<cuf::DataTransferOp>(loc, rhs, lhs, transferKindAttr);
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 870652c72fab7..b00c374682922 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -89,6 +89,19 @@ mlir::LogicalResult cuf::AllocateOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// DataTransferOp
+//===----------------------------------------------------------------------===//
+
+mlir::LogicalResult cuf::DataTransferOp::verify() {
+  mlir::Type srcTy = getSrc().getType();
+  mlir::Type dstTy = getDst().getType();
+  if (fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy) ||
+      fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy))
+    return mlir::success();
+  return emitOpError("expect src and dst to be both references or descriptors");
+}
+
 //===----------------------------------------------------------------------===//
 // DeallocateOp
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 084314ed63ecd..e23792e6efc55 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -159,3 +159,22 @@ end subroutine
 
 ! CHECK-LABEL: func.func @_QPsub6
 ! CHECK: cuf.data_transfer
+
+subroutine sub7(a, b, c)
+  integer, device, allocatable :: a(:), c(:)
+  integer, allocatable :: b(:)
+  b = a
+
+  a = b
+
+  c = a
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub7(
+! CHECK-SAME:  %[[ARG0:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}) {
+! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]] dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ea"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]] dummy_scope %{{.*}} {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Eb"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: %[[C:.*]]:2 = hlfir.declare %[[ARG2]] dummy_scope %0 {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub7Ec"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.dscope) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: cuf.data_transfer %[[A]]#0 to %[[C]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>


        


More information about the flang-commits mailing list