[flang-commits] [flang] [flang][cuda][openacc] Create new symbol in host_data region for CUDA Fortran interop (PR #161613)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Oct 2 07:06:31 PDT 2025


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/161613

>From 6130bbc38e7bd3cba36179b52b9634db6c3de8fb Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 11:19:57 -0700
Subject: [PATCH 1/5] DRAFT

---
 flang/include/flang/Lower/OpenACC.h           |  3 +-
 flang/include/flang/Lower/SymbolMap.h         |  4 ++
 flang/include/flang/Semantics/symbol.h        |  2 +-
 flang/lib/Lower/Bridge.cpp                    |  2 +-
 flang/lib/Lower/OpenACC.cpp                   | 33 +++++++---
 flang/lib/Lower/SymbolMap.cpp                 | 10 +++
 flang/lib/Semantics/check-declarations.cpp    |  3 +-
 flang/lib/Semantics/resolve-directives.cpp    |  5 ++
 flang/lib/Semantics/resolve-names.cpp         | 65 ++++++++++++++++++-
 .../OpenACC/acc-host-data-cuda-device.f90     | 43 ++++++++++++
 10 files changed, 156 insertions(+), 14 deletions(-)
 create mode 100644 flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90

diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
 mlir::Value genOpenACCConstruct(AbstractConverter &,
                                 Fortran::semantics::SemanticsContext &,
                                 pft::Evaluation &,
-                                const parser::OpenACCConstruct &);
+                                const parser::OpenACCConstruct &,
+                                Fortran::lower::SymMap &localSymbols);
 void genOpenACCDeclarativeConstruct(
     AbstractConverter &, Fortran::semantics::SemanticsContext &,
     StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
     return lookupSymbol(*sym);
   }
 
+  /// Find a symbol by name and return its value if it appears in the current
+  /// mappings. This lookup is more expensive as it iterates over the map.
+  const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
   /// Find `symbol` and return its value if it appears in the inner-most level
   /// map.
   SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
       AccPrivate, AccFirstPrivate, AccShared,
       // OpenACC data-mapping attribute
       AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
-      AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+      AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
       // OpenACC declare
       AccDeclare,
       // OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
     mlir::Value exitCond = genOpenACCConstruct(
-        *this, bridge.getSemanticsContext(), getEval(), acc);
+        *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
 
     const Fortran::parser::OpenACCLoopConstruct *accLoop =
         std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::lower::pft::Evaluation &eval,
                  Fortran::semantics::SemanticsContext &semanticsContext,
                  Fortran::lower::StatementContext &stmtCtx,
-                 const Fortran::parser::AccClauseList &accClauseList) {
+                 const Fortran::parser::AccClauseList &accClauseList,
+                 Fortran::lower::SymMap &localSymbols) {
   mlir::Value ifCond;
   llvm::SmallVector<mlir::Value> dataOperands;
   bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
+      // When CUDA Fotran is en
+      if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+        const Fortran::parser::AccObjectList &objectList{useDevice->v};
+        for (const auto &accObject : objectList.v) {
+          Fortran::semantics::Symbol &symbol =
+              getSymbolFromAccObject(accObject);
+          const Fortran::semantics::Symbol *baseSym =
+              localSymbols.lookupSymbolByName(symbol.name().ToString());
+          localSymbols.copySymbolBinding(*baseSym, symbol);
+        }
+      }
       genDataOperandOperations<mlir::acc::UseDeviceOp>(
           useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
           mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     hostDataOp.setIfPresentAttr(builder.getUnitAttr());
 }
 
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
-       Fortran::semantics::SemanticsContext &semanticsContext,
-       Fortran::lower::pft::Evaluation &eval,
-       const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+                   Fortran::semantics::SemanticsContext &semanticsContext,
+                   Fortran::lower::pft::Evaluation &eval,
+                   const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+                   Fortran::lower::SymMap &localSymbols) {
   const auto &beginBlockDirective =
       std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
   const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
                                           accClauseList);
   } else if (blockDirective.v == llvm::acc::ACCD_host_data) {
     genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
-                     stmtCtx, accClauseList);
+                     stmtCtx, accClauseList, localSymbols);
   }
 }
 
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
-    const Fortran::parser::OpenACCConstruct &accConstruct) {
+    const Fortran::parser::OpenACCConstruct &accConstruct,
+    Fortran::lower::SymMap &localSymbols) {
 
   mlir::Value exitCond;
   Fortran::common::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
-            genACC(converter, semanticsContext, eval, blockConstruct);
+            genACC(converter, semanticsContext, eval, blockConstruct,
+                   localSymbols);
           },
           [&](const Fortran::parser::OpenACCCombinedConstruct
                   &combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
   return SymbolBox::None{};
 }
 
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+  for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+       jmap != jend; ++jmap)
+    for (auto const &[sym, symBox] : *jmap)
+      if (sym->name().ToString() == symName)
+        return sym;
+  return nullptr;
+}
+
 Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
     Fortran::semantics::SymbolRef symRef) {
   auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
       }
     } else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
         symbol.owner().kind() != Scope::Kind::MainProgram &&
-        symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+        symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+        symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
       messages_.Say(
           "ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
           parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
     return false;
   }
 
+  bool Pre(const parser::AccClause::UseDevice &x) {
+    ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+    return false;
+  }
+
   void Post(const parser::Name &);
 
 private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
 // Create scopes for OpenACC constructs
 class AccVisitor : public virtual DeclarationVisitor {
 public:
+  explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
   void AddAccSourceRange(const parser::CharBlock &);
 
   static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::OpenACCBlockConstruct &);
   bool Pre(const parser::OpenACCCombinedConstruct &);
   void Post(const parser::OpenACCCombinedConstruct &);
+  bool Pre(const parser::AccClause::UseDevice &x);
   bool Pre(const parser::AccBeginBlockDirective &x) {
     AddAccSourceRange(x.source);
     return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
   void Post(const parser::AccBeginLoopDirective &x) {
     messageHandler().set_currStmtSource(std::nullopt);
   }
+
+  void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+  SemanticsContext &context_;
 };
 
 bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
   return true;
 }
 
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+  // When CUDA Fortran is enabled together with OpenACC, new
+  // symbols are created for the one appearing in the use_device
+  // clause. These new symbols have the CUDA Fortran device
+  // attribute.
+  if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+    name->symbol = currScope().CopySymbol(*name->symbol);
+    if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+      object->set_cudaDataAttr(common::CUDADataAttr::Device);
+    }
+  }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+  for (const auto &accObject : x.v.v) {
+    common::visit(
+        common::visitors{
+            [&](const parser::Designator &designator) {
+              if (const auto *name{
+                      semantics::getDesignatorNameIfDataRef(designator)}) {
+                Symbol *prev{currScope().FindSymbol(name->source)};
+                if (prev != name->symbol) {
+                  name->symbol = prev;
+                }
+                CopySymbolWithDevice(name);
+              } else {
+                if (const auto *dataRef{
+                        std::get_if<parser::DataRef>(&designator.u)}) {
+                  using ElementIndirection =
+                      common::Indirection<parser::ArrayElement>;
+                  if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+                    const parser::ArrayElement &arrayElement{ind->value()};
+                    Walk(arrayElement.subscripts);
+                    const parser::DataRef &base{arrayElement.base};
+                    if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+                      Symbol *prev{currScope().FindSymbol(name->source)};
+                      if (prev != name->symbol) {
+                        name->symbol = prev;
+                      }
+                      CopySymbolWithDevice(name);
+                    }
+                  }
+                }
+              }
+            },
+            [&](const parser::Name &name) {
+              // TODO: common block in use_device?
+            },
+        },
+        accObject.u);
+  }
+  return false;
+}
+
 void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
   if (NeedsScope(x)) {
     PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
 
   ResolveNamesVisitor(
       SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
-      : BaseVisitor{context, *this, rules}, topScope_{top} {
+      : BaseVisitor{context, *this, rules}, topScope_{top},
+        AccVisitor(context) {
     PushScope(top);
   }
 
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+    real(4), device, intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+    real(4), intent(in) :: a(:,:,:)
+    !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+  call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+  call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub

>From 664e2234319a785aa79a8b8ed1c210e7cc804dc2 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 18:19:30 -0700
Subject: [PATCH 2/5] Fix werror

---
 flang/lib/Semantics/resolve-names.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 7e55eb005dbb2..5041a6a08fc3c 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -2100,8 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
 
   ResolveNamesVisitor(
       SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
-      : BaseVisitor{context, *this, rules}, topScope_{top},
-        AccVisitor(context) {
+      : BaseVisitor{context, *this, rules}, AccVisitor(context),
+        topScope_{top} {
     PushScope(top);
   }
 

>From ee30ea6b42b448915d7325b9b789c8434039d4e5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 19:34:33 -0700
Subject: [PATCH 3/5] Complete comment

---
 flang/lib/Lower/OpenACC.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 66f45451b3eeb..1e65f5559b83b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3200,9 +3200,10 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
-      // When CUDA Fotran is en
+      // When CUDA Fotran is enabled, extra symbolds are used in the host_data
+      // region. Look for them and bind their value with the symbol in the outer
+      // scope. 
       if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
-
         const Fortran::parser::AccObjectList &objectList{useDevice->v};
         for (const auto &accObject : objectList.v) {
           Fortran::semantics::Symbol &symbol =

>From b4afb66184e769d3837cc188064c9a29bc8d33d9 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 19:37:38 -0700
Subject: [PATCH 4/5] Fix typos

---
 flang/lib/Lower/OpenACC.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 1e65f5559b83b..731293b79d9bf 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3200,9 +3200,9 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
     } else if (const auto *useDevice =
                    std::get_if<Fortran::parser::AccClause::UseDevice>(
                        &clause.u)) {
-      // When CUDA Fotran is enabled, extra symbolds are used in the host_data
-      // region. Look for them and bind their value with the symbol in the outer
-      // scope. 
+      // When CUDA Fotran is enabled, extra symbols are used in the host_data
+      // region. Look for them and bind their values with the symbols in the
+      // outer scope. 
       if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
         const Fortran::parser::AccObjectList &objectList{useDevice->v};
         for (const auto &accObject : objectList.v) {

>From e223c122aad6f357f0d3371945387d59d9c85b14 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 2 Oct 2025 07:06:16 -0700
Subject: [PATCH 5/5] clang-format

---
 flang/lib/Lower/OpenACC.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 731293b79d9bf..f9b9b850ad839 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3202,7 +3202,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
                        &clause.u)) {
       // When CUDA Fotran is enabled, extra symbols are used in the host_data
       // region. Look for them and bind their values with the symbols in the
-      // outer scope. 
+      // outer scope.
       if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
         const Fortran::parser::AccObjectList &objectList{useDevice->v};
         for (const auto &accObject : objectList.v) {



More information about the flang-commits mailing list