[Mlir-commits] [mlir] e532241 - Re-apply (#117867): [flang][OpenMP] Implicitly map allocatable record fields (#120374)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 18 00:19:51 PST 2024


Author: Kareem Ergawy
Date: 2024-12-18T09:19:45+01:00
New Revision: e532241b021cd48bad303721757c1194bc844775

URL: https://github.com/llvm/llvm-project/commit/e532241b021cd48bad303721757c1194bc844775
DIFF: https://github.com/llvm/llvm-project/commit/e532241b021cd48bad303721757c1194bc844775.diff

LOG: Re-apply (#117867): [flang][OpenMP] Implicitly map allocatable record fields (#120374)

This re-applies #117867 with a small fix that hopefully prevents build
bot failures. The fix is avoiding `dyn_cast` for the result of
`getOperation()`. Instead we can assign the result to `mlir::ModuleOp`
directly since the type of the operation is known statically (`OpT` in
`OperationPass`).

Added: 
    flang/include/flang/Lower/DirectivesCommon.h
    flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
    offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
    offload/test/offloading/fortran/implicit-record-field-mapping.f90

Modified: 
    flang/lib/Lower/Bridge.cpp
    flang/lib/Lower/OpenACC.cpp
    flang/lib/Lower/OpenMP/ClauseProcessor.h
    flang/lib/Lower/OpenMP/OpenMP.cpp
    flang/lib/Lower/OpenMP/Utils.cpp
    flang/lib/Optimizer/OpenMP/CMakeLists.txt
    flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Removed: 
    flang/lib/Lower/DirectivesCommon.h


################################################################################
diff  --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 

diff  --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"

diff  --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"

diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();

diff  --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>

diff  --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )

diff  --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
           argIface
               ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
               : 0;
-      addOperands(
-          mapMutableOpRange,
-          llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
-          blockArgInsertIndex);
+      addOperands(mapMutableOpRange,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+                      argIface.getOperation()),
+                  blockArgInsertIndex);
     }
 
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
   // operation (usually function) containing the MapInfoOp because this pass
   // will mutate siblings of MapInfoOp.
   void runOnOperation() override {
-    mlir::ModuleOp module =
-        mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+    mlir::ModuleOp module = getOperation();
     if (!module)
       module = getOperation()->getParentOfType<mlir::ModuleOp>();
     fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a

diff  --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+    %23 = hlfir.designate %20#0{"to_implicitly_map"}   {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<!record_t>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: %[[RECORD_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "_QFEdst_record"}
+// CHECK: %[[FIELD_COORD:.*]] = fir.coordinate_of %[[RECORD_DECL]]#1, %{{c1.*}}
+
+// CHECK: %[[UPPER_BOUND:.*]] = arith.subi %{{.*}}#1, %{{c1.*}} : index
+
+// CHECK: %[[BOUNDS:.*]] = omp.map.bounds 
+// CHECK-SAME: lower_bound(%{{c0.*}} : index) upper_bound(%[[UPPER_BOUND]] : index)
+// CHECK-SAME: extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index)
+// CHECK-SAME: start_idx(%{{.*}}#0 : index) {stride_in_bytes = true}
+
+// CHECK: %[[BASE_ADDR:.*]] = fir.box_offset %[[FIELD_COORD]] base_addr
+// CHECK: %[[FIELD_BASE_ADDR_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) var_ptr_ptr(
+// CHECK-SAME: %[[BASE_ADDR]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) bounds(
+// CHECK-SAME: %[[BOUNDS]])
+
+// CHECK: %[[FIELD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, to) capture(ByRef) ->
+// CHECK-SAME: {{.*}} {name = "dst_record.to_implicitly_map.implicit_map"}
+
+// CHECK: %[[RECORD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[RECORD_DECL]]#1 : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) members(
+// CHECK-SAME: %[[FIELD_MAP]], %[[FIELD_BASE_ADDR_MAP]] :
+// CHECK-SAME: [1], [1, 0] : {{.*}}) -> {{.*}}> {name =
+// CHECK-SAME: "dst_record", partial_map = true}
+
+// CHECK: omp.target map_entries(
+// CHECK-SAME: %[[RECORD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_BASE_ADDR_MAP]] -> %{{[^[:space:]]+}}
+// CHECK-SAME: : {{.*}})

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c4cf0f7afb3a34 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -193,6 +193,13 @@ def MapClauseOwningOpInterface : OpInterface<"MapClauseOwningOpInterface"> {
       (ins), [{
         return $_op.getMapVarsMutable();
       }]>,
+      InterfaceMethod<"Get operand index for a map clause",
+                      "int64_t",
+                      "getOperandIndexForMap",
+      (ins "::mlir::Value":$map), [{
+         return std::distance($_op.getMapVars().begin(),
+                              llvm::find($_op.getMapVars(), map));
+      }]>,
   ];
 }
 

diff  --git a/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..b619774514b2c0
--- /dev/null
+++ b/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
@@ -0,0 +1,83 @@
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module test
+implicit none
+
+TYPE field_type
+  REAL, DIMENSION(:,:), ALLOCATABLE :: density0, density1
+END TYPE field_type
+
+TYPE tile_type
+  TYPE(field_type) :: field
+  INTEGER          :: tile_neighbours(4)
+END TYPE tile_type
+
+TYPE chunk_type
+  INTEGER                                    :: filler
+  TYPE(tile_type), DIMENSION(:), ALLOCATABLE :: tiles
+END TYPE chunk_type
+
+end module test
+
+program reproducer
+  use test
+  implicit none
+  integer          :: i, j
+  TYPE(chunk_type) :: chunk
+
+  allocate(chunk%tiles(2))
+  do i = 1, 2
+    allocate(chunk%tiles(i)%field%density0(2, 2))
+    allocate(chunk%tiles(i)%field%density1(2, 2))
+    do j = 1, 4
+      chunk%tiles(i)%tile_neighbours(j) = j * 10
+    end do
+  end do
+
+  !$omp target enter data map(alloc:       &
+  !$omp  chunk%tiles(2)%field%density0)
+
+  !$omp target
+    chunk%tiles(2)%field%density0(1,1) = 25
+    chunk%tiles(2)%field%density0(1,2) = 50
+    chunk%tiles(2)%field%density0(2,1) = 75
+    chunk%tiles(2)%field%density0(2,2) = 100
+  !$omp end target
+
+  !$omp target exit data map(from:         &
+  !$omp  chunk%tiles(2)%field%density0)
+
+  if (chunk%tiles(2)%field%density0(1,1) /= 25) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(1,2) /= 50) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(2,1) /= 75) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(2,2) /= 100) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  do j = 1, 4
+    if (chunk%tiles(2)%tile_neighbours(j) /= j * 10) then
+      print*, "======= Test Failed! ======="
+      stop 1
+    end if
+  end do
+
+  print *, "======= Test Passed! ======="
+end program reproducer
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: ======= Test Passed! =======

diff  --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..77b13bed707c71
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,52 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeeds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+program test_implicit_field_mapping
+  implicit none
+
+  type record_t
+    real, allocatable :: not_to_implicitly_map(:)
+    real, allocatable :: to_implicitly_map(:)
+  end type
+
+  type(record_t) :: dst_record
+  real :: src_array(10)
+  real :: dst_sum, src_sum
+  integer :: i
+
+  call random_number(src_array)
+  dst_sum = 0
+  src_sum = 0
+
+  do i=1,10
+    src_sum = src_sum + src_array(i)
+  end do
+  print *, "src_sum=", src_sum
+
+  !$omp target map(from: dst_sum)
+    dst_record%to_implicitly_map = src_array
+    dst_sum = 0
+
+    do i=1,10
+      dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+    end do
+  !$omp end target
+
+  print *, "dst_sum=", dst_sum
+
+  if (src_sum == dst_sum) then
+    print *, "Test succeeded!"
+  else
+    print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+  endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!


        


More information about the Mlir-commits mailing list