[flang-commits] [flang] 0bc710f - [flang][cuda] Accept constant as src for cuf.data_tranfer (#92951)

via flang-commits flang-commits at lists.llvm.org
Tue May 21 12:42:33 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-05-21T12:42:30-07:00
New Revision: 0bc710f7c19910817ccff254c43496602635bbc9

URL: https://github.com/llvm/llvm-project/commit/0bc710f7c19910817ccff254c43496602635bbc9
DIFF: https://github.com/llvm/llvm-project/commit/0bc710f7c19910817ccff254c43496602635bbc9.diff

LOG: [flang][cuda] Accept constant as src for cuf.data_tranfer (#92951)

Assignment of a constant (host) to a device variable is a special case
that can be further lowered to `cudaMemset` or similar functions. This
patch update the lowering to avoid the creation of a temporary when we
assign a constant to a device variable.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
    flang/lib/Lower/Bridge.cpp
    flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
    flang/test/Lower/CUDA/cuda-data-transfer.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index f2992997c42cb..37b8da0181955 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -158,7 +158,7 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
     updated.
   }];
 
-  let arguments = (ins Arg<AnyRefOrBoxType, "", [MemRead]>:$src,
+  let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
                        cuf_DataTransferKindAttr:$transfer_kind);
 

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3e0a6da7fc327..898b37504a6e6 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -57,6 +57,7 @@
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -3798,11 +3799,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::HostDevice);
       if (!rhs.isVariable()) {
-        auto associate = hlfir::genAssociateExpr(
-            loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
-        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
-                                            transferKindAttr);
-        builder.create<hlfir::EndAssociateOp>(loc, associate);
+        // Special case if the rhs is a constant.
+        if (matchPattern(rhs.getDefiningOp(), mlir::m_Constant())) {
+          builder.create<cuf::DataTransferOp>(loc, rhs, lhsVal,
+                                              transferKindAttr);
+        } else {
+          auto associate = hlfir::genAssociateExpr(
+              loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
+          builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              transferKindAttr);
+          builder.create<hlfir::EndAssociateOp>(loc, associate);
+        }
       } else {
         builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
                                             transferKindAttr);

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 0446c1db86b16..2c0c4c2cfae34 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,7 +99,12 @@ mlir::LogicalResult cuf::DataTransferOp::verify() {
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)))
     return mlir::success();
-  return emitOpError("expect src and dst to be both references or descriptors");
+  if (fir::isa_trivial(srcTy) &&
+      matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
+    return mlir::success();
+  return emitOpError()
+         << "expect src and dst to be both references or descriptors or src to "
+            "be a constant";
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index e23792e6efc55..42fa4d09c95e0 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -25,6 +25,8 @@ subroutine sub1()
 
   adev = ahost + bhost
 
+  adev = 10
+
 end
 
 ! CHECK-LABEL: func.func @_QPsub1()
@@ -41,10 +43,7 @@ end
 ! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
 
-! CHECK: %[[C1:.*]] = arith.constant 1 : i32
-! CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[C1]] {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
-! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
-! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
+! CHECK: cuf.data_transfer %c1{{.*}} to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
 
 ! CHECK: cuf.data_transfer %[[AHOST]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 
@@ -62,6 +61,8 @@ end
 ! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<!fir.array<10xi32>>, i1
 
+! CHECK: cuf.data_transfer %c10{{.*}} to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
+
 subroutine sub2()
   integer, device :: m
   integer, device :: adev(10), bdev(10)


        


More information about the flang-commits mailing list