[flang-commits] [flang] [flang][cuda] Accept constant as src for cuf.data_tranfer (PR #92951)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue May 21 11:44:25 PDT 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/92951

Assignment of a constant (host) to a device variable is a special case that can be further lowered to `cudaMemset` or similar functions. This patch update the lowering to avoid the creation of a temporary when we assign a constant to a device variable.

>From 4c504565e8b1c39a7ece0701a4dc16af06369bed Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 21 May 2024 10:06:55 -0700
Subject: [PATCH] [flang][cuda] Accept constant as src for cuf.data_tranfer

Assignment of a constant to a device variable is a special case
that can be further lowered to cudaMemsetor similar function.
This patch update the lowering to avoid the creation of a temporary
when we assign a constant to a device variable.
---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td       |  2 +-
 flang/lib/Lower/Bridge.cpp                      | 17 ++++++++++++-----
 flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp      |  7 ++++++-
 flang/test/Lower/CUDA/cuda-data-transfer.cuf    |  9 +++++----
 4 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index f2992997c42cb..37b8da0181955 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -158,7 +158,7 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
     updated.
   }];
 
-  let arguments = (ins Arg<AnyRefOrBoxType, "", [MemRead]>:$src,
+  let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
                        Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
                        cuf_DataTransferKindAttr:$transfer_kind);
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3e0a6da7fc327..898b37504a6e6 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -57,6 +57,7 @@
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Transforms/RegionUtils.h"
@@ -3798,11 +3799,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::HostDevice);
       if (!rhs.isVariable()) {
-        auto associate = hlfir::genAssociateExpr(
-            loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
-        builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
-                                            transferKindAttr);
-        builder.create<hlfir::EndAssociateOp>(loc, associate);
+        // Special case if the rhs is a constant.
+        if (matchPattern(rhs.getDefiningOp(), mlir::m_Constant())) {
+          builder.create<cuf::DataTransferOp>(loc, rhs, lhsVal,
+                                              transferKindAttr);
+        } else {
+          auto associate = hlfir::genAssociateExpr(
+              loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
+          builder.create<cuf::DataTransferOp>(loc, associate.getBase(), lhsVal,
+                                              transferKindAttr);
+          builder.create<hlfir::EndAssociateOp>(loc, associate);
+        }
       } else {
         builder.create<cuf::DataTransferOp>(loc, rhsVal, lhsVal,
                                             transferKindAttr);
diff --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index 0446c1db86b16..2c0c4c2cfae34 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -99,7 +99,12 @@ mlir::LogicalResult cuf::DataTransferOp::verify() {
   if ((fir::isa_ref_type(srcTy) && fir::isa_ref_type(dstTy)) ||
       (fir::isa_box_type(srcTy) && fir::isa_box_type(dstTy)))
     return mlir::success();
-  return emitOpError("expect src and dst to be both references or descriptors");
+  if (fir::isa_trivial(srcTy) &&
+      matchPattern(getSrc().getDefiningOp(), mlir::m_Constant()))
+    return mlir::success();
+  return emitOpError()
+         << "expect src and dst to be both references or descriptors or src to "
+            "be a constant";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index e23792e6efc55..42fa4d09c95e0 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -25,6 +25,8 @@ subroutine sub1()
 
   adev = ahost + bhost
 
+  adev = 10
+
 end
 
 ! CHECK-LABEL: func.func @_QPsub1()
@@ -41,10 +43,7 @@ end
 ! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
 
-! CHECK: %[[C1:.*]] = arith.constant 1 : i32
-! CHECK: %[[ASSOC:.*]]:3 = hlfir.associate %[[C1]] {uniq_name = ".cuf_host_tmp"} : (i32) -> (!fir.ref<i32>, !fir.ref<i32>, i1)
-! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<i32>, !fir.ref<i32>
-! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<i32>, i1
+! CHECK: cuf.data_transfer %c1{{.*}} to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
 
 ! CHECK: cuf.data_transfer %[[AHOST]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 
@@ -62,6 +61,8 @@ end
 ! CHECK: cuf.data_transfer %[[ASSOC]]#0 to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>
 ! CHECK: hlfir.end_associate %[[ASSOC]]#1, %[[ASSOC]]#2 : !fir.ref<!fir.array<10xi32>>, i1
 
+! CHECK: cuf.data_transfer %c10{{.*}} to %[[ADEV]]#0 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<!fir.array<10xi32>>
+
 subroutine sub2()
   integer, device :: m
   integer, device :: adev(10), bdev(10)



More information about the flang-commits mailing list