[flang-commits] [flang] [flang][cuda] Update implicit data transfer for device component (PR #147882)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Jul 9 21:17:22 PDT 2025


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/147882

Update the detection of implicit data transfer when a device resident allocatable derived-type component is involved  and remove the TODOs.

>From f45edd9dbf728cd3174f9890f3541e559e2bdd92 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 9 Jul 2025 21:15:38 -0700
Subject: [PATCH] [flang][cuda] Update implicit data transfer for device
 component

---
 flang/include/flang/Evaluate/tools.h         | 19 +--------------
 flang/lib/Evaluate/tools.cpp                 | 25 ++++++++++++++++++++
 flang/lib/Lower/Bridge.cpp                   |  2 --
 flang/test/Lower/CUDA/cuda-data-transfer.cuf | 17 +++++++++++++
 4 files changed, 43 insertions(+), 20 deletions(-)

diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 18c244f6f450f..96ed86f468350 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1359,24 +1359,7 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
 
 /// Check if the expression is a mix of host and device variables that require
 /// implicit data transfer.
-inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
-  unsigned hostSymbols{0};
-  unsigned deviceSymbols{0};
-  for (const Symbol &sym : CollectCudaSymbols(expr)) {
-    if (IsCUDADeviceSymbol(sym)) {
-      ++deviceSymbols;
-    } else {
-      if (sym.owner().IsDerivedType()) {
-        if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
-          ++deviceSymbols;
-        }
-      }
-      ++hostSymbols;
-    }
-  }
-  bool hasConstant{HasConstant(expr)};
-  return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
-}
+bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr);
 
 // Checks whether the symbol on the LHS is present in the RHS expression.
 bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp
index 6a57d87a30e93..3d9f06308d8c1 100644
--- a/flang/lib/Evaluate/tools.cpp
+++ b/flang/lib/Evaluate/tools.cpp
@@ -1154,6 +1154,31 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols(
 template semantics::UnorderedSymbolSet CollectCudaSymbols(
     const Expr<SubscriptInteger> &);
 
+bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
+  semantics::UnorderedSymbolSet hostSymbols;
+  semantics::UnorderedSymbolSet deviceSymbols;
+
+  SymbolVector symbols{GetSymbolVector(expr)};
+  std::reverse(symbols.begin(), symbols.end());
+  bool skipNext{false};
+  for (const Symbol &sym : symbols) {
+    bool isComponent{sym.owner().IsDerivedType()};
+    bool skipComponent{false};
+    if (!skipNext) {
+      if (IsCUDADeviceSymbol(sym)) {
+        deviceSymbols.insert(sym);
+      } else if (isComponent) {
+        skipComponent = true; // Component is not device. Look on the base.
+      } else {
+        hostSymbols.insert(sym);
+      }
+    }
+    skipNext = isComponent && !skipComponent;
+  }
+  bool hasConstant{HasConstant(expr)};
+  return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0;
+}
+
 // HasVectorSubscript()
 struct HasVectorSubscriptHelper
     : public AnyTraverse<HasVectorSubscriptHelper, bool,
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 696473605a4e0..ff26aa8e33649 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4842,8 +4842,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                   .detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
         if (details->cudaDataAttr() &&
             *details->cudaDataAttr() != Fortran::common::CUDADataAttr::Pinned) {
-          if (sym.owner().IsDerivedType() && IsAllocatable(sym.GetUltimate()))
-            TODO(loc, "Device resident allocatable derived-type component");
           // TODO: This should probably being checked in semantic and give a
           // proper error.
           assert(
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 68a0202f951fe..3a9b55996d9b1 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -7,6 +7,10 @@ module mod1
     integer :: i
   end type
 
+  type :: t2
+    integer, device, allocatable, dimension(:) :: x
+  end type
+
   integer, device, dimension(11:20) :: cdev
 
 contains
@@ -419,3 +423,16 @@ end subroutine
 ! CHECK: fir.do_concurrent.loop
 ! CHECK-NOT: cuf.data_transfer
 ! CHECK: hlfir.assign
+
+
+subroutine sub22()
+  use mod1
+  type(t2) :: a
+  integer :: b(100)
+  allocate(a%x(100))
+
+  b = a%x
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub22()
+! CHECK: cuf.data_transfer



More information about the flang-commits mailing list