[flang-commits] [flang] [flang][cuda][NFC] Extract is cuda device attribute logic (PR #100809)

via flang-commits flang-commits at lists.llvm.org
Fri Jul 26 13:06:05 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/100809.diff


1 Files Affected:

- (modified) flang/include/flang/Evaluate/tools.h (+23-26) 


``````````diff
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index 8555073a2d0d4..8c6d3b37166a9 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -1243,22 +1243,30 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
     const std::optional<ActualArgument> &, const std::string &procName,
     const std::string &argName);
 
-// Get the number of distinct symbols with CUDA attribute in the expression.
+inline bool IsCUDADeviceSymbol(const Symbol &sym) {
+  if (const auto *details =
+          sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
+    if (details->cudaDataAttr() &&
+        *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
+      return true;
+    }
+  }
+  return false;
+}
+
+// Get the number of distinct symbols with CUDA device
+// attribute in the expression.
 template <typename A> inline int GetNbOfCUDADeviceSymbols(const A &expr) {
   semantics::UnorderedSymbolSet symbols;
   for (const Symbol &sym : CollectCudaSymbols(expr)) {
-    if (const auto *details =
-            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-      if (details->cudaDataAttr() &&
-          *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-        symbols.insert(sym);
-      }
+    if (IsCUDADeviceSymbol(sym)) {
+      symbols.insert(sym);
     }
   }
   return symbols.size();
 }
 
-// Check if any of the symbols part of the expression has a CUDA data
+// Check if any of the symbols part of the expression has a CUDA device
 // attribute.
 template <typename A> inline bool HasCUDADeviceAttrs(const A &expr) {
   return GetNbOfCUDADeviceSymbols(expr) > 0;
@@ -1270,26 +1278,15 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
   unsigned hostSymbols{0};
   unsigned deviceSymbols{0};
   for (const Symbol &sym : CollectCudaSymbols(expr)) {
-    if (const auto *details =
-            sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-      if (details->cudaDataAttr() &&
-          *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-        ++deviceSymbols;
-      } else {
-        if (sym.owner().IsDerivedType()) {
-          if (const auto *details =
-                  sym.owner()
-                      .GetSymbol()
-                      ->GetUltimate()
-                      .detailsIf<semantics::ObjectEntityDetails>()) {
-            if (details->cudaDataAttr() &&
-                *details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
-              ++deviceSymbols;
-            }
-          }
+    if (IsCUDADeviceSymbol(sym)) {
+      ++deviceSymbols;
+    } else {
+      if (sym.owner().IsDerivedType()) {
+        if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
+          ++deviceSymbols;
         }
-        ++hostSymbols;
       }
+      ++hostSymbols;
     }
   }
   return hostSymbols > 0 && deviceSymbols > 0;

``````````

</details>


https://github.com/llvm/llvm-project/pull/100809


More information about the flang-commits mailing list