[flang-commits] [flang] [flang][cuda] Add shape to cuf.data_transfer operation (PR #104631)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Aug 16 14:05:25 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/104631

>From b5926296de15c5b36d86925bf3c1c0761ba01087 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 16 Aug 2024 11:52:32 -0700
Subject: [PATCH 1/2] [flang][cuda] Add shape to cuf.data_transfer operation

When doing data transfer with dynamic sized array, we are currently
generating a data transfer between two descriptors. If the shape
values can be provided, we can keep the data transfer between two
references. This patch adds the shape operands to the operation.
---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     |  3 +-
 flang/lib/Lower/Bridge.cpp                    | 19 +++++++-----
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp    |  5 +++
 flang/test/Fir/cuf-invalid.fir                | 31 +++++++++++++++++++
 4 files changed, 49 insertions(+), 9 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index e95af629ef32f1..3e2d897ff56156 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,10 +161,11 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+                       Variadic<AnyIntegerType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
-    $src `to` $dst attr-dict `:` type(operands)
+    $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
   }];
 
   let hasVerifier = 1;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index ccbb481f472d81..3ab24bc163c7af 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4272,18 +4272,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           base = convertOp.getValue();
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
-          builder.create<cuf::DataTransferOp>(loc, base, lhsVal,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, base, lhsVal, mlir::ValueRange{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              mlir::ValueRange{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
-        builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                            transferKindAttr);
+        builder.create<cuf::DataTransferOp>(
+            loc, rhsVal, lhsVal, mlir::ValueRange{}, transferKindAttr);
       }
       return;
     }
@@ -4293,7 +4294,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
 
@@ -4303,7 +4304,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          transferKindAttr);
+                                          mlir::ValueRange{}, transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
@@ -4346,8 +4347,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           addSymbol(sym,
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
-          builder.create<cuf::DataTransferOp>(loc, addr, temp,
-                                              transferKindAttr);
+          builder.create<cuf::DataTransferOp>(
+              loc, addr, temp, mlir::ValueRange{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }
@@ -4444,7 +4445,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         !userDefinedAssignment) {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
+      llvm::errs() << rhs << "\n";
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
+      llvm::errs() << lhs << "\n";
       if (isCUDATransfer && !hasCUDAImplicitTransfer)
         genCUDADataTransfer(builder, loc, assign, lhs, rhs);
       else
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index f7b36b208a7deb..d02c5d752dc5a6 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,6 +99,11 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
+  if (!getShape().empty()) {
+    if (!fir::isa_ref_type(srcTy) || fir::isa_ref_type(dstTy))
+      return emitOpError()
+             << "shape can only be specified on data transfer with references";
+  }
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)) ||
       (fir::isa_ref_type(srcTy) && fir::isa_box_type(dstTy)) ||
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 06e08d14b2435c..add864b5bea354 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -94,3 +94,34 @@ func.func @_QPsub1() {
   cuf.free %0 : !fir.ref<f32> {data_attr = #cuf.cuda<constant>}
   return
 }
+
+// -----
+
+func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "adev"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "ahost"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}, %arg3: !fir.ref<i32> {fir.bindc_name = "m"}) {
+  %0 = fir.dummy_scope : !fir.dscope
+  %1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFsub1En"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %2:2 = hlfir.declare %arg3 dummy_scope %0 {uniq_name = "_QFsub1Em"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %3 = fir.load %1#0 : !fir.ref<i32>
+  %4 = fir.load %2#0 : !fir.ref<i32>
+  %5 = arith.muli %3, %4 : i32
+  %6 = fir.convert %5 : (i32) -> i64
+  %7 = fir.convert %6 : (i64) -> index
+  %c0 = arith.constant 0 : index
+  %8 = arith.cmpi sgt, %7, %c0 : index
+  %9 = arith.select %8, %7, %c0 : index
+  %10 = fir.shape %9 : (index) -> !fir.shape<1>
+  %11:2 = hlfir.declare %arg0(%10) dummy_scope %0 {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub1Eadev"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  %12 = fir.load %1#0 : !fir.ref<i32>
+  %13 = fir.load %2#0 : !fir.ref<i32>
+  %14 = arith.muli %12, %13 : i32
+  %15 = fir.convert %14 : (i32) -> i64
+  %16 = fir.convert %15 : (i64) -> index
+  %c0_0 = arith.constant 0 : index
+  %17 = arith.cmpi sgt, %16, %c0_0 : index
+  %18 = arith.select %17, %16, %c0_0 : index
+  %19 = fir.shape %18 : (index) -> !fir.shape<1>
+  %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
+  // expected-error at +1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
+  cuf.data_transfer %20#0 to %11#0, %18 : index {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  return
+}

>From d220af9a582d7b110d829b0e3610f27437152c33 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 16 Aug 2024 14:04:47 -0700
Subject: [PATCH 2/2] Address comments

---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td        |  2 +-
 flang/lib/Lower/Bridge.cpp                       | 16 ++++++++--------
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp       |  4 ++--
 flang/test/Fir/cuf-invalid.fir                   |  2 +-
 4 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 3e2d897ff56156..f643674f1d5d6b 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -161,7 +161,7 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
-                       Variadic<AnyIntegerType>:$shape,
+                       Optional<fir_ShapeType>:$shape,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
   let assemblyFormat = [{
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3ab24bc163c7af..24cd6b22b89259 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4273,18 +4273,18 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
           builder.create<cuf::DataTransferOp>(
-              loc, base, lhsVal, mlir::ValueRange{}, transferKindAttr);
+              loc, base, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
-                                              mlir::ValueRange{},
+                                              /*shape=*/mlir::Value{},
                                               transferKindAttr);
           builder.create<hlfir::EndAssociateOp>(loc, associate);
         }
       } else {
         builder.create<cuf::DataTransferOp>(
-            loc, rhsVal, lhsVal, mlir::ValueRange{}, transferKindAttr);
+            loc, rhsVal, lhsVal, /*shape=*/mlir::Value{}, transferKindAttr);
       }
       return;
     }
@@ -4294,7 +4294,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          mlir::ValueRange{}, transferKindAttr);
+                                          /*shape=*/mlir::Value{},
+                                          transferKindAttr);
       return;
     }
 
@@ -4304,7 +4305,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
-                                          mlir::ValueRange{}, transferKindAttr);
+                                          /*shape=*/mlir::Value{},
+                                          transferKindAttr);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
@@ -4348,7 +4350,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                     hlfir::translateToExtendedValue(loc, builder, temp).first,
                     /*forced=*/true);
           builder.create<cuf::DataTransferOp>(
-              loc, addr, temp, mlir::ValueRange{}, transferKindAttr);
+              loc, addr, temp, /*shape=*/mlir::Value{}, transferKindAttr);
           ++nbDeviceResidentObject;
         }
       }
@@ -4445,9 +4447,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         !userDefinedAssignment) {
       Fortran::lower::StatementContext localStmtCtx;
       hlfir::Entity rhs = evaluateRhs(localStmtCtx);
-      llvm::errs() << rhs << "\n";
       hlfir::Entity lhs = evaluateLhs(localStmtCtx);
-      llvm::errs() << lhs << "\n";
       if (isCUDATransfer && !hasCUDAImplicitTransfer)
         genCUDADataTransfer(builder, loc, assign, lhs, rhs);
       else
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index d02c5d752dc5a6..3b4ad95cafe6b5 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,8 +99,8 @@ llvm::LogicalResult cuf::AllocateOp::verify() {
 llvm::LogicalResult cuf::DataTransferOp::verify() {
   mlir::Type srcTy = getSrc().getType();
   mlir::Type dstTy = getDst().getType();
-  if (!getShape().empty()) {
-    if (!fir::isa_ref_type(srcTy) || fir::isa_ref_type(dstTy))
+  if (getShape()) {
+    if (!fir::isa_ref_type(srcTy) || !fir::isa_ref_type(dstTy))
       return emitOpError()
              << "shape can only be specified on data transfer with references";
   }
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index add864b5bea354..e9aeaa281e2a85 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -122,6 +122,6 @@ func.func @_QPsub1(%arg0: !fir.ref<!fir.array<?xf32>> {cuf.data_attr = #cuf.cuda
   %19 = fir.shape %18 : (index) -> !fir.shape<1>
   %20:2 = hlfir.declare %arg1(%19) dummy_scope %0 {uniq_name = "_QFsub1Eahost"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
   // expected-error at +1{{'cuf.data_transfer' op shape can only be specified on data transfer with references}}
-  cuf.data_transfer %20#0 to %11#0, %18 : index {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
+  cuf.data_transfer %20#0 to %11#0, %19 : !fir.shape<1> {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.box<!fir.array<?xf32>>, !fir.box<!fir.array<?xf32>>
   return
 }



More information about the flang-commits mailing list