[flang-commits] [flang] 0ff9259 - [flang][cuda][NFC] Extract is cuda device attribute logic (#100809)

via flang-commits flang-commits at lists.llvm.org
Fri Jul 26 15:28:23 PDT 2024


Author: Valentin Clement (バレンタイン クレメン)
Date: 2024-07-26T15:28:20-07:00
New Revision: 0ff92593d2d8f35c99332471ef6c42df997341aa

URL: https://github.com/llvm/llvm-project/commit/0ff92593d2d8f35c99332471ef6c42df997341aa
DIFF: https://github.com/llvm/llvm-project/commit/0ff92593d2d8f35c99332471ef6c42df997341aa.diff

LOG: [flang][cuda][NFC] Extract is cuda device attribute logic (#100809)

Added: 
    

Modified: 
    flang/include/flang/Evaluate/tools.h

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8555073a2d0d4..8c6d3b37166a9 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
     const std::optional<ActualArgument> &, const std::string &procName,
     const std::string &argName);
 
-// Get the number of distinct symbols with CUDA attribute in the expression.
+inline bool IsCUDADeviceSymbol(const Symbol &sym) {
+  if (const auto *details =
+          sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+    if (details->cudaDataAttr() &&
+        *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Get the number of distinct symbols with CUDA device
+// attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   semantics::UnorderedSymbolSet symbols;
   for (const Symbol &sym : CollectCudaSymbols(expr)) {
-    if (const auto *details =
-            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-      if (details->cudaDataAttr() &&
-          *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-        symbols.insert(sym);
-      }
+    if (IsCUDADeviceSymbol(sym)) {
+      symbols.insert(sym);
     }
   }
   return symbols.size();
 }
 
-// Check if any of the symbols part of the expression has a CUDA data
+// Check if any of the symbols part of the expression has a CUDA device
 // attribute.
 template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
   return GetNbOfCUDADeviceSymbols(expr) > 0;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   unsigned hostSymbols{0};
   unsigned deviceSymbols{0};
   for (const Symbol &sym : CollectCudaSymbols(expr)) {
-    if (const auto *details =
-            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-      if (details->cudaDataAttr() &&
-          *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-        ++deviceSymbols;
-      } else {
-        if (sym.owner().IsDerivedType()) {
-          if (const auto *details =
-                  sym.owner()
-                      .GetSymbol()
-                      ->GetUltimate()
-                      .detailsIf<semantics::ObjectEntityDetails>()) {
-            if (details->cudaDataAttr() &&
-                *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-              ++deviceSymbols;
-            }
-          }
+    if (IsCUDADeviceSymbol(sym)) {
+      ++deviceSymbols;
+    } else {
+      if (sym.owner().IsDerivedType()) {
+        if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
+          ++deviceSymbols;
         }
-        ++hostSymbols;
       }
+      ++hostSymbols;
     }
   }
   return hostSymbols > 0 && deviceSymbols > 0;


        


More information about the flang-commits mailing list