[flang] [llvm] [mlir] [flang][OpenMP] Implicitly map allocatable record fields (PR #117867)
Michael Klemm via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 29 10:26:57 PST 2024
================
@@ -485,6 +492,153 @@ class MapInfoFinalizationPass
// clear all local allocations we made for any boxes in any prior
// iterations from previous function scopes.
localBoxAllocas.clear();
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ // TODO Add as a method to MapClauseOwningOpInterface.
+ unsigned mapVarIdx = 0;
+ for (auto [idx, mapOp] : llvm::enumerate(mapClauseOwner.getMapVars())) {
+ if (mapOp == op) {
+ mapVarIdx = idx;
+ break;
+ }
+ }
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding paramemter to the field.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIdices;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = false;
+
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ alreadyMapped = true;
+ }
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
----------------
mjklemm wrote:
Yes, I think that's the right behavior. However, we need to bear in mind that a user might have used `declare mapper` to define a mapper for a derived type. In that case, we might need to honor the user-defined mapper instead.
https://github.com/llvm/llvm-project/pull/117867
More information about the llvm-commits
mailing list