[flang-commits] [flang] d1fd369 - [flang][cuda] Allow unsupported data transfer to be done on the host (#129160)

via flang-commits flang-commits at lists.llvm.org
Sun Mar 2 16:12:05 PST 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-03-02T16:12:01-08:00
New Revision: d1fd3698a9b755250f622fd1b14c57a27e2a9d77

URL: https://github.com/llvm/llvm-project/commit/d1fd3698a9b755250f622fd1b14c57a27e2a9d77
DIFF: https://github.com/llvm/llvm-project/commit/d1fd3698a9b755250f622fd1b14c57a27e2a9d77.diff

LOG: [flang][cuda] Allow unsupported data transfer to be done on the host (#129160)

Some data transfer marked as unsupported can actually be deferred to an
assignment on the host when the variables involved are unified or
managed.

Added: 
    

Modified: 
    flang/include/flang/Evaluate/tools.h
    flang/lib/Lower/Bridge.cpp
    flang/lib/Semantics/assignment.cpp
    flang/test/Lower/CUDA/cuda-data-transfer.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index f94981011b6e5..44e0e73028bf7 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) {
   return false;
 }
 
+inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
+  if (const auto *details =
+          sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+    if (details->cudaDataAttr() &&
+        (*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
+            *details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
+      return true;
+    }
+  }
+  return false;
+}
+
 // Get the number of distinct symbols with CUDA device
 // attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
@@ -1315,12 +1327,42 @@ template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   return symbols.size();
 }
 
+// Get the number of distinct symbols with CUDA managed or unified
+// attribute in the expression.
+template <typename A>
+inline int GetNbOfCUDAManagedOrUnifiedSymbols(const A &expr) {
+  semantics::UnorderedSymbolSet symbols;
+  for (const Symbol &sym : CollectCudaSymbols(expr)) {
+    if (IsCUDAManagedOrUnifiedSymbol(sym)) {
+      symbols.insert(sym);
+    }
+  }
+  return symbols.size();
+}
+
 // Check if any of the symbols part of the expression has a CUDA device
 // attribute.
 template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
   return GetNbOfCUDADeviceSymbols(expr) > 0;
 }
 
+// Check if any of the symbols part of the lhs or rhs expression has a CUDA
+// device attribute.
+template <typename A, typename B>
+inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
+  int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(lhs)};
+  int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
+  int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};
+
+  // Special case where only managed or unifed symbols are involved. This is
+  // performed on the host.
+  if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 &&
+      rhsNbSymbols == 1) {
+    return false;
+  }
+  return HasCUDADeviceAttrs(lhs) || rhsNbSymbols > 0;
+}
+
 /// Check if the expression is a mix of host and device variables that require
 /// implicit data transfer.
 inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f824e4c621c8e..cc19f335cd017 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4640,10 +4640,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
     bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
 
-    bool isCUDATransfer = (Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs) ||
-                           Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs)) &&
-                          !isInDeviceContext;
+    bool isCUDATransfer =
+        IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;
     bool hasCUDAImplicitTransfer =
+        isCUDATransfer &&
         Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
     llvm::SmallVector<mlir::Value> implicitTemps;
 

diff  --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp
index 627983d19a822..8de20d3126a6c 100644
--- a/flang/lib/Semantics/assignment.cpp
+++ b/flang/lib/Semantics/assignment.cpp
@@ -98,6 +98,10 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
       if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) {
         if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) &&
             Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) {
+          if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 &&
+              GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 &&
+              GetNbOfCUDADeviceSymbols(rhs) == 1)
+            return; // This is a special case handled on the host.
           context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
         }
       }

diff  --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index 1c03a76cae76a..fa324abb137ec 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -393,4 +393,13 @@ end subroutine
 ! CHECK: %[[ALLOC_TMP:.*]] = fir.allocmem !fir.array<10xi32> {bindc_name = ".tmp", uniq_name = ""}
 ! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
 ! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
+! CHECK: hlfir.assign 
+
+subroutine sub20()
+  integer, managed :: a(10)
+  a = a + 2 ! ok. No data transfer. Assignment on the host.
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub20()
+! CHECK-NOT: cuf.data_transfer
 ! CHECK: hlfir.assign


        


More information about the flang-commits mailing list