[flang-commits] [flang] [flang] Rework CUDA kernel DO host array check (PR #116301)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Nov 14 16:16:37 PST 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/116301

Don't worry about derived type components unless they are pointers or allocatables.

>From f72a5c41abdea853a70ce3467167850992f3f311 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 14 Nov 2024 16:15:01 -0800
Subject: [PATCH] [flang] Rework CUDA kernel DO host array check

Don't worry about derived type components unless they are
pointers or allocatables.
---
 flang/lib/Semantics/check-cuda.cpp | 52 +++++++++++++++++++++---------
 1 file changed, 36 insertions(+), 16 deletions(-)

diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index eaf1d52a9fc1a8..79b7a26ef222f8 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -91,6 +91,37 @@ struct DeviceExprChecker
   }
 };
 
+struct FindHostArray
+    : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
+  using Result = const Symbol *;
+  using Base = evaluate::AnyTraverse<FindHostArray, Result>;
+  FindHostArray() : Base(*this) {}
+  using Base::operator();
+  Result operator()(const evaluate::Component &x) const {
+    const Symbol &symbol{x.GetLastSymbol()};
+    if (IsAllocatableOrPointer(symbol)) {
+      if (Result hostArray{(*this)(symbol)}) {
+        return hostArray;
+      }
+    }
+    return (*this)(x.base());
+  }
+  Result operator()(const Symbol &symbol) const {
+    if (const auto *details{
+            symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
+      if (details->IsArray() &&
+          (!details->cudaDataAttr() ||
+              (details->cudaDataAttr() &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Device &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
+                  *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
+        return &symbol;
+      }
+    }
+    return nullptr;
+  }
+};
+
 template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
     return DeviceExprChecker{}(expr->typedExpr);
@@ -306,22 +337,11 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
     }
   }
   template <typename A>
-  void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) {
-    for (const Symbol &sym : CollectCudaSymbols(expr)) {
-      if (const auto *details =
-              sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
-        if (details->IsArray() &&
-            (!details->cudaDataAttr() ||
-                (details->cudaDataAttr() &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Device &&
-                    *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
-                    *details->cudaDataAttr() !=
-                        common::CUDADataAttr::Unified))) {
-          context_.Say(source,
-              "Host array '%s' cannot be present in CUF kernel"_err_en_US,
-              sym.name());
-        }
-      }
+  void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
+    if (const Symbol * hostArray{FindHostArray{}(expr)}) {
+      context_.Say(source,
+          "Host array '%s' cannot be present in CUF kernel"_err_en_US,
+          hostArray->name());
     }
   }
   void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {



More information about the flang-commits mailing list