[flang-commits] [flang] [flang][cuda] Only apply the implicit data attribute on the component for use_device (PR #192146)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Apr 16 21:53:23 PDT 2026
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/192146
>From dd17cc2b3e00c30084818dc54ebf85d3bc159d68 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 13 Apr 2026 17:37:15 -0700
Subject: [PATCH 1/2] [flang][cuda] Only apply the implicit data attribute on
the component for use_device
---
flang/include/flang/Semantics/tools.h | 9 +
flang/include/flang/Semantics/type.h | 3 +
flang/lib/Semantics/expression.cpp | 39 +++-
flang/lib/Semantics/resolve-names.cpp | 174 +++++++++++++++---
flang/lib/Semantics/type.cpp | 62 +++++++
.../OpenACC/acc-host-data-cuda-device.f90 | 23 +++
6 files changed, 274 insertions(+), 36 deletions(-)
diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index d2e2be2548d2e..9f77d0ec5da2e 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -23,6 +23,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/semantics.h"
#include "flang/Support/Fortran.h"
+#include "llvm/ADT/ArrayRef.h"
#include <functional>
namespace Fortran::semantics {
@@ -294,6 +295,14 @@ SymbolVector OrderParameterNames(const Symbol &);
const DeclTypeSpec &FindOrInstantiateDerivedType(Scope &, DerivedTypeSpec &&,
DeclTypeSpec::Category = DeclTypeSpec::TypeDerived);
+// Clone a derived type's component scope for OpenACC use_device with CUDA
+// Fortran: each component named in `path` (e.g. a%b%c -> {b,c}) gets a
+// distinct component symbol with cudaDataAttr Device in a new DerivedTypeSpec.
+// Returns nullptr if `path` is empty or `origType` is not derived.
+const DeclTypeSpec *CloneDerivedTypeForUseDevice(Scope &containingScope,
+ SemanticsContext &, const DeclTypeSpec &origType,
+ llvm::ArrayRef<SourceName> path);
+
// When a subprogram defined in a submodule defines a separate module
// procedure whose interface is defined in an ancestor (sub)module,
// returns a pointer to that interface, else null.
diff --git a/flang/include/flang/Semantics/type.h b/flang/include/flang/Semantics/type.h
index 3a07b6ee2ec1c..312f8a4b2075e 100644
--- a/flang/include/flang/Semantics/type.h
+++ b/flang/include/flang/Semantics/type.h
@@ -307,6 +307,9 @@ class DerivedTypeSpec {
// instantiations that have been specialized with actual type parameter
// values, which are cooked &/or evaluated if necessary.
void Instantiate(Scope &containingScope);
+ // Reset instantiation state so a copy can receive a fresh component scope
+ // (e.g. OpenACC use_device with CUDA Fortran component paths).
+ void PrepareForScopeClone();
ParamValue *FindParameter(SourceName);
const ParamValue *FindParameter(SourceName target) const {
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 756bfd551a90a..8b64fca2666e6 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1497,14 +1497,24 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::StructureComponent &sc) {
Say(name,
"A type parameter inquiry must be applied to a designator"_err_en_US);
}
- } else if (!dtSpec || !dtSpec->scope()) {
+ } else if (!dtSpec || !dtSpec->GetScope()) {
CHECK(context_.AnyFatalError() || !foldingContext_.messages().empty());
return std::nullopt;
} else if (std::optional<DataRef> dataRef{
ExtractDataRef(std::move(*dtExpr))}) {
+ // The base may use a forked DerivedTypeSpec (e.g. OpenACC use_device with
+ // CUDA) while the parse tree still points at host-associated component
+ // symbols; resolve the component in the base type's instantiated scope.
+ const semantics::Scope &typeScope{DEREF(dtSpec->GetScope())};
+ Symbol *compSym{sym};
+ if (sym && sym->owner().IsDerivedType() && &sym->owner() != &typeScope) {
+ if (Symbol * via{typeScope.FindComponent(sym->name())}) {
+ compSym = via;
+ }
+ }
auto restorer{GetContextualMessages().SetLocation(name)};
if (auto component{
- CreateComponent(std::move(*dataRef), *sym, *dtSpec->scope())}) {
+ CreateComponent(std::move(*dataRef), *compSym, typeScope)}) {
return Designate(DataRef{std::move(*component)});
} else {
Say(name, "Component is not in scope of derived TYPE(%s)"_err_en_US,
@@ -2841,13 +2851,24 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
if (actual) {
if (auto *expr{actual->UnwrapExpr()}) {
- const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
- if (actualLastSymbol) {
- actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
- if (const auto *actualObject{actualLastSymbol
- ? actualLastSymbol
- ->detailsIf<semantics::ObjectEntityDetails>()
- : nullptr}) {
+ if (evaluate::IsVariable(*expr)) {
+ // Match check-call.cpp: walk the whole designator so e.g. b%a picks up
+ // ATTRIBUTES(DEVICE) from the base b when the component a has no CUDA
+ // attribute (OpenACC use_device(b) + doit(b%a)), not only from the
+ // last symbol (GetLastSymbol would only see a).
+ for (const Symbol &s : evaluate::GetSymbolVector(*expr)) {
+ if (const auto *object{
+ s.detailsIf<semantics::ObjectEntityDetails>()}) {
+ if (auto cudaAttr{object->cudaDataAttr()}) {
+ actualDataAttr = *cudaAttr;
+ }
+ }
+ }
+ } else if (const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)}) {
+ const Symbol &resolved{
+ semantics::ResolveAssociations(*actualLastSymbol)};
+ if (const auto *actualObject{
+ resolved.detailsIf<semantics::ObjectEntityDetails>()}) {
actualDataAttr = actualObject->cudaDataAttr();
}
}
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index b6907cc792d76..091439da3e812 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -42,6 +42,8 @@
#include "flang/Semantics/type.h"
#include "flang/Support/Fortran.h"
#include "flang/Support/default-kinds.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include <list>
@@ -1410,7 +1412,7 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
- bool Pre(const parser::AccClause::UseDevice &x);
+ void Post(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
@@ -1472,11 +1474,116 @@ class AccVisitor : public virtual DeclarationVisitor {
}
void CopySymbolWithDevice(const parser::Name *name);
+ void CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
+ llvm::ArrayRef<SourceName> componentPath, parser::Designator &designator);
private:
SemanticsContext &context_;
};
+namespace {
+
+void GetUseDeviceDataRefPath(const parser::DataRef &dr,
+ const parser::Name *&baseOut,
+ llvm::SmallVectorImpl<SourceName> &componentsOut) {
+ common::visit(
+ common::visitors{
+ [&](const parser::Name &name) { baseOut = &name; },
+ [&](const common::Indirection<parser::StructureComponent> &sc) {
+ const parser::StructureComponent &comp{sc.value()};
+ GetUseDeviceDataRefPath(comp.Base(), baseOut, componentsOut);
+ componentsOut.push_back(comp.Component().source);
+ },
+ [&](const common::Indirection<parser::ArrayElement> &ae) {
+ GetUseDeviceDataRefPath(ae.value().Base(), baseOut, componentsOut);
+ },
+ [&](const common::Indirection<parser::CoindexedNamedObject> &ci) {
+ GetUseDeviceDataRefPath(std::get<parser::DataRef>(ci.value().t),
+ baseOut, componentsOut);
+ },
+ },
+ dr.u);
+}
+
+void GetUseDeviceStructurePath(const parser::Designator &d,
+ const parser::Name *&baseOut,
+ llvm::SmallVectorImpl<SourceName> &componentsOut) {
+ componentsOut.clear();
+ baseOut = nullptr;
+ if (const auto *dr{std::get_if<parser::DataRef>(&d.u)}) {
+ GetUseDeviceDataRefPath(*dr, baseOut, componentsOut);
+ }
+}
+
+// Point each %component name at the corresponding symbol in the cloned
+// DerivedTypeSpec scope so semantics and lowering see cudaDataAttr::Device.
+static bool RemapUseDeviceComponentNames(parser::DataRef &dr,
+ const DerivedTypeSpec *&currDts, llvm::ArrayRef<SourceName> path,
+ std::size_t &idx) {
+ return common::visit(
+ common::visitors{
+ [&](parser::Name &) { return true; },
+ [&](common::Indirection<parser::StructureComponent> &sc) -> bool {
+ if (!RemapUseDeviceComponentNames(
+ const_cast<parser::DataRef &>(sc.value().Base()), currDts,
+ path, idx)) {
+ return false;
+ }
+ if (idx >= path.size()) {
+ return true;
+ }
+ parser::StructureComponent &compStruct{sc.value()};
+ parser::Name &compName{std::get<parser::Name>(compStruct.t)};
+ const semantics::Scope *compTyScope{currDts->GetScope()};
+ Symbol *sym{
+ compTyScope ? compTyScope->FindComponent(path[idx]) : nullptr};
+ if (!sym && compTyScope) {
+ sym = compTyScope->FindComponent(compName.source);
+ }
+ if (!sym) {
+ return false;
+ }
+ compName.symbol = sym;
+ if (const auto *oed{sym->detailsIf<ObjectEntityDetails>()}) {
+ if (const DeclTypeSpec *t{oed->type()}) {
+ if (const DerivedTypeSpec *inner{t->AsDerived()}) {
+ currDts = inner;
+ }
+ }
+ }
+ ++idx;
+ return true;
+ },
+ [&](common::Indirection<parser::ArrayElement> &ae) -> bool {
+ return RemapUseDeviceComponentNames(
+ const_cast<parser::DataRef &>(ae.value().Base()), currDts, path,
+ idx);
+ },
+ [&](common::Indirection<parser::CoindexedNamedObject> &ci) -> bool {
+ return RemapUseDeviceComponentNames(
+ const_cast<parser::DataRef &>(
+ std::get<parser::DataRef>(ci.value().t)),
+ currDts, path, idx);
+ },
+ },
+ dr.u);
+}
+
+static void RemapUseDeviceDesignator(parser::Designator &d,
+ const DeclTypeSpec &newVarType, llvm::ArrayRef<SourceName> path) {
+ const DerivedTypeSpec *dts{newVarType.AsDerived()};
+ if (!dts || path.empty()) {
+ return;
+ }
+ std::size_t idx{0};
+ const DerivedTypeSpec *currDts{dts};
+ if (auto *dr{std::get_if<parser::DataRef>(&d.u)}) {
+ (void)RemapUseDeviceComponentNames(*dr, currDts, path, idx);
+ }
+}
+
+} // namespace
+
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
const auto &beginBlockDir{std::get<parser::AccBeginBlockDirective>(x.t)};
const auto &beginDir{std::get<parser::AccBlockDirective>(beginBlockDir.t)};
@@ -1520,34 +1627,48 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
}
}
-bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
- for (const auto &accObject : x.v.v) {
- Walk(accObject);
+void AccVisitor::CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
+ llvm::ArrayRef<SourceName> componentPath, parser::Designator &designator) {
+ if (!context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA) ||
+ !baseName || !baseName->symbol || componentPath.empty()) {
+ return;
+ }
+ const Symbol &orig{*baseName->symbol};
+ const DeclTypeSpec *origType{orig.GetType()};
+ if (!origType || !origType->AsDerived()) {
+ return;
+ }
+ if (const DeclTypeSpec *newType{CloneDerivedTypeForUseDevice(
+ currScope(), context_, *origType, componentPath)}) {
+ if (Symbol * copy{currScope().CopySymbol(orig)}) {
+ baseName->symbol = copy;
+ if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
+ object->ReplaceType(*newType);
+ }
+ RemapUseDeviceDesignator(designator, *newType, componentPath);
+ }
+ }
+}
+
+void AccVisitor::Post(const parser::AccClause::UseDevice &x) {
+ // Run after clause acc-objects are walked so Post(Designator) has resolved
+ // Name::symbol, then fork types and remap %component symbols for CUDA.
+ for (auto &accObject : const_cast<std::list<parser::AccObject> &>(x.v.v)) {
common::visit(
common::visitors{
- [&](const parser::Designator &designator) {
- if (const auto *name{
- parser::GetDesignatorNameIfDataRef(designator)}) {
- CopySymbolWithDevice(name);
- } else if (const auto *dataRef{
- std::get_if<parser::DataRef>(&designator.u)}) {
- using ElementIndirection =
- common::Indirection<parser::ArrayElement>;
- using ComponentIndirection =
- common::Indirection<parser::StructureComponent>;
- if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
- const parser::ArrayElement &arrayElement{ind->value()};
- const parser::DataRef &base{arrayElement.Base()};
- if (auto *name{std::get_if<parser::Name>(&base.u)}) {
- CopySymbolWithDevice(name);
- }
- } else if (auto *ind{std::get_if<ComponentIndirection>(
- &dataRef->u)}) {
- const parser::StructureComponent &comp{ind->value()};
- const parser::DataRef &base{comp.Base()};
- if (auto *name{std::get_if<parser::Name>(&base.u)}) {
- CopySymbolWithDevice(name);
+ [&](parser::Designator &designator) {
+ const parser::Name *baseName{nullptr};
+ llvm::SmallVector<SourceName, 4> components;
+ GetUseDeviceStructurePath(designator, baseName, components);
+ if (baseName && baseName->symbol) {
+ if (components.empty()) {
+ if (const auto *bareName{
+ parser::GetDesignatorNameIfDataRef(designator)}) {
+ CopySymbolWithDevice(bareName);
}
+ } else {
+ CopySymbolWithDeviceStructurePath(
+ baseName, components, designator);
}
}
},
@@ -1557,7 +1678,6 @@ bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
},
accObject.u);
}
- return false;
}
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
diff --git a/flang/lib/Semantics/type.cpp b/flang/lib/Semantics/type.cpp
index bb8fddbffd945..6acfd63dbac4d 100644
--- a/flang/lib/Semantics/type.cpp
+++ b/flang/lib/Semantics/type.cpp
@@ -18,6 +18,7 @@
#include "flang/Semantics/scope.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
namespace Fortran::semantics {
@@ -199,6 +200,8 @@ void DerivedTypeSpec::ReevaluateParameters(SemanticsContext &context) {
EvaluateParameters(context);
}
+void DerivedTypeSpec::PrepareForScopeClone() { scope_ = nullptr; }
+
void DerivedTypeSpec::AddParamValue(SourceName name, ParamValue &&value) {
CHECK(cooked_);
auto pair{parameters_.insert(std::make_pair(name, std::move(value)))};
@@ -282,13 +285,22 @@ bool DerivedTypeSpec::MatchesOrExtends(const DerivedTypeSpec &that) const {
return MatchKindParams(*typeSymbol, *this, that);
}
+static const DeclTypeSpec *CloneDerivedTypeForUseDeviceImpl(
+ Scope &containingScope, SemanticsContext &context,
+ const DerivedTypeSpec &sourceDts, DeclTypeSpec::Category category,
+ llvm::ArrayRef<SourceName> path);
+
class InstantiateHelper {
public:
InstantiateHelper(Scope &scope) : scope_{scope} {}
// Instantiate components from fromScope into scope_
void InstantiateComponents(const Scope &);
+ void SetUseDevicePath(llvm::ArrayRef<SourceName> path) {
+ useDevicePath_ = path;
+ }
private:
+ llvm::ArrayRef<SourceName> useDevicePath_{};
SemanticsContext &context() const { return scope_.context(); }
evaluate::FoldingContext &foldingContext() {
return context().foldingContext();
@@ -568,6 +580,14 @@ const DeclTypeSpec *InstantiateHelper::InstantiateType(const Symbol &symbol) {
if (!type) {
return nullptr; // error has occurred
} else if (const DerivedTypeSpec * spec{type->AsDerived()}) {
+ if (!useDevicePath_.empty() && symbol.name() == useDevicePath_[0] &&
+ useDevicePath_.size() > 1) {
+ if (const DeclTypeSpec *cloned{
+ CloneDerivedTypeForUseDeviceImpl(scope_, context(), *spec,
+ type->category(), useDevicePath_.drop_front())}) {
+ return cloned;
+ }
+ }
return &FindOrInstantiateDerivedType(scope_,
CreateDerivedTypeSpec(*spec, symbol.test(Symbol::Flag::ParentComp)),
type->category());
@@ -688,6 +708,48 @@ DerivedTypeSpec InstantiateHelper::CreateDerivedTypeSpec(
return result;
}
+static const DeclTypeSpec *CloneDerivedTypeForUseDeviceImpl(
+ Scope &containingScope, SemanticsContext &context,
+ const DerivedTypeSpec &sourceDts, DeclTypeSpec::Category category,
+ llvm::ArrayRef<SourceName> path) {
+ if (path.empty()) {
+ return nullptr;
+ }
+ DerivedTypeSpec newDts{sourceDts};
+ newDts.PrepareForScopeClone();
+ DeclTypeSpec &newDecl{
+ containingScope.MakeDerivedType(category, std::move(newDts))};
+ DerivedTypeSpec &dtsRef{newDecl.derivedTypeSpec()};
+ Scope &newScope{containingScope.MakeScope(Scope::Kind::DerivedType)};
+ dtsRef.ReplaceScope(newScope);
+ newScope.set_derivedTypeSpec(dtsRef);
+
+ InstantiateHelper helper{newScope};
+ helper.SetUseDevicePath(path);
+ helper.InstantiateComponents(*sourceDts.GetScope());
+
+ if (Symbol * comp{newScope.FindComponent(path[0])}) {
+ if (auto *details{comp->detailsIf<ObjectEntityDetails>()}) {
+ details->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
+ return &newDecl;
+}
+
+const DeclTypeSpec *CloneDerivedTypeForUseDevice(Scope &containingScope,
+ SemanticsContext &context, const DeclTypeSpec &origType,
+ llvm::ArrayRef<SourceName> path) {
+ if (path.empty()) {
+ return nullptr;
+ }
+ const DerivedTypeSpec *spec{origType.AsDerived()};
+ if (!spec) {
+ return nullptr;
+ }
+ return CloneDerivedTypeForUseDeviceImpl(
+ containingScope, context, *spec, origType.category(), path);
+}
+
std::string DerivedTypeSpec::VectorTypeAsFortran() const {
std::string buf;
llvm::raw_string_ostream ss{buf};
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
index cc301617444ac..55bf100ffa55d 100644
--- a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -13,9 +13,32 @@ module m
!dir$ ignore_tkr(c) a
end
end interface
+type t
+ integer, pointer :: p1
+ integer, pointer :: p2
+ end type
+interface foo
+subroutine foo_device(p)
+ integer, pointer, device :: p
+end subroutine
+subroutine foo_host(p)
+ integer, pointer :: p
+end subroutine
+end interface
contains
+ subroutine test(obj)
+ type(t) :: obj
+ !$acc host_data use_device(obj%p1)
+ call foo(obj%p1)
+ call foo(obj%p2)
+ !$acc end host_data
+ end subroutine
+! CHECK-LABEL: func.func @_QMmPtest
+! CHECK: fir.call @_QPfoo_device
+! CHECK: fir.call @_QPfoo_host
+
subroutine vectoraddarray(a, b, n)
implicit none
integer :: n
>From f33c099939772253f28a5235fc6abfc47c6a65d7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 16 Apr 2026 21:53:04 -0700
Subject: [PATCH 2/2] Address review comments
---
flang/lib/Semantics/resolve-names.cpp | 28 +++++++++--------
.../OpenACC/acc-host-data-cuda-device.f90 | 30 +++++++++++++++++++
2 files changed, 45 insertions(+), 13 deletions(-)
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 091439da3e812..d1ef2c7caae0a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1634,19 +1634,24 @@ void AccVisitor::CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
return;
}
const Symbol &orig{*baseName->symbol};
- const DeclTypeSpec *origType{orig.GetType()};
- if (!origType || !origType->AsDerived()) {
+ Symbol *copy{currScope().CopySymbol(orig)};
+ if (!copy) {
+ copy = FindInScope(currScope(), baseName->symbol->name());
+ }
+ if (!copy) {
+ return;
+ }
+ baseName->symbol = copy;
+ const DeclTypeSpec *sourceType{copy->GetType()};
+ if (!sourceType || !sourceType->AsDerived()) {
return;
}
if (const DeclTypeSpec *newType{CloneDerivedTypeForUseDevice(
- currScope(), context_, *origType, componentPath)}) {
- if (Symbol * copy{currScope().CopySymbol(orig)}) {
- baseName->symbol = copy;
- if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
- object->ReplaceType(*newType);
- }
- RemapUseDeviceDesignator(designator, *newType, componentPath);
+ currScope(), context_, *sourceType, componentPath)}) {
+ if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
+ object->ReplaceType(*newType);
}
+ RemapUseDeviceDesignator(designator, *newType, componentPath);
}
}
@@ -1662,10 +1667,7 @@ void AccVisitor::Post(const parser::AccClause::UseDevice &x) {
GetUseDeviceStructurePath(designator, baseName, components);
if (baseName && baseName->symbol) {
if (components.empty()) {
- if (const auto *bareName{
- parser::GetDesignatorNameIfDataRef(designator)}) {
- CopySymbolWithDevice(bareName);
- }
+ CopySymbolWithDevice(baseName);
} else {
CopySymbolWithDeviceStructurePath(
baseName, components, designator);
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
index 55bf100ffa55d..ce1e3682ea40b 100644
--- a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -26,6 +26,15 @@ subroutine foo_host(p)
end subroutine
end interface
+interface foo_array
+ subroutine foo_device_array(x)
+ real, device :: x(:,:)
+ end
+ subroutine foo_host_array(x)
+ real :: x(:,:)
+ end
+end interface
+
contains
subroutine test(obj)
@@ -34,10 +43,31 @@ subroutine test(obj)
call foo(obj%p1)
call foo(obj%p2)
!$acc end host_data
+
+ call foo(obj%p1)
+
+ !$acc host_data use_device(obj%p1, obj%p2)
+ call foo(obj%p1)
+ call foo(obj%p2)
+ !$acc end host_data
end subroutine
! CHECK-LABEL: func.func @_QMmPtest
! CHECK: fir.call @_QPfoo_device
! CHECK: fir.call @_QPfoo_host
+! CHECK: fir.call @_QPfoo_host
+! CHECK: fir.call @_QPfoo_device
+! CHECK: fir.call @_QPfoo_device
+
+ subroutine test_array(a, i)
+ real :: a(4,4,4)
+ integer :: i
+ !$acc host_data use_device(a(:,:,i))
+ call foo_array(a(:,:,i))
+ !$acc end host_data
+ end subroutine
+
+! CHECK-LABEL: func.func @_QMmPtest_array
+! CHECK: fir.call @_QPfoo_device_array
subroutine vectoraddarray(a, b, n)
implicit none
More information about the flang-commits
mailing list