[flang-commits] [flang] [flang][cuda] Simplify data transfer when possible (PR #106120)

via flang-commits flang-commits at lists.llvm.org
Mon Aug 26 11:57:07 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

When possible, avoid using descriptors and use the reference and the shape for data_transfer.

---
Full diff: https://github.com/llvm/llvm-project/pull/106120.diff


3 Files Affected:

- (modified) flang/lib/Lower/Bridge.cpp (+32-13) 
- (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+3-1) 
- (modified) flang/test/Lower/CUDA/cuda-data-transfer.cuf (+17-2) 


``````````diff
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 24cd6b22b89259..a414cd5cb0b934 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4251,15 +4251,37 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs);
     bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs);
 
-    auto getRefIfLoaded = [](mlir::Value val) -> mlir::Value {
+    auto getRefFromValue = [](mlir::Value val) -> mlir::Value {
       if (auto loadOp =
               mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
         return loadOp.getMemref();
+      if (!mlir::isa<fir::BaseBoxType>(val.getType()))
+        return val;
+      if (auto declOp =
+              mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp())) {
+        if (!declOp.getShape())
+          return val;
+        if (mlir::isa<fir::ReferenceType>(declOp.getMemref().getType()))
+          return declOp.getMemref();
+      }
       return val;
     };
 
-    mlir::Value rhsVal = getRefIfLoaded(rhs.getBase());
-    mlir::Value lhsVal = getRefIfLoaded(lhs.getBase());
+    auto getShapeFromDecl = [](mlir::Value val) -> mlir::Value {
+      if (!mlir::isa<fir::BaseBoxType>(val.getType()))
+        return {};
+      if (auto declOp =
+              mlir::dyn_cast_or_null<hlfir::DeclareOp>(val.getDefiningOp()))
+        return declOp.getShape();
+      return {};
+    };
+
+    mlir::Value rhsVal = getRefFromValue(rhs.getBase());
+    mlir::Value lhsVal = getRefFromValue(lhs.getBase());
+    // Get shape from the rhs if available otherwise get it from lhs.
+    mlir::Value shape = getShapeFromDecl(rhs.getBase());
+    if (!shape)
+      shape = getShapeFromDecl(lhs.getBase());
 
     // device = host
     if (lhsIsDevice && !rhsIsDevice) {
@@ -4272,19 +4294,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(
-              loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
+          builder.create<cuf::DataTransferOp>(loc, base, lhsVal, shape,
+                                              transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
-                                              /*shape=*/mlir::Value{},
-                                              transferKindAttr);
+                                              shape, transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(
-            loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
+        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
+                                            transferKindAttr);
       }
       return;
     }
@@ -4293,8 +4314,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (!lhsIsDevice && rhsIsDevice) {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
-      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          /*shape=*/mlir::Value{},
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
                                           transferKindAttr);
       return;
     }
@@ -4304,8 +4324,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       assert(rhs.isVariable() && "CUDA Fortran assignment rhs is not legal");
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
-      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          /*shape=*/mlir::Value{},
+      builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal, shape,
                                           transferKindAttr);
       return;
     }
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 3b4ad95cafe6b5..7fb2dcf4af115c 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -112,9 +112,11 @@ llvm::LogicalResult cuf::DataTransferOp::verify() {
   if (fir::isa_trivial(srcTy) &&
       matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
     return mlir::success();
+
   return emitOpError()
          << "expect src and dst to be references or descriptors or src to "
-            "be a constant";
+            "be a constant: "
+         << srcTy << " - " << dstTy;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 7eb74a4234f5a7..b2ae8c9f82ebb9 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -11,6 +11,7 @@ contains
   function dev1(a)
     integer, device :: a(:)
     integer :: dev1
+    dev1 = 1
   end function
 end
 
@@ -198,8 +199,8 @@ end subroutine
 ! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}, %[[ARG1:.*]]: !fir.ref<!fir.array<10xi32>> {fir.bindc_name = "b"}, %[[ARG2:.*]]: !fir.ref<i32> {fir.bindc_name = "n"})
 ! CHECK: %[[B:.*]]:2 = hlfir.declare %[[ARG1]](%{{.*}}) dummy_scope %{{.*}} {uniq_name = "_QFsub8Eb"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
 ! CHECK: %[[A:.*]]:2 = hlfir.declare %[[ARG0]](%{{.*}}) dummy_scope %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub8Ea"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
-! CHECK: cuf.data_transfer %[[A]]#0 to %[[B]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<10xi32>>
-! CHECK: cuf.data_transfer %[[B]]#0 to %[[A]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.box<!fir.array<?xi32>>
+! CHECK: cuf.data_transfer %[[ARG0]] to %[[B]]#0, %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<?xi32>>, !fir.ref<!fir.array<10xi32>>
+! CHECK: cuf.data_transfer %[[B]]#0 to %[[ARG0]], %{{.*}} : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<?xi32>>
 
 subroutine sub9(a)
   integer, pinned, allocatable :: a(:)
@@ -274,3 +275,17 @@ end subroutine
 ! CHECK-LABEL: func.func @_QPsub14()
 ! CHECK: %[[TRUE:.*]] = arith.constant true
 ! CHECK: cuf.data_transfer %[[TRUE]] to %{{.*}}#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i1, !fir.ref<!fir.array<10x!fir.logical<4>>>
+
+subroutine sub15(a_dev, a_host, n, m)
+  integer, intent(in) :: n, m
+  real, device :: a_dev(n*m)
+  real :: a_host(n*m)
+
+  a_dev = a_host
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub15(
+! CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a_dev"}, %[[ARG1:.*]]: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a_host"}
+! CHECK: %{{.*}} = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: %[[SHAPE:.*]] = fir.shape %{{.*}} : (index) -> !fir.shape<1>
+! CHECK: cuf.data_transfer %[[ARG1]] to %[[ARG0]], %[[SHAPE]] : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>

``````````

</details>


https://github.com/llvm/llvm-project/pull/106120


More information about the flang-commits mailing list