[flang-commits] [flang] [llvm] [mlir] [flang][OpenMP] Implicitly map allocatable record fields (PR #117867)
Kareem Ergawy via flang-commits
flang-commits at lists.llvm.org
Fri Nov 29 01:14:29 PST 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/117867
>From 298c1fb895b27780c638db6709de41a456d0aeac Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 26 Nov 2024 22:17:34 -0600
Subject: [PATCH] [flang][OpenMP] Implicitly map allocatable record fields
This is a starting PR to implicitly map allocatable record fields.
This PR contains the following changes:
1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that
these utils work on the `mlir::Value` level rather than the
`semantics::Symbol` level. This takes one step towards to enabling
MLIR passes to more easily do some lowering themselves (e.g. creating
`omp.map.bounds` ops for implicitely caputured data like this PR
does).
2. Adds support for implicitely capturing and mapping allocatable fields
in record types.
There is quite some distant to still cover to have full support for
this. I added a number of todos to guide further development.
---
.../flang}/Lower/DirectivesCommon.h | 50 ++++--
flang/lib/Lower/Bridge.cpp | 3 +-
flang/lib/Lower/OpenACC.cpp | 3 +-
flang/lib/Lower/OpenMP/ClauseProcessor.h | 2 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 23 +--
flang/lib/Lower/OpenMP/Utils.cpp | 2 +-
flang/lib/Optimizer/OpenMP/CMakeLists.txt | 2 +
.../Optimizer/OpenMP/MapInfoFinalization.cpp | 155 ++++++++++++++++++
...p-map-info-finalization-implicit-field.fir | 63 +++++++
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 7 +
.../fortran/implicit-record-field-mapping.f90 | 52 ++++++
11 files changed, 326 insertions(+), 36 deletions(-)
rename flang/{lib => include/flang}/Lower/DirectivesCommon.h (97%)
create mode 100644 flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
create mode 100644 offload/test/offloading/fortran/implicit-record-field-mapping.f90
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
}
}
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
- fir::FirOpBuilder &builder,
- Fortran::lower::SymbolRef sym, mlir::Location loc) {
- mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+ mlir::Value symAddr,
+ bool isOptional,
+ mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
rawInput = declareOp.getResults()[1];
}
- // TODO: Might need revisiting to handle for non-shared clauses
- if (!symAddr) {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- symAddr = converter.getSymbolAddress(details->symbol());
- rawInput = symAddr;
- }
- }
-
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
mlir::Value isPresent;
- if (Fortran::semantics::IsOptional(sym))
+ if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
- !Fortran::semantics::IsOptional(sym)) {
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ Fortran::lower::SymbolRef sym, mlir::Location loc) {
+ return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+ Fortran::semantics::IsOptional(sym), loc);
+}
+
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
return info;
}
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+ fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+ mlir::Location loc) {
+ llvm::SmallVector<mlir::Value> bounds;
+
+ mlir::Value baseOp = info.rawInput;
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+ bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, info);
+ if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+ bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, loc, dataExv, dataExvIsAssumedSize);
+ }
+
+ return bounds;
+}
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 77003eff190e26..51df81238e4c00 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 878dccc4ecbc4b..7712627e767961 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 217d7c6917bd61..e726a7bacef4e6 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 3bb43b766bcebf..669884d646b144 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1731,32 +1731,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
- llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
- mlir::Value baseOp = info.rawInput;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
- bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv, info);
- if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
- bool dataExvIsAssumedSize =
- semantics::IsAssumedSizeArray(sym.GetUltimate());
- bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv,
- dataExvIsAssumedSize);
- }
+ llvm::SmallVector<mlir::Value> bounds =
+ lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, dataExv,
+ semantics::IsAssumedSizeArray(sym.GetUltimate()),
+ converter.getCurrentLocation());
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
+ mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
#include "Utils.h"
#include "Clauses.h"
-#include <DirectivesCommon.h>
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
+ ${dialect_libs}
LINK_LIBS
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
+ ${dialect_libs}
)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..5c8bb92d5fa62b 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -486,6 +490,157 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
+ // First, walk `omp.map.info` ops to see if any record members should be
+ // implicitly mapped.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ // TODO For now, only consider `omp.target` ops. Other ops that support
+ // `map` clauses will follow later.
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+ assert(mapVarIdx >= 0 &&
+ mapVarIdx <
+ static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ // TODO How should `map` block argument that correspond to: `private`,
+ // `use_device_addr`, `use_device_ptr`, be handled?
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding paramemter to the field in the called function.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIdices;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types. Adapting
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+ // entities might be helpful here.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = [&]() {
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ return true;
+ }
+
+ return false;
+ }();
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ field + ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ fieldIdices.emplace_back(fieldIdx);
+ }
+
+ if (newMapOpsForFields.empty())
+ return mlir::WalkResult::advance();
+
+ op.getMembersMutable().append(newMapOpsForFields);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+ if (oldMembersIdxAttr)
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ llvm::SmallVector<int64_t> listVec;
+
+ for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+ listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+ newMemberIndices.emplace_back(std::move(listVec));
+ }
+
+ for (int64_t newFieldIdx : fieldIdices)
+ newMemberIndices.emplace_back(
+ llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+ op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+
+ return mlir::WalkResult::advance();
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..73364569d76336
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecored_t{
+ not_to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>,
+ to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+ %0 = fir.undefined !record_t
+ fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+ %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+ %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+ omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+ %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %23 = hlfir.designate %20#0{"to_implicitly_map"} {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<!record_t>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+ omp.terminator
+ }
+ return
+}
+
+// CHECK: %[[RECORD_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "_QFEdst_record"}
+// CHECK: %[[FIELD_COORD:.*]] = fir.coordinate_of %[[RECORD_DECL]]#1, %{{c1.*}}
+
+// CHECK: %[[UPPER_BOUND:.*]] = arith.subi %{{.*}}#1, %{{c1.*}} : index
+
+// CHECK: %[[BOUNDS:.*]] = omp.map.bounds
+// CHECK-SAME: lower_bound(%{{c0.*}} : index) upper_bound(%[[UPPER_BOUND]] : index)
+// CHECK-SAME: extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index)
+// CHECK-SAME: start_idx(%{{.*}}#0 : index) {stride_in_bytes = true}
+
+// CHECK: %[[BASE_ADDR:.*]] = fir.box_offset %[[FIELD_COORD]] base_addr
+// CHECK: %[[FIELD_BASE_ADDR_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) var_ptr_ptr(
+// CHECK-SAME: %[[BASE_ADDR]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) bounds(
+// CHECK-SAME: %[[BOUNDS]])
+
+// CHECK: %[[FIELD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, to) capture(ByRef) ->
+// CHECK-SAME: {{.*}} {name = "dst_record.to_implicitly_map.implicit_map"}
+
+// CHECK: %[[RECORD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[RECORD_DECL]]#1 : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) members(
+// CHECK-SAME: %[[FIELD_MAP]], %[[FIELD_BASE_ADDR_MAP]] :
+// CHECK-SAME: [1], [1, 0] : {{.*}}) -> {{.*}}> {name =
+// CHECK-SAME: "dst_record"}
+
+// CHECK: omp.target map_entries(
+// CHECK-SAME: %[[RECORD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_BASE_ADDR_MAP]] -> %{{[^[:space:]]+}}
+// CHECK-SAME: : {{.*}})
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c4cf0f7afb3a34 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -193,6 +193,13 @@ def MapClauseOwningOpInterface : OpInterface<"MapClauseOwningOpInterface"> {
(ins), [{
return $_op.getMapVarsMutable();
}]>,
+ InterfaceMethod<"Get operand index for a map clause",
+ "int64_t",
+ "getOperandIndexForMap",
+ (ins "::mlir::Value":$map), [{
+ return std::distance($_op.getMapVars().begin(),
+ llvm::find($_op.getMapVars(), map));
+ }]>,
];
}
diff --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..0d49335903f8b6
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,52 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeedds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+program test_implicit_field_mapping
+ implicit none
+
+ type recored_t
+ real, allocatable :: not_to_implicitly_map(:)
+ real, allocatable :: to_implicitly_map(:)
+ end type
+
+ type(recored_t) :: dst_record
+ real :: src_array(10)
+ real :: dst_sum, src_sum
+ integer :: i
+
+ call random_number(src_array)
+ dst_sum = 0
+ src_sum = 0
+
+ do i=1,10
+ src_sum = src_sum + src_array(i)
+ end do
+ print *, "src_sum=", src_sum
+
+ !$omp target map(from: dst_sum)
+ dst_record%to_implicitly_map = src_array
+ dst_sum = 0
+
+ do i=1,10
+ dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+ end do
+ !$omp end target
+
+ print *, "dst_sum=", dst_sum
+
+ if (src_sum == dst_sum) then
+ print *, "Test succeeded!"
+ else
+ print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+ endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!
More information about the flang-commits
mailing list