[flang-commits] [flang] [flang][cuda] Add hasManagedOrUnifedSymbols attribute to cuf.data_transfer op (PR #185106)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri Mar 6 12:59:02 PST 2026


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/185106

Add an attribute to signal the presence of managed or unified symbols in the data transfer. In some case, the presence of such symbols require to insert synchronization. Adding the attribute in the op during lowering facilitate the recognition of such data transfer. 

>From bf40b415781e94ffc4e4598cb182ab0dca151568 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 6 Mar 2026 12:57:15 -0800
Subject: [PATCH] [flang][cuda] Add hasManagedOrUnifedSymbols attribute to
 cuf.data_transfer op

---
 .../flang/Optimizer/Dialect/CUF/CUFOps.td     |  7 +++---
 flang/lib/Lower/Bridge.cpp                    | 22 ++++++++++++++-----
 flang/test/Lower/CUDA/cuda-data-transfer.cuf  | 12 +++++++++-
 3 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index 2bde3ac00a439..e5134a591e3ce 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -176,9 +176,10 @@ def cuf_DataTransferOp : cuf_Op<"data_transfer", []> {
   }];
 
   let arguments = (ins Arg<AnyType, "", [MemRead]>:$src,
-                       Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
-                       Optional<AnyShapeOrShiftType>:$shape,
-                       cuf_DataTransferKindAttr:$transfer_kind);
+      Arg<AnyRefOrBoxType, "", [MemWrite]>:$dst,
+      Optional<AnyShapeOrShiftType>:$shape,
+      cuf_DataTransferKindAttr:$transfer_kind,
+      UnitAttr:$hasManagedOrUnifedSymbols);
 
   let assemblyFormat = [{
     $src `to` $dst (`,` $shape^ `:` type($shape) )? attr-dict `:` type($src) `,` type($dst)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a8f405dd03d1c..a3a607be0d01f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -5374,6 +5374,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                            bool keepLhsLengthInAllocatableAssignment) {
     bool lhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs);
     bool rhsIsDevice = Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs);
+    mlir::UnitAttr hasManagedOrUnifedSymbols =
+        (Fortran::evaluate::GetNbOfCUDAManagedOrUnifiedSymbols(assign.lhs) >
+             0 ||
+         Fortran::evaluate::GetNbOfCUDAManagedOrUnifiedSymbols(assign.rhs) > 0)
+            ? mlir::UnitAttr::get(builder.getContext())
+            : nullptr;
 
     auto getRefFromValue = [](mlir::Value val) -> mlir::Value {
       if (auto loadOp =
@@ -5419,17 +5425,20 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         // Special case if the rhs is a constant.
         if (matchPattern(base.getDefiningOp(), mlir::m_Constant())) {
           cuf::DataTransferOp::create(builder, loc, base, lhsVal, shape,
-                                      transferKindAttr);
+                                      transferKindAttr,
+                                      hasManagedOrUnifedSymbols);
         } else {
           auto associate = hlfir::genAssociateExpr(
               loc, builder, rhs, rhs.getType(), ".cuf_host_tmp");
           cuf::DataTransferOp::create(builder, loc, associate.getBase(), lhsVal,
-                                      shape, transferKindAttr);
+                                      shape, transferKindAttr,
+                                      hasManagedOrUnifedSymbols);
           hlfir::EndAssociateOp::create(builder, loc, associate);
         }
       } else {
         cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape,
-                                    transferKindAttr);
+                                    transferKindAttr,
+                                    hasManagedOrUnifedSymbols);
       }
       return;
     }
@@ -5448,7 +5457,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         auto transferKindAttr = cuf::DataTransferKindAttr::get(
             builder.getContext(), cuf::DataTransferKind::DeviceHost);
         cuf::DataTransferOp::create(builder, loc, designateOp.getMemref(), temp,
-                                    /*shape=*/mlir::Value{}, transferKindAttr);
+                                    /*shape=*/mlir::Value{}, transferKindAttr,
+                                    hasManagedOrUnifedSymbols);
         designateOp.getMemrefMutable().assign(temp);
         builder.setInsertionPointAfter(elOp);
         hlfir::AssignOp::create(builder, loc, elOp, lhs,
@@ -5459,7 +5469,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceHost);
       cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape,
-                                  transferKindAttr);
+                                  transferKindAttr, hasManagedOrUnifedSymbols);
       return;
     }
 
@@ -5468,7 +5478,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       auto transferKindAttr = cuf::DataTransferKindAttr::get(
           builder.getContext(), cuf::DataTransferKind::DeviceDevice);
       cuf::DataTransferOp::create(builder, loc, rhsVal, lhsVal, shape,
-                                  transferKindAttr);
+                                  transferKindAttr, hasManagedOrUnifedSymbols);
       return;
     }
     llvm_unreachable("Unhandled CUDA data transfer");
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 866a63abd36d6..e73085a5c0077 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -482,7 +482,7 @@ end
 ! CHECK: %[[D:.*]]:2 = hlfir.declare %1(%2) {data_attr = #cuf.cuda<device>, uniq_name = "_QFsub24Ed"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<4xf32>>, !fir.ref<!fir.array<4xf32>>)
 ! CHECK: %[[M:.*]]:2 = hlfir.declare %4 {data_attr = #cuf.cuda<managed>, uniq_name = "_QFsub24Em"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
 ! CHECK: %[[D1:.*]] = hlfir.designate %[[D]]#0 (%c1{{.*}})  : (!fir.ref<!fir.array<4xf32>>, index) -> !fir.ref<f32>
-! CHECK: cuf.data_transfer %[[D1]] to %[[M]]#0 {transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<f32>, !fir.ref<f32>
+! CHECK: cuf.data_transfer %[[D1]] to %[[M]]#0 {hasManagedOrUnifedSymbols, transfer_kind = #cuf.cuda_transfer<device_device>} : !fir.ref<f32>, !fir.ref<f32>
 
 subroutine sub25()
   use mod1
@@ -629,3 +629,13 @@ end subroutine
 ! CHECK: cuf.data_transfer 
 ! CHECK-COUNT-2: hlfir.elemental
 ! CHECK: hlfir.assign
+
+subroutine sub34(n)
+  integer :: n
+  real(2), managed, allocatable :: dx(:)
+  allocate(dx(1:n))
+  dx(1:n) = 7.0_2
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub34
+! CHECK:  cuf.data_transfer %{{.*}} to %{{.*}} {hasManagedOrUnifedSymbols, transfer_kind = #cuf.cuda_transfer<host_device>} : f16, !fir.box<!fir.array<?xf16>>



More information about the flang-commits mailing list