[flang-commits] [flang] [flang][cuda] Allow unsupported data transfer to be done on the host (PR #129160)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Feb 27 16:28:43 PST 2025
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/129160
Some data transfer marked as unsupported can actually be deferred to an assignment on the host when the variables involved are unified or managed.
>From a75b570296bc63a4a2ef46c2a10db5362eb3b66f Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 27 Feb 2025 16:27:23 -0800
Subject: [PATCH] [flang][cuda] Allow unsupported data transfer to be done on
the host
---
flang/include/flang/Evaluate/tools.h | 42 ++++++++++++++++++++
flang/lib/Lower/Bridge.cpp | 6 +--
flang/lib/Semantics/assignment.cpp | 4 ++
flang/test/Lower/CUDA/cuda-data-transfer.cuf | 9 +++++
4 files changed, 58 insertions(+), 3 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 352f6b36458ce..729aef3a7c9f2 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1303,6 +1303,18 @@ inline bool IsCUDADeviceSymbol(const Symbol &sym) {
return false;
}
+inline bool IsCUDAManagedOrUnifiedSymbol(const Symbol &sym) {
+ if (const auto *details =
+ sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+ if (details->cudaDataAttr() &&
+ (*details->cudaDataAttr() == common::CUDADataAttr::Managed ||
+ *details->cudaDataAttr() == common::CUDADataAttr::Unified)) {
+ return true;
+ }
+ }
+ return false;
+}
+
// Get the number of distinct symbols with CUDA device
// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
@@ -1315,12 +1327,42 @@ template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
return symbols.size();
}
+// Get the number of distinct symbols with CUDA managed or unified
+// attribute in the expression.
+template <typename A>
+inline int GetNbOfCUDAManagedOrUnifiedSymbols(const A &expr) {
+ semantics::UnorderedSymbolSet symbols;
+ for (const Symbol &sym : CollectCudaSymbols(expr)) {
+ if (IsCUDAManagedOrUnifiedSymbol(sym)) {
+ symbols.insert(sym);
+ }
+ }
+ return symbols.size();
+}
+
// Check if any of the symbols part of the expression has a CUDA device
// attribute.
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
return GetNbOfCUDADeviceSymbols(expr) > 0;
}
+// Check if any of the symbols part of the lhs or rhs expression has a CUDA
+// device attribute.
+template <typename A, typename B>
+inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
+ int lhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(lhs)};
+ int rhsNbManagedSymbols = {GetNbOfCUDAManagedOrUnifiedSymbols(rhs)};
+ int rhsNbSymbols{GetNbOfCUDADeviceSymbols(rhs)};
+
+ // Special case where only managed or unifed symbols are involved. This is
+ // performed on the host.
+ if (lhsNbManagedSymbols == 1 && rhsNbManagedSymbols == 1 &&
+ rhsNbSymbols == 1) {
+ return false;
+ }
+ return HasCUDADeviceAttrs(lhs) || rhsNbSymbols > 0;
+}
+
/// Check if the expression is a mix of host and device variables that require
/// implicit data transfer.
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f824e4c621c8e..cc19f335cd017 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -4640,10 +4640,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool isInDeviceContext = Fortran::lower::isCudaDeviceContext(builder);
- bool isCUDATransfer = (Fortran::evaluate::HasCUDADeviceAttrs(assign.lhs) ||
- Fortran::evaluate::HasCUDADeviceAttrs(assign.rhs)) &&
- !isInDeviceContext;
+ bool isCUDATransfer =
+ IsCUDADataTransfer(assign.lhs, assign.rhs) && !isInDeviceContext;
bool hasCUDAImplicitTransfer =
+ isCUDATransfer &&
Fortran::evaluate::HasCUDAImplicitTransfer(assign.rhs);
llvm::SmallVector<mlir::Value> implicitTemps;
diff --git a/flang/lib/Semantics/assignment.cpp b/flang/lib/Semantics/assignment.cpp
index 627983d19a822..8de20d3126a6c 100644
--- a/flang/lib/Semantics/assignment.cpp
+++ b/flang/lib/Semantics/assignment.cpp
@@ -98,6 +98,10 @@ void AssignmentContext::Analyze(const parser::AssignmentStmt &stmt) {
if (!IsCUDADeviceContext(&progUnit) && deviceConstructDepth_ == 0) {
if (Fortran::evaluate::HasCUDADeviceAttrs(lhs) &&
Fortran::evaluate::HasCUDAImplicitTransfer(rhs)) {
+ if (GetNbOfCUDAManagedOrUnifiedSymbols(lhs) == 1 &&
+ GetNbOfCUDAManagedOrUnifiedSymbols(rhs) == 1 &&
+ GetNbOfCUDADeviceSymbols(rhs) == 1)
+ return; // This is a special case handled on the host.
context_.Say(lhsLoc, "Unsupported CUDA data transfer"_err_en_US);
}
}
diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
index cbddcd79c6333..aa6cc44f599aa 100644
--- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf
+++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf
@@ -394,3 +394,12 @@ end subroutine
! CHECK: %[[TMP:.*]]:2 = hlfir.declare %[[ALLOC_TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>)
! CHECK: cuf.data_transfer %[[ADEV_DECL]]#1 to %[[TMP]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<10xi32>>, !fir.heap<!fir.array<10xi32>>
! CHECL: hlfir.assign
+
+subroutine sub20()
+ integer, managed :: a(10)
+ a = a + 2 ! ok. No data transfer. Assignment on the host.
+end subroutine
+
+! CHECK-LABEL: func.func @_QPsub20()
+! CHECK-NOT: cuf.data_transfer
+! CHECK: hlfir.assign
More information about the flang-commits
mailing list