[flang-commits] [flang] 7af61d5 - [flang][cuda] Add shape to cuf.data_transfer operation (#104631)

via flang-commits flang-commits at lists.llvm.org
Mon Aug 26 09:50:20 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-08-26T09:50:17-07:00
New Revision: 7af61d5cf464f1d716c82bc77907fa3fe4ebc841

URL: https://github.com/llvm/llvm-project/commit/7af61d5cf464f1d716c82bc77907fa3fe4ebc841
DIFF: https://github.com/llvm/llvm-project/commit/7af61d5cf464f1d716c82bc77907fa3fe4ebc841.diff

LOG: [flang][cuda] Add shape to cuf.data_transfer operation (#104631)

When doing data transfer with dynamic sized array, we are currently
generating a data transfer between two descriptors. If the shape values
can be provided, we can keep the data transfer between two references.
This patch adds the shape operands to the operation.

This will be exploited in lowering in a follow up patch.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
    flang/lib/Lower/Bridge.cpp
    flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
    flang/test/Fir/cuf-invalid.fir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index e95af629ef32f1..f643674f1d5d6b 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+                       Optional<fir_ShapeType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
-    $src `to` $dst attr-dict `:` type(operands)
+    $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
   }];
 
   let hasVerifier = 1;

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ccbb481f472d81..24cd6b22b89259 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              /*shape=*/mlir::Value{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                            transferKindAttr);
+        builder.create<cuf::DataTransferOp>(
+            loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
       }
       return;
     }
@@ -4293,6 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          /*shape=*/mlir::Value{},
                                           transferKindAttr);
       return;
     }
@@ -4303,6 +4305,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
+                                          /*shape=*/mlir::Value{},
                                           transferKindAttr);
       return;
     }
@@ -4346,8 +4349,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           addSymbol(sym,
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
-          builder.create<cuf::DataTransferOp>(loc, addr, temp,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, addr, temp, /*shape=*/mlir::Value{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index f7b36b208a7deb..3b4ad95cafe6b5 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
+  if (getShape()) {
+    if (!fir::isa_ref_type(srcTy) || !fir::isa_ref_type(dstTy))
+      return emitOpError()
+             << "shape can only be specified on data transfer with references";
+  }
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||

diff  --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 06e08d14b2435c..e9aeaa281e2a85 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
   cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
   return
 }
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.load %2#0 : !fir.ref<i32>
+  %5 = arith.muli %3, %4 : i32
+  %6 = fir.convert %5 : (i32) -> i64
+  %7 = fir.convert %6 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %9 : (index) -> !fir.shape<1>
+  %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  %12 = fir.load %1#0 : !fir.ref<i32>
+  %13 = fir.load %2#0 : !fir.ref<i32>
+  %14 = arith.muli %12, %13 : i32
+  %15 = fir.convert %14 : (i32) -> i64
+  %16 = fir.convert %15 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %17 = arith.cmpi sgt, %16, %c0_0 : index
+  %18 = arith.select %17, %16, %c0_0 : index
+  %19 = fir.shape %18 : (index) -> !fir.shape<1>
+  %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  // expected-error at +1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
+  cuf.data_transfer %20#0 to %11#0, %19 : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  return
+}


        


More information about the flang-commits mailing list