[flang-commits] [flang] 73216cd - [flang] Rework CUDA kernel DO host array check (#116301)

via flang-commits flang-commits at lists.llvm.org
Tue Nov 19 16:19:35 PST 2024


Author: Peter Klausler
Date: 2024-11-19T16:19:32-08:00
New Revision: 73216cd71229fc7ccd380c334d45f809787f41b1

URL: https://github.com/llvm/llvm-project/commit/73216cd71229fc7ccd380c334d45f809787f41b1
DIFF: https://github.com/llvm/llvm-project/commit/73216cd71229fc7ccd380c334d45f809787f41b1.diff

LOG: [flang] Rework CUDA kernel DO host array check (#116301)

Don't worry about derived type components unless they are pointers or
allocatables.

Added: 
    

Modified: 
    flang/lib/Semantics/check-cuda.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index eaf1d52a9fc1a8..79b7a26ef222f8 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -91,6 +91,37 @@ struct DeviceExprChecker
   }
 };
 
+struct FindHostArray
+    : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
+  using Result = const Symbol *;
+  using Base = evaluate::AnyTraverse<FindHostArray, Result>;
+  FindHostArray() : Base(*this) {}
+  using Base::operator();
+  Result operator()(const evaluate::Component &x) const {
+    const Symbol &symbol{x.GetLastSymbol()};
+    if (IsAllocatableOrPointer(symbol)) {
+      if (Result hostArray{(*this)(symbol)}) {
+        return hostArray;
+      }
+    }
+    return (*this)(x.base());
+  }
+  Result operator()(const Symbol &symbol) const {
+    if (const auto *details{
+            symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
+      if (details->IsArray() &&
+          (!details->cudaDataAttr() ||
+              (details->cudaDataAttr() &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Device &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
+        return &symbol;
+      }
+    }
+    return nullptr;
+  }
+};
+
 template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
     return DeviceExprChecker{}(expr->typedExpr);
@@ -306,22 +337,11 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
     }
   }
   template <typename A>
-  void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) {
-    for (const Symbol &sym : CollectCudaSymbols(expr)) {
-      if (const auto *details =
-              sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-        if (details->IsArray() &&
-            (!details->cudaDataAttr() ||
-                (details->cudaDataAttr() &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Device &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
-                    *details->cudaDataAttr() !=
-                        common::CUDADataAttr::Unified))) {
-          context_.Say(source,
-              "Host array '%s' cannot be present in CUF kernel"_err_en_US,
-              sym.name());
-        }
-      }
+  void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
+    if (const Symbol * hostArray{FindHostArray{}(expr)}) {
+      context_.Say(source,
+          "Host array '%s' cannot be present in CUF kernel"_err_en_US,
+          hostArray->name());
     }
   }
   void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {


        


More information about the flang-commits mailing list