[flang-commits] [flang] [flang][cuda][NFC] Extract is cuda device attribute logic (PR #100809)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Fri Jul 26 13:05:29 PDT 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/100809
None
>From cb34cda22b8ad7e54e7f1e424d2730a3384637cb Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 26 Jul 2024 13:04:48 -0700
Subject: [PATCH] [flang][cuda][NFC] Extract is cuda device attribute logic
---
flang/include/flang/Evaluate/tools.h | 49 +++++++++++++---------------
1 file changed, 23 insertions(+), 26 deletions(-)
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8555073a2d0d4..8c6d3b37166a9 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
const std::optional<ActualArgument> &, const std::string &procName,
const std::string &argName);
-// Get the number of distinct symbols with CUDA attribute in the expression.
+inline bool IsCUDADeviceSymbol(const Symbol &sym) {
+ if (const auto *details =
+ sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+ if (details->cudaDataAttr() &&
+ *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Get the number of distinct symbols with CUDA device
+// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
for (const Symbol &sym : CollectCudaSymbols(expr)) {
- if (const auto *details =
- sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- symbols.insert(sym);
- }
+ if (IsCUDADeviceSymbol(sym)) {
+ symbols.insert(sym);
}
}
return symbols.size();
}
-// Check if any of the symbols part of the expression has a CUDA data
+// Check if any of the symbols part of the expression has a CUDA device
// attribute.
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
return GetNbOfCUDADeviceSymbols(expr) > 0;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
for (const Symbol &sym : CollectCudaSymbols(expr)) {
- if (const auto *details =
- sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- ++deviceSymbols;
- } else {
- if (sym.owner().IsDerivedType()) {
- if (const auto *details =
- sym.owner()
- .GetSymbol()
- ->GetUltimate()
- .detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- ++deviceSymbols;
- }
- }
+ if (IsCUDADeviceSymbol(sym)) {
+ ++deviceSymbols;
+ } else {
+ if (sym.owner().IsDerivedType()) {
+ if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
+ ++deviceSymbols;
}
- ++hostSymbols;
}
+ ++hostSymbols;
}
}
return hostSymbols > 0 && deviceSymbols > 0;
More information about the flang-commits
mailing list