[flang-commits] [flang] [flang] Rework CUDA kernel DO host array check (PR #116301)

via flang-commits flang-commits at lists.llvm.org
Sat Nov 16 06:15:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-semantics

Author: Peter Klausler (klausler)

<details>
<summary>Changes</summary>

Don't worry about derived type components unless they are pointers or allocatables.

---
Full diff: https://github.com/llvm/llvm-project/pull/116301.diff


1 Files Affected:

- (modified) flang/lib/Semantics/check-cuda.cpp (+36-16) 


``````````diff
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index eaf1d52a9fc1a8..79b7a26ef222f8 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -91,6 +91,37 @@ struct DeviceExprChecker
   }
 };
 
+struct FindHostArray
+    : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
+  using Result = const Symbol *;
+  using Base = evaluate::AnyTraverse<FindHostArray, Result>;
+  FindHostArray() : Base(*this) {}
+  using Base::operator();
+  Result operator()(const evaluate::Component &x) const {
+    const Symbol &symbol{x.GetLastSymbol()};
+    if (IsAllocatableOrPointer(symbol)) {
+      if (Result hostArray{(*this)(symbol)}) {
+        return hostArray;
+      }
+    }
+    return (*this)(x.base());
+  }
+  Result operator()(const Symbol &symbol) const {
+    if (const auto *details{
+            symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
+      if (details->IsArray() &&
+          (!details->cudaDataAttr() ||
+              (details->cudaDataAttr() &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Device &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
+        return &symbol;
+      }
+    }
+    return nullptr;
+  }
+};
+
 template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
     return DeviceExprChecker{}(expr->typedExpr);
@@ -306,22 +337,11 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
     }
   }
   template <typename A>
-  void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) {
-    for (const Symbol &sym : CollectCudaSymbols(expr)) {
-      if (const auto *details =
-              sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-        if (details->IsArray() &&
-            (!details->cudaDataAttr() ||
-                (details->cudaDataAttr() &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Device &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
-                    *details->cudaDataAttr() !=
-                        common::CUDADataAttr::Unified))) {
-          context_.Say(source,
-              "Host array '%s' cannot be present in CUF kernel"_err_en_US,
-              sym.name());
-        }
-      }
+  void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
+    if (const Symbol * hostArray{FindHostArray{}(expr)}) {
+      context_.Say(source,
+          "Host array '%s' cannot be present in CUF kernel"_err_en_US,
+          hostArray->name());
     }
   }
   void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/116301


More information about the flang-commits mailing list