[flang-commits] [flang] [flang] Rework CUDA kernel DO host array check (PR #116301)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Thu Nov 14 16:16:37 PST 2024
https://github.com/klausler created https://github.com/llvm/llvm-project/pull/116301
Don't worry about derived type components unless they are pointers or allocatables.
>From f72a5c41abdea853a70ce3467167850992f3f311 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 14 Nov 2024 16:15:01 -0800
Subject: [PATCH] [flang] Rework CUDA kernel DO host array check
Don't worry about derived type components unless they are
pointers or allocatables.
---
flang/lib/Semantics/check-cuda.cpp | 52 +++++++++++++++++++++---------
1 file changed, 36 insertions(+), 16 deletions(-)
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index eaf1d52a9fc1a8..79b7a26ef222f8 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -91,6 +91,37 @@ struct DeviceExprChecker
}
};
+struct FindHostArray
+ : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
+ using Result = const Symbol *;
+ using Base = evaluate::AnyTraverse<FindHostArray, Result>;
+ FindHostArray() : Base(*this) {}
+ using Base::operator();
+ Result operator()(const evaluate::Component &x) const {
+ const Symbol &symbol{x.GetLastSymbol()};
+ if (IsAllocatableOrPointer(symbol)) {
+ if (Result hostArray{(*this)(symbol)}) {
+ return hostArray;
+ }
+ }
+ return (*this)(x.base());
+ }
+ Result operator()(const Symbol &symbol) const {
+ if (const auto *details{
+ symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
+ if (details->IsArray() &&
+ (!details->cudaDataAttr() ||
+ (details->cudaDataAttr() &&
+ *details->cudaDataAttr() != common::CUDADataAttr::Device &&
+ *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
+ *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
+ return &symbol;
+ }
+ }
+ return nullptr;
+ }
+};
+
template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
return DeviceExprChecker{}(expr->typedExpr);
@@ -306,22 +337,11 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
}
}
template <typename A>
- void ErrorIfHostSymbol(const A &expr, const parser::CharBlock &source) {
- for (const Symbol &sym : CollectCudaSymbols(expr)) {
- if (const auto *details =
- sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
- if (details->IsArray() &&
- (!details->cudaDataAttr() ||
- (details->cudaDataAttr() &&
- *details->cudaDataAttr() != common::CUDADataAttr::Device &&
- *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
- *details->cudaDataAttr() !=
- common::CUDADataAttr::Unified))) {
- context_.Say(source,
- "Host array '%s' cannot be present in CUF kernel"_err_en_US,
- sym.name());
- }
- }
+ void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
+ if (const Symbol * hostArray{FindHostArray{}(expr)}) {
+ context_.Say(source,
+ "Host array '%s' cannot be present in CUF kernel"_err_en_US,
+ hostArray->name());
}
}
void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {
More information about the flang-commits
mailing list