[Mlir-commits] [flang] [llvm] [mlir] [flang][OpenMP] Implicitly map allocatable record fields (PR #117867)

Kareem Ergawy llvmlistbot at llvm.org
Sun Dec 15 23:59:40 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/117867

>From 8af3b52c7461a87886fffab8821ec0e06649fbcb Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 26 Nov 2024 22:17:34 -0600
Subject: [PATCH] [flang][OpenMP] Implicitly map allocatable record fields

This is a starting PR to implicitly map allocatable record fields.

This PR contains the following changes:
1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that
   these utils work on the `mlir::Value` level rather than the
   `semantics::Symbol` level. This takes one step towards to enabling
   MLIR passes to more easily do some lowering themselves (e.g. creating
   `omp.map.bounds` ops for implicitely caputured data like this PR
   does).
2. Adds support for implicitely capturing and mapping allocatable fields
   in record types.

There is quite some distant to still cover to have full support for
this. I added a number of todos to guide further development.

Co-authored-by: Andrew Gozillon <andrew.gozillon at amd.com>
---
 .../flang}/Lower/DirectivesCommon.h           |  50 ++++--
 flang/lib/Lower/Bridge.cpp                    |   3 +-
 flang/lib/Lower/OpenACC.cpp                   |   3 +-
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |   2 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  23 +--
 flang/lib/Lower/OpenMP/Utils.cpp              |   2 +-
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   2 +
 .../Optimizer/OpenMP/MapInfoFinalization.cpp  | 158 ++++++++++++++++++
 ...p-map-info-finalization-implicit-field.fir |  63 +++++++
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     |   7 +
 ...icit-and-implicit-record-field-mapping.f90 |  83 +++++++++
 .../fortran/implicit-record-field-mapping.f90 |  52 ++++++
 12 files changed, 412 insertions(+), 36 deletions(-)
 rename flang/{lib => include/flang}/Lower/DirectivesCommon.h (97%)
 create mode 100644 flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
 create mode 100644 offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
 create mode 100644 offload/test/offloading/fortran/implicit-record-field-mapping.f90

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index de2b941b688bee..cb27f4ec5572ed 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..df7f6129c70741 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -486,6 +490,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+    %23 = hlfir.designate %20#0{"to_implicitly_map"}   {fortran_attrs = #fir.var_attrs<allocatable>} : (!fir.ref<!record_t>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+    omp.terminator
+  }
+  return
+}
+
+// CHECK: %[[RECORD_DECL:.*]]:2 = hlfir.declare %0 {uniq_name = "_QFEdst_record"}
+// CHECK: %[[FIELD_COORD:.*]] = fir.coordinate_of %[[RECORD_DECL]]#1, %{{c1.*}}
+
+// CHECK: %[[UPPER_BOUND:.*]] = arith.subi %{{.*}}#1, %{{c1.*}} : index
+
+// CHECK: %[[BOUNDS:.*]] = omp.map.bounds 
+// CHECK-SAME: lower_bound(%{{c0.*}} : index) upper_bound(%[[UPPER_BOUND]] : index)
+// CHECK-SAME: extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index)
+// CHECK-SAME: start_idx(%{{.*}}#0 : index) {stride_in_bytes = true}
+
+// CHECK: %[[BASE_ADDR:.*]] = fir.box_offset %[[FIELD_COORD]] base_addr
+// CHECK: %[[FIELD_BASE_ADDR_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) var_ptr_ptr(
+// CHECK-SAME: %[[BASE_ADDR]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) bounds(
+// CHECK-SAME: %[[BOUNDS]])
+
+// CHECK: %[[FIELD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[FIELD_COORD]] : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, to) capture(ByRef) ->
+// CHECK-SAME: {{.*}} {name = "dst_record.to_implicitly_map.implicit_map"}
+
+// CHECK: %[[RECORD_MAP:.*]] = omp.map.info var_ptr(
+// CHECK-SAME: %[[RECORD_DECL]]#1 : {{.*}}) map_clauses(
+// CHECK-SAME: implicit, tofrom) capture(ByRef) members(
+// CHECK-SAME: %[[FIELD_MAP]], %[[FIELD_BASE_ADDR_MAP]] :
+// CHECK-SAME: [1], [1, 0] : {{.*}}) -> {{.*}}> {name =
+// CHECK-SAME: "dst_record", partial_map = true}
+
+// CHECK: omp.target map_entries(
+// CHECK-SAME: %[[RECORD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_MAP]] -> %{{[^[:space:]]+}},
+// CHECK-SAME: %[[FIELD_BASE_ADDR_MAP]] -> %{{[^[:space:]]+}}
+// CHECK-SAME: : {{.*}})
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c4cf0f7afb3a34 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -193,6 +193,13 @@ def MapClauseOwningOpInterface : OpInterface<"MapClauseOwningOpInterface"> {
       (ins), [{
         return $_op.getMapVarsMutable();
       }]>,
+      InterfaceMethod<"Get operand index for a map clause",
+                      "int64_t",
+                      "getOperandIndexForMap",
+      (ins "::mlir::Value":$map), [{
+         return std::distance($_op.getMapVars().begin(),
+                              llvm::find($_op.getMapVars(), map));
+      }]>,
   ];
 }
 
diff --git a/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..b619774514b2c0
--- /dev/null
+++ b/offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90
@@ -0,0 +1,83 @@
+! REQUIRES: flang, amdgpu
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module test
+implicit none
+
+TYPE field_type
+  REAL, DIMENSION(:,:), ALLOCATABLE :: density0, density1
+END TYPE field_type
+
+TYPE tile_type
+  TYPE(field_type) :: field
+  INTEGER          :: tile_neighbours(4)
+END TYPE tile_type
+
+TYPE chunk_type
+  INTEGER                                    :: filler
+  TYPE(tile_type), DIMENSION(:), ALLOCATABLE :: tiles
+END TYPE chunk_type
+
+end module test
+
+program reproducer
+  use test
+  implicit none
+  integer          :: i, j
+  TYPE(chunk_type) :: chunk
+
+  allocate(chunk%tiles(2))
+  do i = 1, 2
+    allocate(chunk%tiles(i)%field%density0(2, 2))
+    allocate(chunk%tiles(i)%field%density1(2, 2))
+    do j = 1, 4
+      chunk%tiles(i)%tile_neighbours(j) = j * 10
+    end do
+  end do
+
+  !$omp target enter data map(alloc:       &
+  !$omp  chunk%tiles(2)%field%density0)
+
+  !$omp target
+    chunk%tiles(2)%field%density0(1,1) = 25
+    chunk%tiles(2)%field%density0(1,2) = 50
+    chunk%tiles(2)%field%density0(2,1) = 75
+    chunk%tiles(2)%field%density0(2,2) = 100
+  !$omp end target
+
+  !$omp target exit data map(from:         &
+  !$omp  chunk%tiles(2)%field%density0)
+
+  if (chunk%tiles(2)%field%density0(1,1) /= 25) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(1,2) /= 50) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(2,1) /= 75) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  if (chunk%tiles(2)%field%density0(2,2) /= 100) then
+    print*, "======= Test Failed! ======="
+    stop 1
+  end if
+
+  do j = 1, 4
+    if (chunk%tiles(2)%tile_neighbours(j) /= j * 10) then
+      print*, "======= Test Failed! ======="
+      stop 1
+    end if
+  end do
+
+  print *, "======= Test Passed! ======="
+end program reproducer
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: ======= Test Passed! =======
diff --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..77b13bed707c71
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,52 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeeds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+program test_implicit_field_mapping
+  implicit none
+
+  type record_t
+    real, allocatable :: not_to_implicitly_map(:)
+    real, allocatable :: to_implicitly_map(:)
+  end type
+
+  type(record_t) :: dst_record
+  real :: src_array(10)
+  real :: dst_sum, src_sum
+  integer :: i
+
+  call random_number(src_array)
+  dst_sum = 0
+  src_sum = 0
+
+  do i=1,10
+    src_sum = src_sum + src_array(i)
+  end do
+  print *, "src_sum=", src_sum
+
+  !$omp target map(from: dst_sum)
+    dst_record%to_implicitly_map = src_array
+    dst_sum = 0
+
+    do i=1,10
+      dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+    end do
+  !$omp end target
+
+  print *, "dst_sum=", dst_sum
+
+  if (src_sum == dst_sum) then
+    print *, "Test succeeded!"
+  else
+    print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+  endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!



More information about the Mlir-commits mailing list