[flang-commits] [flang] 0ff9259 - [flang][cuda][NFC] Extract is cuda device attribute logic (#100809)
via flang-commits
flang-commits at lists.llvm.org
Fri Jul 26 15:28:23 PDT 2024
Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-07-26T15:28:20-07:00
New Revision: 0ff92593d2d8f35c99332471ef6c42df997341aa
URL: https://github.com/llvm/llvm-project/commit/0ff92593d2d8f35c99332471ef6c42df997341aa
DIFF: https://github.com/llvm/llvm-project/commit/0ff92593d2d8f35c99332471ef6c42df997341aa.diff
LOG: [flang][cuda][NFC] Extract is cuda device attribute logic (#100809)
Added:
Modified:
flang/include/flang/Evaluate/tools.h
Removed:
################################################################################
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8555073a2d0d4..8c6d3b37166a9 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
const std::optional<ActualArgument> &, const std::string &procName,
const std::string &argName);
-// Get the number of distinct symbols with CUDA attribute in the expression.
+inline bool IsCUDADeviceSymbol(const Symbol &sym) {
+ if (const auto *details =
+ sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+ if (details->cudaDataAttr() &&
+ *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Get the number of distinct symbols with CUDA device
+// attribute in the expression.
template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
semantics::UnorderedSymbolSet symbols;
for (const Symbol &sym : CollectCudaSymbols(expr)) {
- if (const auto *details =
- sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- symbols.insert(sym);
- }
+ if (IsCUDADeviceSymbol(sym)) {
+ symbols.insert(sym);
}
}
return symbols.size();
}
-// Check if any of the symbols part of the expression has a CUDA data
+// Check if any of the symbols part of the expression has a CUDA device
// attribute.
template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
return GetNbOfCUDADeviceSymbols(expr) > 0;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
unsigned hostSymbols{0};
unsigned deviceSymbols{0};
for (const Symbol &sym : CollectCudaSymbols(expr)) {
- if (const auto *details =
- sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- ++deviceSymbols;
- } else {
- if (sym.owner().IsDerivedType()) {
- if (const auto *details =
- sym.owner()
- .GetSymbol()
- ->GetUltimate()
- .detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
- ++deviceSymbols;
- }
- }
+ if (IsCUDADeviceSymbol(sym)) {
+ ++deviceSymbols;
+ } else {
+ if (sym.owner().IsDerivedType()) {
+ if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
+ ++deviceSymbols;
}
- ++hostSymbols;
}
+ ++hostSymbols;
}
}
return hostSymbols > 0 && deviceSymbols > 0;
More information about the flang-commits
mailing list