[flang-commits] [flang] [flang][cuda] Only apply the implicit data attribute on the component for use_device (PR #192146)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Apr 16 21:53:23 PDT 2026


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/192146

>From dd17cc2b3e00c30084818dc54ebf85d3bc159d68 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Mon, 13 Apr 2026 17:37:15 -0700
Subject: [PATCH 1/2] [flang][cuda] Only apply the implicit data attribute on
 the component for use_device

---
 flang/include/flang/Semantics/tools.h         |   9 +
 flang/include/flang/Semantics/type.h          |   3 +
 flang/lib/Semantics/expression.cpp            |  39 +++-
 flang/lib/Semantics/resolve-names.cpp         | 174 +++++++++++++++---
 flang/lib/Semantics/type.cpp                  |  62 +++++++
 .../OpenACC/acc-host-data-cuda-device.f90     |  23 +++
 6 files changed, 274 insertions(+), 36 deletions(-)

diff --git a/flang/include/flang/Semantics/tools.h b/flang/include/flang/Semantics/tools.h
index d2e2be2548d2e..9f77d0ec5da2e 100644
--- a/flang/include/flang/Semantics/tools.h
+++ b/flang/include/flang/Semantics/tools.h
@@ -23,6 +23,7 @@
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/semantics.h"
 #include "flang/Support/Fortran.h"
+#include "llvm/ADT/ArrayRef.h"
 #include <functional>
 
 namespace Fortran::semantics {
@@ -294,6 +295,14 @@ SymbolVector OrderParameterNames(const Symbol &);
 const DeclTypeSpec &FindOrInstantiateDerivedType(Scope &, DerivedTypeSpec &&,
     DeclTypeSpec::Category = DeclTypeSpec::TypeDerived);
 
+// Clone a derived type's component scope for OpenACC use_device with CUDA
+// Fortran: each component named in `path` (e.g. a%b%c -> {b,c}) gets a
+// distinct component symbol with cudaDataAttr Device in a new DerivedTypeSpec.
+// Returns nullptr if `path` is empty or `origType` is not derived.
+const DeclTypeSpec *CloneDerivedTypeForUseDevice(Scope &containingScope,
+    SemanticsContext &, const DeclTypeSpec &origType,
+    llvm::ArrayRef<SourceName> path);
+
 // When a subprogram defined in a submodule defines a separate module
 // procedure whose interface is defined in an ancestor (sub)module,
 // returns a pointer to that interface, else null.
diff --git a/flang/include/flang/Semantics/type.h b/flang/include/flang/Semantics/type.h
index 3a07b6ee2ec1c..312f8a4b2075e 100644
--- a/flang/include/flang/Semantics/type.h
+++ b/flang/include/flang/Semantics/type.h
@@ -307,6 +307,9 @@ class DerivedTypeSpec {
   // instantiations that have been specialized with actual type parameter
   // values, which are cooked &/or evaluated if necessary.
   void Instantiate(Scope &containingScope);
+  // Reset instantiation state so a copy can receive a fresh component scope
+  // (e.g. OpenACC use_device with CUDA Fortran component paths).
+  void PrepareForScopeClone();
 
   ParamValue *FindParameter(SourceName);
   const ParamValue *FindParameter(SourceName target) const {
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index 756bfd551a90a..8b64fca2666e6 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -1497,14 +1497,24 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::StructureComponent &sc) {
         Say(name,
             "A type parameter inquiry must be applied to a designator"_err_en_US);
       }
-    } else if (!dtSpec || !dtSpec->scope()) {
+    } else if (!dtSpec || !dtSpec->GetScope()) {
       CHECK(context_.AnyFatalError() || !foldingContext_.messages().empty());
       return std::nullopt;
     } else if (std::optional<DataRef> dataRef{
                    ExtractDataRef(std::move(*dtExpr))}) {
+      // The base may use a forked DerivedTypeSpec (e.g. OpenACC use_device with
+      // CUDA) while the parse tree still points at host-associated component
+      // symbols; resolve the component in the base type's instantiated scope.
+      const semantics::Scope &typeScope{DEREF(dtSpec->GetScope())};
+      Symbol *compSym{sym};
+      if (sym && sym->owner().IsDerivedType() && &sym->owner() != &typeScope) {
+        if (Symbol * via{typeScope.FindComponent(sym->name())}) {
+          compSym = via;
+        }
+      }
       auto restorer{GetContextualMessages().SetLocation(name)};
       if (auto component{
-              CreateComponent(std::move(*dataRef), *sym, *dtSpec->scope())}) {
+              CreateComponent(std::move(*dataRef), *compSym, typeScope)}) {
         return Designate(DataRef{std::move(*component)});
       } else {
         Say(name, "Component is not in scope of derived TYPE(%s)"_err_en_US,
@@ -2841,13 +2851,24 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
   std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
   if (actual) {
     if (auto *expr{actual->UnwrapExpr()}) {
-      const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)};
-      if (actualLastSymbol) {
-        actualLastSymbol = &semantics::ResolveAssociations(*actualLastSymbol);
-        if (const auto *actualObject{actualLastSymbol
-                    ? actualLastSymbol
-                          ->detailsIf<semantics::ObjectEntityDetails>()
-                    : nullptr}) {
+      if (evaluate::IsVariable(*expr)) {
+        // Match check-call.cpp: walk the whole designator so e.g. b%a picks up
+        // ATTRIBUTES(DEVICE) from the base b when the component a has no CUDA
+        // attribute (OpenACC use_device(b) + doit(b%a)), not only from the
+        // last symbol (GetLastSymbol would only see a).
+        for (const Symbol &s : evaluate::GetSymbolVector(*expr)) {
+          if (const auto *object{
+                  s.detailsIf<semantics::ObjectEntityDetails>()}) {
+            if (auto cudaAttr{object->cudaDataAttr()}) {
+              actualDataAttr = *cudaAttr;
+            }
+          }
+        }
+      } else if (const auto *actualLastSymbol{evaluate::GetLastSymbol(*expr)}) {
+        const Symbol &resolved{
+            semantics::ResolveAssociations(*actualLastSymbol)};
+        if (const auto *actualObject{
+                resolved.detailsIf<semantics::ObjectEntityDetails>()}) {
           actualDataAttr = actualObject->cudaDataAttr();
         }
       }
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index b6907cc792d76..091439da3e812 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -42,6 +42,8 @@
 #include "flang/Semantics/type.h"
 #include "flang/Support/Fortran.h"
 #include "flang/Support/default-kinds.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSwitch.h"
 #include "llvm/Support/raw_ostream.h"
 #include <list>
@@ -1410,7 +1412,7 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::OpenACCBlockConstruct &);
   bool Pre(const parser::OpenACCCombinedConstruct &);
   void Post(const parser::OpenACCCombinedConstruct &);
-  bool Pre(const parser::AccClause::UseDevice &x);
+  void Post(const parser::AccClause::UseDevice &x);
   bool Pre(const parser::AccBeginBlockDirective &x) {
     AddAccSourceRange(x.source);
     return true;
@@ -1472,11 +1474,116 @@ class AccVisitor : public virtual DeclarationVisitor {
   }
 
   void CopySymbolWithDevice(const parser::Name *name);
+  void CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
+      llvm::ArrayRef<SourceName> componentPath, parser::Designator &designator);
 
 private:
   SemanticsContext &context_;
 };
 
+namespace {
+
+void GetUseDeviceDataRefPath(const parser::DataRef &dr,
+    const parser::Name *&baseOut,
+    llvm::SmallVectorImpl<SourceName> &componentsOut) {
+  common::visit(
+      common::visitors{
+          [&](const parser::Name &name) { baseOut = &name; },
+          [&](const common::Indirection<parser::StructureComponent> &sc) {
+            const parser::StructureComponent &comp{sc.value()};
+            GetUseDeviceDataRefPath(comp.Base(), baseOut, componentsOut);
+            componentsOut.push_back(comp.Component().source);
+          },
+          [&](const common::Indirection<parser::ArrayElement> &ae) {
+            GetUseDeviceDataRefPath(ae.value().Base(), baseOut, componentsOut);
+          },
+          [&](const common::Indirection<parser::CoindexedNamedObject> &ci) {
+            GetUseDeviceDataRefPath(std::get<parser::DataRef>(ci.value().t),
+                baseOut, componentsOut);
+          },
+      },
+      dr.u);
+}
+
+void GetUseDeviceStructurePath(const parser::Designator &d,
+    const parser::Name *&baseOut,
+    llvm::SmallVectorImpl<SourceName> &componentsOut) {
+  componentsOut.clear();
+  baseOut = nullptr;
+  if (const auto *dr{std::get_if<parser::DataRef>(&d.u)}) {
+    GetUseDeviceDataRefPath(*dr, baseOut, componentsOut);
+  }
+}
+
+// Point each %component name at the corresponding symbol in the cloned
+// DerivedTypeSpec scope so semantics and lowering see cudaDataAttr::Device.
+static bool RemapUseDeviceComponentNames(parser::DataRef &dr,
+    const DerivedTypeSpec *&currDts, llvm::ArrayRef<SourceName> path,
+    std::size_t &idx) {
+  return common::visit(
+      common::visitors{
+          [&](parser::Name &) { return true; },
+          [&](common::Indirection<parser::StructureComponent> &sc) -> bool {
+            if (!RemapUseDeviceComponentNames(
+                    const_cast<parser::DataRef &>(sc.value().Base()), currDts,
+                    path, idx)) {
+              return false;
+            }
+            if (idx >= path.size()) {
+              return true;
+            }
+            parser::StructureComponent &compStruct{sc.value()};
+            parser::Name &compName{std::get<parser::Name>(compStruct.t)};
+            const semantics::Scope *compTyScope{currDts->GetScope()};
+            Symbol *sym{
+                compTyScope ? compTyScope->FindComponent(path[idx]) : nullptr};
+            if (!sym && compTyScope) {
+              sym = compTyScope->FindComponent(compName.source);
+            }
+            if (!sym) {
+              return false;
+            }
+            compName.symbol = sym;
+            if (const auto *oed{sym->detailsIf<ObjectEntityDetails>()}) {
+              if (const DeclTypeSpec *t{oed->type()}) {
+                if (const DerivedTypeSpec *inner{t->AsDerived()}) {
+                  currDts = inner;
+                }
+              }
+            }
+            ++idx;
+            return true;
+          },
+          [&](common::Indirection<parser::ArrayElement> &ae) -> bool {
+            return RemapUseDeviceComponentNames(
+                const_cast<parser::DataRef &>(ae.value().Base()), currDts, path,
+                idx);
+          },
+          [&](common::Indirection<parser::CoindexedNamedObject> &ci) -> bool {
+            return RemapUseDeviceComponentNames(
+                const_cast<parser::DataRef &>(
+                    std::get<parser::DataRef>(ci.value().t)),
+                currDts, path, idx);
+          },
+      },
+      dr.u);
+}
+
+static void RemapUseDeviceDesignator(parser::Designator &d,
+    const DeclTypeSpec &newVarType, llvm::ArrayRef<SourceName> path) {
+  const DerivedTypeSpec *dts{newVarType.AsDerived()};
+  if (!dts || path.empty()) {
+    return;
+  }
+  std::size_t idx{0};
+  const DerivedTypeSpec *currDts{dts};
+  if (auto *dr{std::get_if<parser::DataRef>(&d.u)}) {
+    (void)RemapUseDeviceComponentNames(*dr, currDts, path, idx);
+  }
+}
+
+} // namespace
+
 bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
   const auto &beginBlockDir{std::get<parser::AccBeginBlockDirective>(x.t)};
   const auto &beginDir{std::get<parser::AccBlockDirective>(beginBlockDir.t)};
@@ -1520,34 +1627,48 @@ void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
   }
 }
 
-bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
-  for (const auto &accObject : x.v.v) {
-    Walk(accObject);
+void AccVisitor::CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
+    llvm::ArrayRef<SourceName> componentPath, parser::Designator &designator) {
+  if (!context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA) ||
+      !baseName || !baseName->symbol || componentPath.empty()) {
+    return;
+  }
+  const Symbol &orig{*baseName->symbol};
+  const DeclTypeSpec *origType{orig.GetType()};
+  if (!origType || !origType->AsDerived()) {
+    return;
+  }
+  if (const DeclTypeSpec *newType{CloneDerivedTypeForUseDevice(
+          currScope(), context_, *origType, componentPath)}) {
+    if (Symbol * copy{currScope().CopySymbol(orig)}) {
+      baseName->symbol = copy;
+      if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
+        object->ReplaceType(*newType);
+      }
+      RemapUseDeviceDesignator(designator, *newType, componentPath);
+    }
+  }
+}
+
+void AccVisitor::Post(const parser::AccClause::UseDevice &x) {
+  // Run after clause acc-objects are walked so Post(Designator) has resolved
+  // Name::symbol, then fork types and remap %component symbols for CUDA.
+  for (auto &accObject : const_cast<std::list<parser::AccObject> &>(x.v.v)) {
     common::visit(
         common::visitors{
-            [&](const parser::Designator &designator) {
-              if (const auto *name{
-                      parser::GetDesignatorNameIfDataRef(designator)}) {
-                CopySymbolWithDevice(name);
-              } else if (const auto *dataRef{
-                             std::get_if<parser::DataRef>(&designator.u)}) {
-                using ElementIndirection =
-                    common::Indirection<parser::ArrayElement>;
-                using ComponentIndirection =
-                    common::Indirection<parser::StructureComponent>;
-                if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
-                  const parser::ArrayElement &arrayElement{ind->value()};
-                  const parser::DataRef &base{arrayElement.Base()};
-                  if (auto *name{std::get_if<parser::Name>(&base.u)}) {
-                    CopySymbolWithDevice(name);
-                  }
-                } else if (auto *ind{std::get_if<ComponentIndirection>(
-                               &dataRef->u)}) {
-                  const parser::StructureComponent &comp{ind->value()};
-                  const parser::DataRef &base{comp.Base()};
-                  if (auto *name{std::get_if<parser::Name>(&base.u)}) {
-                    CopySymbolWithDevice(name);
+            [&](parser::Designator &designator) {
+              const parser::Name *baseName{nullptr};
+              llvm::SmallVector<SourceName, 4> components;
+              GetUseDeviceStructurePath(designator, baseName, components);
+              if (baseName && baseName->symbol) {
+                if (components.empty()) {
+                  if (const auto *bareName{
+                          parser::GetDesignatorNameIfDataRef(designator)}) {
+                    CopySymbolWithDevice(bareName);
                   }
+                } else {
+                  CopySymbolWithDeviceStructurePath(
+                      baseName, components, designator);
                 }
               }
             },
@@ -1557,7 +1678,6 @@ bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
         },
         accObject.u);
   }
-  return false;
 }
 
 void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
diff --git a/flang/lib/Semantics/type.cpp b/flang/lib/Semantics/type.cpp
index bb8fddbffd945..6acfd63dbac4d 100644
--- a/flang/lib/Semantics/type.cpp
+++ b/flang/lib/Semantics/type.cpp
@@ -18,6 +18,7 @@
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/Support/raw_ostream.h"
 
 namespace Fortran::semantics {
@@ -199,6 +200,8 @@ void DerivedTypeSpec::ReevaluateParameters(SemanticsContext &context) {
   EvaluateParameters(context);
 }
 
+void DerivedTypeSpec::PrepareForScopeClone() { scope_ = nullptr; }
+
 void DerivedTypeSpec::AddParamValue(SourceName name, ParamValue &&value) {
   CHECK(cooked_);
   auto pair{parameters_.insert(std::make_pair(name, std::move(value)))};
@@ -282,13 +285,22 @@ bool DerivedTypeSpec::MatchesOrExtends(const DerivedTypeSpec &that) const {
   return MatchKindParams(*typeSymbol, *this, that);
 }
 
+static const DeclTypeSpec *CloneDerivedTypeForUseDeviceImpl(
+    Scope &containingScope, SemanticsContext &context,
+    const DerivedTypeSpec &sourceDts, DeclTypeSpec::Category category,
+    llvm::ArrayRef<SourceName> path);
+
 class InstantiateHelper {
 public:
   InstantiateHelper(Scope &scope) : scope_{scope} {}
   // Instantiate components from fromScope into scope_
   void InstantiateComponents(const Scope &);
+  void SetUseDevicePath(llvm::ArrayRef<SourceName> path) {
+    useDevicePath_ = path;
+  }
 
 private:
+  llvm::ArrayRef<SourceName> useDevicePath_{};
   SemanticsContext &context() const { return scope_.context(); }
   evaluate::FoldingContext &foldingContext() {
     return context().foldingContext();
@@ -568,6 +580,14 @@ const DeclTypeSpec *InstantiateHelper::InstantiateType(const Symbol &symbol) {
   if (!type) {
     return nullptr; // error has occurred
   } else if (const DerivedTypeSpec * spec{type->AsDerived()}) {
+    if (!useDevicePath_.empty() && symbol.name() == useDevicePath_[0] &&
+        useDevicePath_.size() > 1) {
+      if (const DeclTypeSpec *cloned{
+              CloneDerivedTypeForUseDeviceImpl(scope_, context(), *spec,
+                  type->category(), useDevicePath_.drop_front())}) {
+        return cloned;
+      }
+    }
     return &FindOrInstantiateDerivedType(scope_,
         CreateDerivedTypeSpec(*spec, symbol.test(Symbol::Flag::ParentComp)),
         type->category());
@@ -688,6 +708,48 @@ DerivedTypeSpec InstantiateHelper::CreateDerivedTypeSpec(
   return result;
 }
 
+static const DeclTypeSpec *CloneDerivedTypeForUseDeviceImpl(
+    Scope &containingScope, SemanticsContext &context,
+    const DerivedTypeSpec &sourceDts, DeclTypeSpec::Category category,
+    llvm::ArrayRef<SourceName> path) {
+  if (path.empty()) {
+    return nullptr;
+  }
+  DerivedTypeSpec newDts{sourceDts};
+  newDts.PrepareForScopeClone();
+  DeclTypeSpec &newDecl{
+      containingScope.MakeDerivedType(category, std::move(newDts))};
+  DerivedTypeSpec &dtsRef{newDecl.derivedTypeSpec()};
+  Scope &newScope{containingScope.MakeScope(Scope::Kind::DerivedType)};
+  dtsRef.ReplaceScope(newScope);
+  newScope.set_derivedTypeSpec(dtsRef);
+
+  InstantiateHelper helper{newScope};
+  helper.SetUseDevicePath(path);
+  helper.InstantiateComponents(*sourceDts.GetScope());
+
+  if (Symbol * comp{newScope.FindComponent(path[0])}) {
+    if (auto *details{comp->detailsIf<ObjectEntityDetails>()}) {
+      details->set_cudaDataAttr(common::CUDADataAttr::Device);
+    }
+  }
+  return &newDecl;
+}
+
+const DeclTypeSpec *CloneDerivedTypeForUseDevice(Scope &containingScope,
+    SemanticsContext &context, const DeclTypeSpec &origType,
+    llvm::ArrayRef<SourceName> path) {
+  if (path.empty()) {
+    return nullptr;
+  }
+  const DerivedTypeSpec *spec{origType.AsDerived()};
+  if (!spec) {
+    return nullptr;
+  }
+  return CloneDerivedTypeForUseDeviceImpl(
+      containingScope, context, *spec, origType.category(), path);
+}
+
 std::string DerivedTypeSpec::VectorTypeAsFortran() const {
   std::string buf;
   llvm::raw_string_ostream ss{buf};
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
index cc301617444ac..55bf100ffa55d 100644
--- a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -13,9 +13,32 @@ module m
     !dir$ ignore_tkr(c) a
 end
 end interface
+type t
+    integer, pointer :: p1
+    integer, pointer :: p2
+  end type
+interface foo
+subroutine foo_device(p)
+  integer, pointer, device :: p
+end subroutine
+subroutine foo_host(p)
+  integer, pointer :: p
+end subroutine
+end interface
 
 contains
 
+  subroutine test(obj)
+    type(t) :: obj
+    !$acc host_data use_device(obj%p1)
+    call foo(obj%p1)
+    call foo(obj%p2)
+    !$acc end host_data
+  end subroutine
+! CHECK-LABEL: func.func @_QMmPtest
+! CHECK: fir.call @_QPfoo_device
+! CHECK: fir.call @_QPfoo_host
+
   subroutine vectoraddarray(a, b, n)
     implicit none
     integer :: n

>From f33c099939772253f28a5235fc6abfc47c6a65d7 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 16 Apr 2026 21:53:04 -0700
Subject: [PATCH 2/2] Address review comments

---
 flang/lib/Semantics/resolve-names.cpp         | 28 +++++++++--------
 .../OpenACC/acc-host-data-cuda-device.f90     | 30 +++++++++++++++++++
 2 files changed, 45 insertions(+), 13 deletions(-)

diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 091439da3e812..d1ef2c7caae0a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1634,19 +1634,24 @@ void AccVisitor::CopySymbolWithDeviceStructurePath(const parser::Name *baseName,
     return;
   }
   const Symbol &orig{*baseName->symbol};
-  const DeclTypeSpec *origType{orig.GetType()};
-  if (!origType || !origType->AsDerived()) {
+  Symbol *copy{currScope().CopySymbol(orig)};
+  if (!copy) {
+    copy = FindInScope(currScope(), baseName->symbol->name());
+  }
+  if (!copy) {
+    return;
+  }
+  baseName->symbol = copy;
+  const DeclTypeSpec *sourceType{copy->GetType()};
+  if (!sourceType || !sourceType->AsDerived()) {
     return;
   }
   if (const DeclTypeSpec *newType{CloneDerivedTypeForUseDevice(
-          currScope(), context_, *origType, componentPath)}) {
-    if (Symbol * copy{currScope().CopySymbol(orig)}) {
-      baseName->symbol = copy;
-      if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
-        object->ReplaceType(*newType);
-      }
-      RemapUseDeviceDesignator(designator, *newType, componentPath);
+          currScope(), context_, *sourceType, componentPath)}) {
+    if (auto *object{copy->detailsIf<ObjectEntityDetails>()}) {
+      object->ReplaceType(*newType);
     }
+    RemapUseDeviceDesignator(designator, *newType, componentPath);
   }
 }
 
@@ -1662,10 +1667,7 @@ void AccVisitor::Post(const parser::AccClause::UseDevice &x) {
               GetUseDeviceStructurePath(designator, baseName, components);
               if (baseName && baseName->symbol) {
                 if (components.empty()) {
-                  if (const auto *bareName{
-                          parser::GetDesignatorNameIfDataRef(designator)}) {
-                    CopySymbolWithDevice(bareName);
-                  }
+                  CopySymbolWithDevice(baseName);
                 } else {
                   CopySymbolWithDeviceStructurePath(
                       baseName, components, designator);
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
index 55bf100ffa55d..ce1e3682ea40b 100644
--- a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -26,6 +26,15 @@ subroutine foo_host(p)
 end subroutine
 end interface
 
+interface foo_array
+  subroutine foo_device_array(x)
+    real, device :: x(:,:)
+  end
+  subroutine foo_host_array(x)
+    real :: x(:,:)
+  end
+end interface
+
 contains
 
   subroutine test(obj)
@@ -34,10 +43,31 @@ subroutine test(obj)
     call foo(obj%p1)
     call foo(obj%p2)
     !$acc end host_data
+
+    call foo(obj%p1)
+
+    !$acc host_data use_device(obj%p1, obj%p2)
+    call foo(obj%p1)
+    call foo(obj%p2)
+    !$acc end host_data
   end subroutine
 ! CHECK-LABEL: func.func @_QMmPtest
 ! CHECK: fir.call @_QPfoo_device
 ! CHECK: fir.call @_QPfoo_host
+! CHECK: fir.call @_QPfoo_host
+! CHECK: fir.call @_QPfoo_device
+! CHECK: fir.call @_QPfoo_device
+
+  subroutine test_array(a, i)
+    real :: a(4,4,4)
+    integer :: i
+    !$acc host_data use_device(a(:,:,i))
+    call foo_array(a(:,:,i))
+    !$acc end host_data
+  end subroutine
+
+! CHECK-LABEL: func.func @_QMmPtest_array
+! CHECK: fir.call @_QPfoo_device_array
 
   subroutine vectoraddarray(a, b, n)
     implicit none



More information about the flang-commits mailing list