[flang-commits] [flang] [llvm] [Proof-of-Concept][flang][OpenMP] Implicitely map allocatable record … (PR #117867)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Wed Nov 27 02:52:43 PST 2024


https://github.com/ergawy created https://github.com/llvm/llvm-project/pull/117867

…fields

This is a proof-of-concept to implicitely map allocatable record fields. This PR contains the following changes:
1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that these utils work on the `mlir::Value` level rather than the `semantics::Symbol` level. This takes one step towards to enabling MLIR passes to more easily do some lowering themselves (e.g. creating `omp.map.bounds` ops for implicitely caputured data like this PR does).
2. Adds support for implicitely capturing and mapping allocatable fields in record types.

If the taken approach by this PR is acceptable, I will convert it to a proper PR and add more extensive testing.

>From b9a9f8aff3c64b3fb3ec1274358749bea666079e Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 26 Nov 2024 22:17:34 -0600
Subject: [PATCH] [Proof-of-Concept][flang][OpenMP] Implicitely map allocatable
 record fields

This is a proof-of-concept to implicitely map allocatable record fields.
This PR contains the following changes:
1. Re-purposes some of the utils used in `Lower/OpenMP.cpp` so that
   these utils work on the `mlir::Value` level rather than the
   `semantics::Symbol` level. This takes one step towards to enabling
   MLIR passes to more easily do some lowering themselves (e.g. creating
   `omp.map.bounds` ops for implicitely caputured data like this PR
   does).
2. Adds support for implicitely capturing and mapping allocatable fields
   in record types.

If the taken approach by this PR is acceptable, I will convert it to a
proper PR and add more extensive testing.
---
 flang/lib/Lower/DirectivesCommon.h            |  50 ++++--
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  21 +--
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   2 +
 .../Optimizer/OpenMP/MapInfoFinalization.cpp  | 154 ++++++++++++++++++
 .../fortran/implicit-record-field-mapping.f90 |  63 +++++++
 5 files changed, 259 insertions(+), 31 deletions(-)
 create mode 100644 offload/test/offloading/fortran/implicit-record-field-mapping.f90

diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/lib/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/lib/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index ea2d1eb66bfea5..ddf15ce7f3c3dd 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1817,32 +1817,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b1e0dbf6e707e5..0b78773a252cd2 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
   HLFIRDialect
   MLIRIR
   MLIRPass
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..717d0dcd3fdfba 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -25,9 +25,12 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -43,6 +46,10 @@
 #include <iterator>
 #include <numeric>
 
+// Move to `flang/include/flang/Common/OpenMP-utils.h` once
+// https://github.com/llvm/llvm-project/pull/115443 is merge.
+#include "../../Lower/DirectivesCommon.h"
+
 namespace flangomp {
 #define GEN_PASS_DEF_MAPINFOFINALIZATIONPASS
 #include "flang/Optimizer/OpenMP/Passes.h.inc"
@@ -485,6 +492,153 @@ class MapInfoFinalizationPass
       // clear all local allocations we made for any boxes in any prior
       // iterations from previous function scopes.
       localBoxAllocas.clear();
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        // TODO Add as a method to MapClauseOwningOpInterface.
+        unsigned mapVarIdx = 0;
+        for (auto [idx, mapOp] : llvm::enumerate(mapClauseOwner.getMapVars())) {
+          if (mapOp == op) {
+            mapVarIdx = idx;
+            break;
+          }
+        }
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding paramemter to the field.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIdices;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = false;
+
+          if (op.getMembersIndexAttr())
+            for (auto indexList : op.getMembersIndexAttr()) {
+              auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+              if (indexListAttr.size() == 1 &&
+                  mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                      fieldIdx)
+                alreadyMapped = true;
+            }
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIdices.emplace_back(fieldIdx);
+        }
+
+        if (!newMapOpsForFields.empty()) {
+          op.getMembersMutable().append(newMapOpsForFields);
+          llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+          mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+          if (oldMembersIdxAttr)
+            for (mlir::Attribute indexList : oldMembersIdxAttr) {
+              llvm::SmallVector<int64_t> listVec;
+
+              for (mlir::Attribute index :
+                   mlir::cast<mlir::ArrayAttr>(indexList))
+                listVec.push_back(
+                    mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+              newMemberIndices.emplace_back(std::move(listVec));
+            }
+
+          for (int64_t newFieldIdx : fieldIdices) {
+            newMemberIndices.emplace_back(
+                llvm::SmallVector<int64_t>(1, newFieldIdx));
+          }
+
+          op.setMembersIndexAttr(
+              builder.create2DI64ArrayAttr(newMemberIndices));
+        }
+
+        return mlir::WalkResult::advance();
+      });
 
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
diff --git a/offload/test/offloading/fortran/implicit-record-field-mapping.f90 b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
new file mode 100644
index 00000000000000..8bde033d4ef4c3
--- /dev/null
+++ b/offload/test/offloading/fortran/implicit-record-field-mapping.f90
@@ -0,0 +1,63 @@
+! Test implicit mapping of alloctable record fields.
+
+! REQUIRES: flang, amdgpu
+
+! This fails only because it needs the Fortran runtime built for device. If this
+! is avaialbe, this test succeedds when run.
+! XFAIL: *
+
+! RUN: %libomptarget-compile-fortran-generic
+! RUN: env LIBOMPTARGET_INFO=16 %libomptarget-run-generic 2>&1 | %fcheck-generic
+module record_m
+  implicit none
+
+  private
+  public :: recored_t
+
+  type recored_t
+    real, allocatable :: not_to_implicitly_map(:)
+    real, allocatable :: to_implicitly_map(:)
+  end type
+
+  interface recored_t
+  end interface
+end module record_m
+
+program concurrent_inferences
+  use record_m, only : recored_t
+  implicit none
+
+  type(recored_t) :: dst_record
+  real :: src_array(10)
+  real :: dst_sum, src_sum
+  integer :: i
+
+  call random_number(src_array)
+  dst_sum = 0
+  src_sum = 0
+
+  do i=1,10
+    src_sum = src_sum + src_array(i)
+  end do
+  print *, "src_sum=", src_sum
+
+  !$omp target map(from: dst_sum)
+    dst_record%to_implicitly_map = src_array
+    dst_sum = 0
+
+    do i=1,10
+      dst_sum = dst_sum + dst_record%to_implicitly_map(i)
+    end do
+  !$omp end target
+
+  print *, "dst_sum=", dst_sum
+
+  if (src_sum == dst_sum) then
+    print *, "Test succeeded!"
+  else
+    print *, "Test failed!", " dst_sum=", dst_sum, "vs. src_sum=", src_sum
+  endif
+end program
+
+! CHECK: "PluginInterface" device {{[0-9]+}} info: Launching kernel {{.*}}
+! CHECK: Test succeeded!



More information about the flang-commits mailing list