[llvm-branch-commits] [flang] [OpenMP][Flang] Emit default declare mappers implicitly for derived types (PR #140562)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon May 19 08:38:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Akash Banerjee (TIFitis)
<details>
<summary>Changes</summary>
This patch adds support to emit default declare mappers for implicit mapping of derived types when not supplied by user. This especially helps tackle mapping of allocatables of derived types.
---
Patch is 33.53 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140562.diff
5 Files Affected:
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+23-16)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+131-1)
- (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+1-1)
- (modified) flang/lib/Semantics/resolve-names.cpp (+47-48)
- (modified) flang/test/Lower/OpenMP/derived-type-map.f90 (+21-1)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 8dcc8be9be5bf..cf25e91a437b8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1102,23 +1102,30 @@ void ClauseProcessor::processMapObjects(
auto getDefaultMapperID = [&](const omp::Object &object,
std::string &mapperIdName) {
- if (!mlir::isa<mlir::omp::DeclareMapperOp>(
- firOpBuilder.getRegion().getParentOp())) {
- const semantics::DerivedTypeSpec *typeSpec = nullptr;
-
- if (object.sym()->owner().IsDerivedType())
- typeSpec = object.sym()->owner().derivedTypeSpec();
- else if (object.sym()->GetType() &&
- object.sym()->GetType()->category() ==
- semantics::DeclTypeSpec::TypeDerived)
- typeSpec = &object.sym()->GetType()->derivedTypeSpec();
-
- if (typeSpec) {
- mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
- if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
- mapperIdName = converter.mangleName(mapperIdName, sym->owner());
- }
+ const semantics::DerivedTypeSpec *typeSpec = nullptr;
+
+ if (object.sym()->GetType() && object.sym()->GetType()->category() ==
+ semantics::DeclTypeSpec::TypeDerived)
+ typeSpec = &object.sym()->GetType()->derivedTypeSpec();
+ else if (object.sym()->owner().IsDerivedType())
+ typeSpec = object.sym()->owner().derivedTypeSpec();
+
+ if (typeSpec) {
+ mapperIdName = typeSpec->name().ToString() + ".omp.default.mapper";
+ if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+ mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+ else
+ mapperIdName =
+ converter.mangleName(mapperIdName, *typeSpec->GetScope());
}
+
+ // Make sure we don't return a mapper to self
+ llvm::StringRef parentOpName;
+ if (auto declMapOp = mlir::dyn_cast<mlir::omp::DeclareMapperOp>(
+ firOpBuilder.getRegion().getParentOp()))
+ parentOpName = declMapOp.getSymName();
+ if (mapperIdName == parentOpName)
+ mapperIdName = "";
};
// Create the mapper symbol from its name, if specified.
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index cfcba0159db8d..8d5c26a4f2d58 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2348,6 +2348,124 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
queue, item, clauseOps);
}
+static mlir::FlatSymbolRefAttr
+genImplicitDefaultDeclareMapper(lower::AbstractConverter &converter,
+ mlir::Location loc, fir::RecordType recordType,
+ llvm::StringRef mapperNameStr) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ lower::StatementContext stmtCtx;
+
+ // Save current insertion point before moving to the module scope to create
+ // the DeclareMapperOp
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+
+ firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
+ auto declMapperOp = firOpBuilder.create<mlir::omp::DeclareMapperOp>(
+ loc, mapperNameStr, recordType);
+ auto ®ion = declMapperOp.getRegion();
+ firOpBuilder.createBlock(®ion);
+ auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
+
+ auto declareOp =
+ firOpBuilder.create<hlfir::DeclareOp>(loc, mapperArg, /*uniq_name=*/"");
+
+ const auto genBoundsOps = [&](mlir::Value mapVal,
+ llvm::SmallVectorImpl<mlir::Value> &bounds) {
+ fir::ExtendedValue extVal =
+ hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
+ hlfir::Entity{mapVal},
+ /*contiguousHint=*/true)
+ .first;
+ fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
+ firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
+ bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, extVal,
+ /*dataExvIsAssumedSize=*/false, mapVal.getLoc());
+ };
+
+ // Return a reference to the contents of a derived type with one field.
+ // Also return the field type.
+ const auto getFieldRef =
+ [&](mlir::Value rec,
+ unsigned index) -> std::tuple<mlir::Value, mlir::Type> {
+ auto recType = mlir::dyn_cast<fir::RecordType>(
+ fir::unwrapPassByRefType(rec.getType()));
+ auto [fieldName, fieldTy] = recType.getTypeList()[index];
+ mlir::Value field = firOpBuilder.create<fir::FieldIndexOp>(
+ loc, fir::FieldType::get(recType.getContext()), fieldName, recType,
+ fir::getTypeParams(rec));
+ return {firOpBuilder.create<fir::CoordinateOp>(
+ loc, firOpBuilder.getRefType(fieldTy), rec, field),
+ fieldTy};
+ };
+
+ mlir::omp::DeclareMapperInfoOperands clauseOps;
+ llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
+ llvm::SmallVector<mlir::Value> memberMapOps;
+
+ llvm::omp::OpenMPOffloadMappingFlags mapFlag =
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM |
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+ mlir::omp::VariableCaptureKind captureKind =
+ mlir::omp::VariableCaptureKind::ByRef;
+ int64_t index = 0;
+
+ // Populate the declareMapper region with the map information.
+ for (const auto &[memberName, memberType] :
+ mlir::dyn_cast<fir::RecordType>(recordType).getTypeList()) {
+ auto [ref, type] = getFieldRef(declareOp.getBase(), index);
+ mlir::FlatSymbolRefAttr mapperId;
+ if (auto recType = mlir::dyn_cast<fir::RecordType>(memberType)) {
+ std::string mapperIdName =
+ recType.getName().str() + ".omp.default.mapper";
+ if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
+ mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+ else if (auto *sym = converter.getCurrentScope().FindSymbol(memberName))
+ mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+
+ if (converter.getModuleOp().lookupSymbol(mapperIdName))
+ mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+ mapperIdName);
+ else
+ mapperId = genImplicitDefaultDeclareMapper(converter, loc, recType,
+ mapperIdName);
+ }
+
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(ref, bounds);
+ mlir::Value mapOp = createMapInfoOp(
+ firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, "", bounds,
+ /*members=*/{},
+ /*membersIndex=*/mlir::ArrayAttr{},
+ static_cast<
+ std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ mapFlag),
+ captureKind, ref.getType(), /*partialMap=*/false, mapperId);
+ memberMapOps.emplace_back(mapOp);
+ memberPlacementIndices.emplace_back(llvm::SmallVector<int64_t>{index++});
+ }
+
+ llvm::SmallVector<mlir::Value> bounds;
+ genBoundsOps(declareOp.getOriginalBase(), bounds);
+ mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+ firOpBuilder, loc, declareOp.getOriginalBase(),
+ /*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
+ firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices),
+ static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+ mapFlag),
+ captureKind, declareOp.getType(0),
+ /*partialMap=*/true);
+
+ clauseOps.mapVars.emplace_back(mapOp);
+
+ firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+ // declMapperOp->dumpPretty();
+ return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
+ mapperNameStr);
+}
+
static mlir::omp::TargetOp
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
@@ -2420,15 +2538,26 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
name << sym.name().ToString();
mlir::FlatSymbolRefAttr mapperId;
- if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived) {
+ if (sym.GetType()->category() == semantics::DeclTypeSpec::TypeDerived &&
+ defaultMaps.empty()) {
auto &typeSpec = sym.GetType()->derivedTypeSpec();
std::string mapperIdName =
typeSpec.name().ToString() + ".omp.default.mapper";
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
+ else
+ mapperIdName =
+ converter.mangleName(mapperIdName, *typeSpec.GetScope());
+
if (converter.getModuleOp().lookupSymbol(mapperIdName))
mapperId = mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperIdName);
+ else
+ mapperId = genImplicitDefaultDeclareMapper(
+ converter, loc,
+ mlir::cast<fir::RecordType>(
+ converter.genType(sym.GetType()->derivedTypeSpec())),
+ mapperIdName);
}
fir::factory::AddrAndBoundsInfo info =
@@ -4039,6 +4168,7 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processMap(loc, stmtCtx, clauseOps);
firOpBuilder.create<mlir::omp::DeclareMapperInfoOp>(loc, clauseOps.mapVars);
+ // declMapperOp->dumpPretty();
}
static void
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 86095d708f7e5..1ef6239e51fa8 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -377,7 +377,7 @@ class MapInfoFinalizationPass
getDescriptorMapType(mapType, target)),
op.getMapCaptureTypeAttr(), /*varPtrPtr=*/mlir::Value{}, newMembers,
newMembersAttr, /*bounds=*/mlir::SmallVector<mlir::Value>{},
- /*mapperId*/ mlir::FlatSymbolRefAttr(), op.getNameAttr(),
+ op.getMapperIdAttr(), op.getNameAttr(),
/*partial_map=*/builder.getBoolAttr(false));
op.replaceAllUsesWith(newDescParentMapOp.getResult());
op->erase();
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 322562b06b87f..318af99a57848 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -2163,7 +2163,7 @@ void AttrsVisitor::SetBindNameOn(Symbol &symbol) {
}
symbol.SetBindName(std::move(*label));
if (!oldBindName.empty()) {
- if (const std::string * newBindName{symbol.GetBindName()}) {
+ if (const std::string *newBindName{symbol.GetBindName()}) {
if (oldBindName != *newBindName) {
Say(symbol.name(),
"The entity '%s' has multiple BIND names ('%s' and '%s')"_err_en_US,
@@ -2285,7 +2285,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
// expression semantics if the DeclTypeSpec is a valid TypeSpec.
// The grammar ensures that it's an intrinsic or derived type spec,
// not TYPE(*) or CLASS(*) or CLASS(T).
- if (const DeclTypeSpec * spec{state_.declTypeSpec}) {
+ if (const DeclTypeSpec *spec{state_.declTypeSpec}) {
switch (spec->category()) {
case DeclTypeSpec::Numeric:
case DeclTypeSpec::Logical:
@@ -2293,7 +2293,7 @@ void DeclTypeSpecVisitor::Post(const parser::TypeSpec &typeSpec) {
typeSpec.declTypeSpec = spec;
break;
case DeclTypeSpec::TypeDerived:
- if (const DerivedTypeSpec * derived{spec->AsDerived()}) {
+ if (const DerivedTypeSpec *derived{spec->AsDerived()}) {
CheckForAbstractType(derived->typeSymbol()); // C703
typeSpec.declTypeSpec = spec;
}
@@ -2802,8 +2802,8 @@ Symbol &ScopeHandler::MakeSymbol(const parser::Name &name, Attrs attrs) {
Symbol &ScopeHandler::MakeHostAssocSymbol(
const parser::Name &name, const Symbol &hostSymbol) {
Symbol &symbol{*NonDerivedTypeScope()
- .try_emplace(name.source, HostAssocDetails{hostSymbol})
- .first->second};
+ .try_emplace(name.source, HostAssocDetails{hostSymbol})
+ .first->second};
name.symbol = &symbol;
symbol.attrs() = hostSymbol.attrs(); // TODO: except PRIVATE, PUBLIC?
// These attributes can be redundantly reapplied without error
@@ -2891,7 +2891,7 @@ void ScopeHandler::ApplyImplicitRules(
if (context().HasError(symbol) || !NeedsType(symbol)) {
return;
}
- if (const DeclTypeSpec * type{GetImplicitType(symbol)}) {
+ if (const DeclTypeSpec *type{GetImplicitType(symbol)}) {
if (!skipImplicitTyping_) {
symbol.set(Symbol::Flag::Implicit);
symbol.SetType(*type);
@@ -2991,7 +2991,7 @@ const DeclTypeSpec *ScopeHandler::GetImplicitType(
const auto *type{implicitRulesMap_->at(scope).GetType(
symbol.name(), respectImplicitNoneType)};
if (type) {
- if (const DerivedTypeSpec * derived{type->AsDerived()}) {
+ if (const DerivedTypeSpec *derived{type->AsDerived()}) {
// Resolve any forward-referenced derived type; a quick no-op else.
auto &instantiatable{*const_cast<DerivedTypeSpec *>(derived)};
instantiatable.Instantiate(currScope());
@@ -3706,10 +3706,10 @@ void ModuleVisitor::DoAddUse(SourceName location, SourceName localName,
} else if (IsPointer(p1) || IsPointer(p2)) {
return false;
} else if (const auto *subp{p1.detailsIf<SubprogramDetails>()};
- subp && !subp->isInterface()) {
+ subp && !subp->isInterface()) {
return false; // defined in module, not an external
} else if (const auto *subp{p2.detailsIf<SubprogramDetails>()};
- subp && !subp->isInterface()) {
+ subp && !subp->isInterface()) {
return false; // defined in module, not an external
} else {
// Both are external interfaces, perhaps to the same procedure
@@ -3969,7 +3969,7 @@ Scope *ModuleVisitor::FindModule(const parser::Name &name,
if (scope) {
if (DoesScopeContain(scope, currScope())) { // 14.2.2(1)
std::optional<SourceName> submoduleName;
- if (const Scope * container{FindModuleOrSubmoduleContaining(currScope())};
+ if (const Scope *container{FindModuleOrSubmoduleContaining(currScope())};
container && container->IsSubmodule()) {
submoduleName = container->GetName();
}
@@ -4074,7 +4074,7 @@ bool InterfaceVisitor::isAbstract() const {
void InterfaceVisitor::AddSpecificProcs(
const std::list<parser::Name> &names, ProcedureKind kind) {
- if (Symbol * symbol{GetGenericInfo().symbol};
+ if (Symbol *symbol{GetGenericInfo().symbol};
symbol && symbol->has<GenericDetails>()) {
for (const auto &name : names) {
specificsForGenericProcs_.emplace(symbol, std::make_pair(&name, kind));
@@ -4174,7 +4174,7 @@ void GenericHandler::DeclaredPossibleSpecificProc(Symbol &proc) {
}
void InterfaceVisitor::ResolveNewSpecifics() {
- if (Symbol * generic{genericInfo_.top().symbol};
+ if (Symbol *generic{genericInfo_.top().symbol};
generic && generic->has<GenericDetails>()) {
ResolveSpecificsInGeneric(*generic, false);
}
@@ -4259,7 +4259,7 @@ bool SubprogramVisitor::HandleStmtFunction(const parser::StmtFunctionStmt &x) {
name.source);
MakeSymbol(name, Attrs{}, UnknownDetails{});
} else if (auto *entity{ultimate.detailsIf<EntityDetails>()};
- entity && !ultimate.has<ProcEntityDetails>()) {
+ entity && !ultimate.has<ProcEntityDetails>()) {
resultType = entity->type();
ultimate.details() = UnknownDetails{}; // will be replaced below
} else {
@@ -4315,7 +4315,7 @@ bool SubprogramVisitor::Pre(const parser::Suffix &suffix) {
} else {
Message &msg{Say(*suffix.resultName,
"RESULT(%s) may appear only in a function"_err_en_US)};
- if (const Symbol * subprogram{InclusiveScope().symbol()}) {
+ if (const Symbol *subprogram{InclusiveScope().symbol()}) {
msg.Attach(subprogram->name(), "Containing subprogram"_en_US);
}
}
@@ -4833,7 +4833,7 @@ Symbol *ScopeHandler::FindSeparateModuleProcedureInterface(
symbol = generic->specific();
}
}
- if (const Symbol * defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
+ if (const Symbol *defnIface{FindSeparateModuleSubprogramInterface(symbol)}) {
// Error recovery in case of multiple definitions
symbol = const_cast<Symbol *>(defnIface);
}
@@ -4970,8 +4970,8 @@ bool SubprogramVisitor::HandlePreviousCalls(
return generic->specific() &&
HandlePreviousCalls(name, *generic->specific(), subpFlag);
} else if (const auto *proc{symbol.detailsIf<ProcEntityDetails>()}; proc &&
- !proc->isDummy() &&
- !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
+ !proc->isDummy() &&
+ !symbol.attrs().HasAny(Attrs{Attr::INTRINSIC, Attr::POINTER})) {
// There's a symbol created for previous calls to this subprogram or
// ENTRY's name. We have to replace that symbol in situ to avoid the
// obligation to rewrite symbol pointers in the parse tree.
@@ -5012,7 +5012,7 @@ void SubprogramVisitor::CheckExtantProc(
if (auto *prev{FindSymbol(name)}) {
if (IsDummy(*prev)) {
} else if (auto *entity{prev->detailsIf<EntityDetails>()};
- IsPointer(*prev) && entity && !entity->type()) {
+ IsPointer(*prev) && entity && !entity->type()) {
// POINTER attribute set before interface
} else if (inInterfaceBlock() && currScope() != prev->owner()) {
// Procedures in an INTERFACE block do not resolve to symbols
@@ -5068,7 +5068,7 @@ Symbol &SubprogramVisitor::PushSubprogramScope(const parser::Name &name,
}
set_inheritFromParent(false); // interfaces don't inherit, even if MODULE
}
- if (Symbol * found{FindSymbol(name)};
+ if (Symbol *found{FindSymbol(name)};
found && found->has<HostAssocDetails>()) {
found->set(subpFlag); // PushScope() created symbol
}
@@ -5910,9 +5910,9 @@ void DeclarationVisitor::Post(const parser::VectorTypeSpec &x) {
vectorDerivedType.CookParameters(GetFoldingContext());
}
- if (const DeclTypeSpec *
- extant{ppcBuiltinTypesScope->FindInstantiatedDerivedType(
- vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
+ if (const DeclTypeSpec *extant{
+ ppcBuiltinTypesScope->FindInstantiatedDerivedType(
+ vectorDerivedType, DeclTypeSpec::Category::TypeDerived)}) {
// This derived type and parameter expressions (if any) are already present
// in the __ppc_intrinsics scope.
SetDeclTypeSpec(*extant);
@@ -5934,7 +5934,7 @@ bool DeclarationVisitor::Pre(const parser::DeclarationTypeSpec::Type &) {
void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Type &type) {
const parser::Name &derivedName{std::get<parser::Name>(type.derived.t)};
- if (const Symbol * derivedSymbol{derivedName.symbol}) {
+ if (const Symbol *derivedSymbol{derivedName.symbol}) {
CheckForAbstractType(*derivedSymbol); // C706
}
}
@@ -6003,8 +6003,8 @@ void DeclarationVisitor::Post(const parser::DerivedTypeSpec &x) {
if (!spec->MightBeParameterized()) {
spec->EvaluateParameters(context());
}
- if (const DeclTypeSpec *
- extant{currScope().FindInstantiatedDerivedType(*spec, category)}) {
+ if (const DeclTypeSpec *extant{
+ currScope().FindInstantiatedDerivedType(*spec, category)}) {
// This derived type and parameter expressions (if any) are already present
// in this scope.
SetDeclTypeSpec(*extant);
@@ -6035,8 +6035,7 @@ void DeclarationVisitor::Post(const parser::DeclarationTypeSpec::Record &rec) {
if (auto spec{ResolveDerivedType(typeName)}) {
spec->CookParameters(GetFoldingContext());
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/140562
More information about the llvm-branch-commits
mailing list