[flang-commits] [flang] [flang][cuda] Add shape to cuf.data_transfer operation (PR #104631)

via flang-commits flang-commits at lists.llvm.org
Fri Aug 16 11:56:36 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

When doing data transfer with dynamic sized array, we are currently generating a data transfer between two descriptors. If the shape values can be provided, we can keep the data transfer between two references. This patch adds the shape operands to the operation.

This will be exploited in lowering in a follow up patch. 

---
Full diff: https://github.com/llvm/llvm-project/pull/104631.diff


4 Files Affected:

- (modified) flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td (+2-1) 
- (modified) flang/lib/Lower/Bridge.cpp (+11-8) 
- (modified) flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp (+5) 
- (modified) flang/test/Fir/cuf-invalid.fir (+31) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index e95af629ef32f1..3e2d897ff56156 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+                       Variadic<AnyIntegerType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
-    $src `to` $dst attr-dict `:` type(operands)
+    $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
   }];
 
   let hasVerifier = 1;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ccbb481f472d81..3ab24bc163c7af 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, base, lhsVal, mlir::ValueRange{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              mlir::ValueRange{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                            transferKindAttr);
+        builder.create<cuf::DataTransferOp>(
+            loc, rhsVal, lhsVal, mlir::ValueRange{}, transferKindAttr);
       }
       return;
     }
@@ -4293,7 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
 
@@ -4303,7 +4304,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
@@ -4346,8 +4347,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           addSymbol(sym,
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
-          builder.create<cuf::DataTransferOp>(loc, addr, temp,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, addr, temp, mlir::ValueRange{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }
@@ -4444,7 +4445,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         !userDefinedAssignment) {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
+      llvm::errs() << rhs << "\n";
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
+      llvm::errs() << lhs << "\n";
       if (isCUDATransfer && !hasCUDAImplicitTransfer)
         genCUDADataTransfer(builder, loc, assign, lhs, rhs);
       else
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index f7b36b208a7deb..d02c5d752dc5a6 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
+  if (!getShape().empty()) {
+    if (!fir::isa_ref_type(srcTy) || fir::isa_ref_type(dstTy))
+      return emitOpError()
+             << "shape can only be specified on data transfer with references";
+  }
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 06e08d14b2435c..add864b5bea354 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
   cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
   return
 }
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.load %2#0 : !fir.ref<i32>
+  %5 = arith.muli %3, %4 : i32
+  %6 = fir.convert %5 : (i32) -> i64
+  %7 = fir.convert %6 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %9 : (index) -> !fir.shape<1>
+  %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  %12 = fir.load %1#0 : !fir.ref<i32>
+  %13 = fir.load %2#0 : !fir.ref<i32>
+  %14 = arith.muli %12, %13 : i32
+  %15 = fir.convert %14 : (i32) -> i64
+  %16 = fir.convert %15 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %17 = arith.cmpi sgt, %16, %c0_0 : index
+  %18 = arith.select %17, %16, %c0_0 : index
+  %19 = fir.shape %18 : (index) -> !fir.shape<1>
+  %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  // expected-error at +1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
+  cuf.data_transfer %20#0 to %11#0, %18 : index {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/104631


More information about the flang-commits mailing list