[flang-commits] [flang] [flang][cuda][openacc] Create new symbol in host_data region for CUDA Fortran interop (PR #161613)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Thu Oct 2 07:06:31 PDT 2025
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/161613
>From 6130bbc38e7bd3cba36179b52b9634db6c3de8fb Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 11:19:57 -0700
Subject: [PATCH 1/5] DRAFT
---
flang/include/flang/Lower/OpenACC.h | 3 +-
flang/include/flang/Lower/SymbolMap.h | 4 ++
flang/include/flang/Semantics/symbol.h | 2 +-
flang/lib/Lower/Bridge.cpp | 2 +-
flang/lib/Lower/OpenACC.cpp | 33 +++++++---
flang/lib/Lower/SymbolMap.cpp | 10 +++
flang/lib/Semantics/check-declarations.cpp | 3 +-
flang/lib/Semantics/resolve-directives.cpp | 5 ++
flang/lib/Semantics/resolve-names.cpp | 65 ++++++++++++++++++-
.../OpenACC/acc-host-data-cuda-device.f90 | 43 ++++++++++++
10 files changed, 156 insertions(+), 14 deletions(-)
create mode 100644 flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 19d759479abaf..4622dbc8ccf64 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -77,7 +77,8 @@ static constexpr llvm::StringRef privatizationRecipePrefix = "privatization";
mlir::Value genOpenACCConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
pft::Evaluation &,
- const parser::OpenACCConstruct &);
+ const parser::OpenACCConstruct &,
+ Fortran::lower::SymMap &localSymbols);
void genOpenACCDeclarativeConstruct(
AbstractConverter &, Fortran::semantics::SemanticsContext &,
StatementContext &, const parser::OpenACCDeclarativeConstruct &);
diff --git a/flang/include/flang/Lower/SymbolMap.h b/flang/include/flang/Lower/SymbolMap.h
index 813df777d7a39..e57b6a42d3cc1 100644
--- a/flang/include/flang/Lower/SymbolMap.h
+++ b/flang/include/flang/Lower/SymbolMap.h
@@ -260,6 +260,10 @@ class SymMap {
return lookupSymbol(*sym);
}
+ /// Find a symbol by name and return its value if it appears in the current
+ /// mappings. This lookup is more expensive as it iterates over the map.
+ const semantics::Symbol *lookupSymbolByName(llvm::StringRef symName);
+
/// Find `symbol` and return its value if it appears in the inner-most level
/// map.
SymbolBox shallowLookupSymbol(semantics::SymbolRef sym);
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index e90e9c617805d..a0d5ae7176141 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -801,7 +801,7 @@ class Symbol {
AccPrivate, AccFirstPrivate, AccShared,
// OpenACC data-mapping attribute
AccCopy, AccCopyIn, AccCopyInReadOnly, AccCopyOut, AccCreate, AccDelete,
- AccPresent, AccLink, AccDeviceResident, AccDevicePtr,
+ AccPresent, AccLink, AccDeviceResident, AccDevicePtr, AccUseDevice,
// OpenACC declare
AccDeclare,
// OpenACC data-movement attribute
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 149e51b501a82..780d56f085f69 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3182,7 +3182,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
mlir::Value exitCond = genOpenACCConstruct(
- *this, bridge.getSemanticsContext(), getEval(), acc);
+ *this, bridge.getSemanticsContext(), getEval(), acc, localSymbols);
const Fortran::parser::OpenACCLoopConstruct *accLoop =
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 95d0adae02670..66f45451b3eeb 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3184,7 +3184,8 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value ifCond;
llvm::SmallVector<mlir::Value> dataOperands;
bool addIfPresentAttr = false;
@@ -3199,6 +3200,18 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
+ // When CUDA Fotran is en
+ if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
+
+ const Fortran::parser::AccObjectList &objectList{useDevice->v};
+ for (const auto &accObject : objectList.v) {
+ Fortran::semantics::Symbol &symbol =
+ getSymbolFromAccObject(accObject);
+ const Fortran::semantics::Symbol *baseSym =
+ localSymbols.lookupSymbolByName(symbol.name().ToString());
+ localSymbols.copySymbolBinding(*baseSym, symbol);
+ }
+ }
genDataOperandOperations<mlir::acc::UseDeviceOp>(
useDevice->v, converter, semanticsContext, stmtCtx, dataOperands,
mlir::acc::DataClause::acc_use_device,
@@ -3239,11 +3252,11 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
hostDataOp.setIfPresentAttr(builder.getUnitAttr());
}
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
+static void genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCBlockConstruct &blockConstruct,
+ Fortran::lower::SymMap &localSymbols) {
const auto &beginBlockDirective =
std::get<Fortran::parser::AccBeginBlockDirective>(blockConstruct.t);
const auto &blockDirective =
@@ -3273,7 +3286,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
accClauseList);
} else if (blockDirective.v == llvm::acc::ACCD_host_data) {
genACCHostDataOp(converter, currentLocation, eval, semanticsContext,
- stmtCtx, accClauseList);
+ stmtCtx, accClauseList, localSymbols);
}
}
@@ -4647,13 +4660,15 @@ mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCConstruct &accConstruct) {
+ const Fortran::parser::OpenACCConstruct &accConstruct,
+ Fortran::lower::SymMap &localSymbols) {
mlir::Value exitCond;
Fortran::common::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
- genACC(converter, semanticsContext, eval, blockConstruct);
+ genACC(converter, semanticsContext, eval, blockConstruct,
+ localSymbols);
},
[&](const Fortran::parser::OpenACCCombinedConstruct
&combinedConstruct) {
diff --git a/flang/lib/Lower/SymbolMap.cpp b/flang/lib/Lower/SymbolMap.cpp
index 080f21ec67400..78529e0d539fb 100644
--- a/flang/lib/Lower/SymbolMap.cpp
+++ b/flang/lib/Lower/SymbolMap.cpp
@@ -45,6 +45,16 @@ Fortran::lower::SymMap::lookupSymbol(Fortran::semantics::SymbolRef symRef) {
return SymbolBox::None{};
}
+const Fortran::semantics::Symbol *
+Fortran::lower::SymMap::lookupSymbolByName(llvm::StringRef symName) {
+ for (auto jmap = symbolMapStack.rbegin(), jend = symbolMapStack.rend();
+ jmap != jend; ++jmap)
+ for (auto const &[sym, symBox] : *jmap)
+ if (sym->name().ToString() == symName)
+ return sym;
+ return nullptr;
+}
+
Fortran::lower::SymbolBox Fortran::lower::SymMap::shallowLookupSymbol(
Fortran::semantics::SymbolRef symRef) {
auto *sym = symRef->HasLocalLocality() ? &*symRef : &symRef->GetUltimate();
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index 1049a6d2c1b2e..7b881008219df 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -1189,7 +1189,8 @@ void CheckHelper::CheckObjectEntity(
}
} else if (!subpDetails && symbol.owner().kind() != Scope::Kind::Module &&
symbol.owner().kind() != Scope::Kind::MainProgram &&
- symbol.owner().kind() != Scope::Kind::BlockConstruct) {
+ symbol.owner().kind() != Scope::Kind::BlockConstruct &&
+ symbol.owner().kind() != Scope::Kind::OpenACCConstruct) {
messages_.Say(
"ATTRIBUTES(%s) may apply only to module, host subprogram, block, or device subprogram data"_err_en_US,
parser::ToUpperCaseLetters(common::EnumToString(attr)));
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index bd7b8ac552fab..366ebba1b7263 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -328,6 +328,11 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
return false;
}
+ bool Pre(const parser::AccClause::UseDevice &x) {
+ ResolveAccObjectList(x.v, Symbol::Flag::AccUseDevice);
+ return false;
+ }
+
void Post(const parser::Name &);
private:
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index d1150a9eb67f4..7e55eb005dbb2 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1387,6 +1387,8 @@ class ConstructVisitor : public virtual DeclarationVisitor {
// Create scopes for OpenACC constructs
class AccVisitor : public virtual DeclarationVisitor {
public:
+ explicit AccVisitor(SemanticsContext &context) : context_{context} {}
+
void AddAccSourceRange(const parser::CharBlock &);
static bool NeedsScope(const parser::OpenACCBlockConstruct &);
@@ -1395,6 +1397,7 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::OpenACCBlockConstruct &);
bool Pre(const parser::OpenACCCombinedConstruct &);
void Post(const parser::OpenACCCombinedConstruct &);
+ bool Pre(const parser::AccClause::UseDevice &x);
bool Pre(const parser::AccBeginBlockDirective &x) {
AddAccSourceRange(x.source);
return true;
@@ -1430,6 +1433,11 @@ class AccVisitor : public virtual DeclarationVisitor {
void Post(const parser::AccBeginLoopDirective &x) {
messageHandler().set_currStmtSource(std::nullopt);
}
+
+ void CopySymbolWithDevice(const parser::Name *name);
+
+private:
+ SemanticsContext &context_;
};
bool AccVisitor::NeedsScope(const parser::OpenACCBlockConstruct &x) {
@@ -1459,6 +1467,60 @@ bool AccVisitor::Pre(const parser::OpenACCBlockConstruct &x) {
return true;
}
+void AccVisitor::CopySymbolWithDevice(const parser::Name *name) {
+ // When CUDA Fortran is enabled together with OpenACC, new
+ // symbols are created for the one appearing in the use_device
+ // clause. These new symbols have the CUDA Fortran device
+ // attribute.
+ if (context_.languageFeatures().IsEnabled(common::LanguageFeature::CUDA)) {
+ name->symbol = currScope().CopySymbol(*name->symbol);
+ if (auto *object{name->symbol->detailsIf<ObjectEntityDetails>()}) {
+ object->set_cudaDataAttr(common::CUDADataAttr::Device);
+ }
+ }
+}
+
+bool AccVisitor::Pre(const parser::AccClause::UseDevice &x) {
+ for (const auto &accObject : x.v.v) {
+ common::visit(
+ common::visitors{
+ [&](const parser::Designator &designator) {
+ if (const auto *name{
+ semantics::getDesignatorNameIfDataRef(designator)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ } else {
+ if (const auto *dataRef{
+ std::get_if<parser::DataRef>(&designator.u)}) {
+ using ElementIndirection =
+ common::Indirection<parser::ArrayElement>;
+ if (auto *ind{std::get_if<ElementIndirection>(&dataRef->u)}) {
+ const parser::ArrayElement &arrayElement{ind->value()};
+ Walk(arrayElement.subscripts);
+ const parser::DataRef &base{arrayElement.base};
+ if (auto *name{std::get_if<parser::Name>(&base.u)}) {
+ Symbol *prev{currScope().FindSymbol(name->source)};
+ if (prev != name->symbol) {
+ name->symbol = prev;
+ }
+ CopySymbolWithDevice(name);
+ }
+ }
+ }
+ }
+ },
+ [&](const parser::Name &name) {
+ // TODO: common block in use_device?
+ },
+ },
+ accObject.u);
+ }
+ return false;
+}
+
void AccVisitor::Post(const parser::OpenACCBlockConstruct &x) {
if (NeedsScope(x)) {
PopScope();
@@ -2038,7 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
- : BaseVisitor{context, *this, rules}, topScope_{top} {
+ : BaseVisitor{context, *this, rules}, topScope_{top},
+ AccVisitor(context) {
PushScope(top);
}
diff --git a/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90 b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
new file mode 100644
index 0000000000000..da034adec39c4
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-host-data-cuda-device.f90
@@ -0,0 +1,43 @@
+
+! RUN: bbc -fopenacc -fcuda -emit-hlfir %s -o - | FileCheck %s
+
+module m
+
+interface doit
+subroutine __device_sub(a)
+ real(4), device, intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+subroutine __host_sub(a)
+ real(4), intent(in) :: a(:,:,:)
+ !dir$ ignore_tkr(c) a
+end
+end interface
+end module
+
+program testex1
+integer, parameter :: ntimes = 10
+integer, parameter :: ni=128
+integer, parameter :: nj=256
+integer, parameter :: nk=64
+real(4), dimension(ni,nj,nk) :: a
+
+!$acc enter data copyin(a)
+
+block; use m
+!$acc host_data use_device(a)
+do nt = 1, ntimes
+ call doit(a)
+end do
+!$acc end host_data
+end block
+
+block; use m
+do nt = 1, ntimes
+ call doit(a)
+end do
+end block
+end
+
+! CHECK: fir.call @_QP__device_sub
+! CHECK: fir.call @_QP__host_sub
>From 664e2234319a785aa79a8b8ed1c210e7cc804dc2 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 18:19:30 -0700
Subject: [PATCH 2/5] Fix werror
---
flang/lib/Semantics/resolve-names.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 7e55eb005dbb2..5041a6a08fc3c 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -2100,8 +2100,8 @@ class ResolveNamesVisitor : public virtual ScopeHandler,
ResolveNamesVisitor(
SemanticsContext &context, ImplicitRulesMap &rules, Scope &top)
- : BaseVisitor{context, *this, rules}, topScope_{top},
- AccVisitor(context) {
+ : BaseVisitor{context, *this, rules}, AccVisitor(context),
+ topScope_{top} {
PushScope(top);
}
>From ee30ea6b42b448915d7325b9b789c8434039d4e5 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 19:34:33 -0700
Subject: [PATCH 3/5] Complete comment
---
flang/lib/Lower/OpenACC.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 66f45451b3eeb..1e65f5559b83b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3200,9 +3200,10 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
- // When CUDA Fotran is en
+ // When CUDA Fotran is enabled, extra symbolds are used in the host_data
+ // region. Look for them and bind their value with the symbol in the outer
+ // scope.
if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
-
const Fortran::parser::AccObjectList &objectList{useDevice->v};
for (const auto &accObject : objectList.v) {
Fortran::semantics::Symbol &symbol =
>From b4afb66184e769d3837cc188064c9a29bc8d33d9 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 1 Oct 2025 19:37:38 -0700
Subject: [PATCH 4/5] Fix typos
---
flang/lib/Lower/OpenACC.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 1e65f5559b83b..731293b79d9bf 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3200,9 +3200,9 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
} else if (const auto *useDevice =
std::get_if<Fortran::parser::AccClause::UseDevice>(
&clause.u)) {
- // When CUDA Fotran is enabled, extra symbolds are used in the host_data
- // region. Look for them and bind their value with the symbol in the outer
- // scope.
+ // When CUDA Fotran is enabled, extra symbols are used in the host_data
+ // region. Look for them and bind their values with the symbols in the
+ // outer scope.
if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
const Fortran::parser::AccObjectList &objectList{useDevice->v};
for (const auto &accObject : objectList.v) {
>From e223c122aad6f357f0d3371945387d59d9c85b14 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 2 Oct 2025 07:06:16 -0700
Subject: [PATCH 5/5] clang-format
---
flang/lib/Lower/OpenACC.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 731293b79d9bf..f9b9b850ad839 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3202,7 +3202,7 @@ genACCHostDataOp(Fortran::lower::AbstractConverter &converter,
&clause.u)) {
// When CUDA Fotran is enabled, extra symbols are used in the host_data
// region. Look for them and bind their values with the symbols in the
- // outer scope.
+ // outer scope.
if (semanticsContext.IsEnabled(Fortran::common::LanguageFeature::CUDA)) {
const Fortran::parser::AccObjectList &objectList{useDevice->v};
for (const auto &accObject : objectList.v) {
More information about the flang-commits
mailing list