[flang-commits] [flang] [flang][cuda] Add shape to cuf.data_transfer operation (PR #104631)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Aug 16 11:56:00 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/104631

When doing data transfer with dynamic sized array, we are currently generating a data transfer between two descriptors. If the shape values can be provided, we can keep the data transfer between two references. This patch adds the shape operands to the operation.

This will be exploited in lowering in a follow up patch. 

>From b5926296de15c5b36d86925bf3c1c0761ba01087 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 16 Aug 2024 11:52:32 -0700
Subject: [PATCH] [flang][cuda] Add shape to cuf.data_transfer operation

When doing data transfer with dynamic sized array, we are currently
generating a data transfer between two descriptors. If the shape
values can be provided, we can keep the data transfer between two
references. This patch adds the shape operands to the operation.
---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     |  3 +-
 flang/lib/Lower/Bridge.cpp                    | 19 +++++++-----
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp    |  5 +++
 flang/test/Fir/cuf-invalid.fir                | 31 +++++++++++++++++++
 4 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index e95af629ef32f1..3e2d897ff56156 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+                       Variadic<AnyIntegerType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
-    $src `to` $dst attr-dict `:` type(operands)
+    $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
   }];
 
   let hasVerifier = 1;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ccbb481f472d81..3ab24bc163c7af 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, base, lhsVal, mlir::ValueRange{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              mlir::ValueRange{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                            transferKindAttr);
+        builder.create<cuf::DataTransferOp>(
+            loc, rhsVal, lhsVal, mlir::ValueRange{}, transferKindAttr);
       }
       return;
     }
@@ -4293,7 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
 
@@ -4303,7 +4304,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
@@ -4346,8 +4347,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           addSymbol(sym,
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
-          builder.create<cuf::DataTransferOp>(loc, addr, temp,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, addr, temp, mlir::ValueRange{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }
@@ -4444,7 +4445,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         !userDefinedAssignment) {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
+      llvm::errs() << rhs << "\n";
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
+      llvm::errs() << lhs << "\n";
       if (isCUDATransfer && !hasCUDAImplicitTransfer)
         genCUDADataTransfer(builder, loc, assign, lhs, rhs);
       else
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index f7b36b208a7deb..d02c5d752dc5a6 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
+  if (!getShape().empty()) {
+    if (!fir::isa_ref_type(srcTy) || fir::isa_ref_type(dstTy))
+      return emitOpError()
+             << "shape can only be specified on data transfer with references";
+  }
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 06e08d14b2435c..add864b5bea354 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
   cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
   return
 }
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.load %2#0 : !fir.ref<i32>
+  %5 = arith.muli %3, %4 : i32
+  %6 = fir.convert %5 : (i32) -> i64
+  %7 = fir.convert %6 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %9 : (index) -> !fir.shape<1>
+  %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  %12 = fir.load %1#0 : !fir.ref<i32>
+  %13 = fir.load %2#0 : !fir.ref<i32>
+  %14 = arith.muli %12, %13 : i32
+  %15 = fir.convert %14 : (i32) -> i64
+  %16 = fir.convert %15 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %17 = arith.cmpi sgt, %16, %c0_0 : index
+  %18 = arith.select %17, %16, %c0_0 : index
+  %19 = fir.shape %18 : (index) -> !fir.shape<1>
+  %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  // expected-error at +1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
+  cuf.data_transfer %20#0 to %11#0, %18 : index {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  return
+}



More information about the flang-commits mailing list