[flang-commits] [flang] [OpenMP]Update use_device_clause lowering (PR #101703)

Sergio Afonso via flang-commits flang-commits at lists.llvm.org
Tue Aug 6 04:09:21 PDT 2024


================
@@ -1072,27 +1072,133 @@ bool ClauseProcessor::processEnter(
 }
 
 bool ClauseProcessor::processUseDeviceAddr(
+    Fortran::lower::StatementContext &stmtCtx,
     mlir::omp::UseDeviceAddrClauseOps &result,
     llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
     llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
-    llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
-  return findRepeatableClause<omp::clause::UseDeviceAddr>(
-      [&](const omp::clause::UseDeviceAddr &clause, const parser::CharBlock &) {
-        addUseDeviceClause(converter, clause.v, result.useDeviceAddrVars,
-                           useDeviceTypes, useDeviceLocs, useDeviceSyms);
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms)
+    const {
+  std::map<const Fortran::semantics::Symbol *,
+           llvm::SmallVector<OmpMapMemberIndicesData>>
+      parentMemberIndices;
+  bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
+      [&](const omp::clause::UseDeviceAddr &clause,
+          const Fortran::parser::CharBlock &) {
+        const Fortran::parser::CharBlock source;
+        mlir::Location location = converter.genLocation(source);
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+        for (const omp::Object &object : clause.v) {
+          llvm::SmallVector<mlir::Value> bounds;
+          std::stringstream asFortran;
+
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::gatherDataOperandAddrAndBounds<
+                  mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
+                  converter, firOpBuilder, semaCtx, stmtCtx, *object.sym(),
+                  object.ref(), location, asFortran, bounds,
+                  treatIndexAsSection);
+
+          auto origSymbol = converter.getSymbolAddress(*object.sym());
+          mlir::Value symAddr = info.addr;
+          if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
+            symAddr = origSymbol;
+
+          // Explicit map captures are captured ByRef by default,
+          // optimisation passes may alter this to ByCopy or other capture
+          // types to optimise
+          mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+              firOpBuilder, location, symAddr,
+              /*varPtrPtr=*/mlir::Value{}, asFortran.str(), bounds,
+              /*members=*/{}, /*membersIndex=*/mlir::DenseIntElementsAttr{},
+              static_cast<
+                  std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                  mapTypeBits),
+              mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
+
+          if (object.sym()->owner().IsDerivedType()) {
+            addChildIndexAndMapToParent(object, parentMemberIndices, mapOp,
+                                        semaCtx);
+          } else {
+            useDeviceSyms.push_back(object.sym());
+            useDeviceTypes.push_back(symAddr.getType());
+            useDeviceLocs.push_back(symAddr.getLoc());
+            result.useDeviceAddrVars.push_back(mapOp);
+          }
+        }
----------------
skatrak wrote:

This block of code should be outlined and called from `processUseDeviceAddr`, `processUseDevicePtr`, `processMap` and `processMotionClauses`. I'm thinking of a private method of `ClauseProcessor` with this signature:
```c++
  void processMappedObjects(
      lower::StatementContext &stmtCtx, const omp::ObjectList &objects,
      mlir::Location clauseLocation,
      llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
      std::map<const semantics::Symbol *,
               llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
      llvm::SmallVectorImpl<mlir::Value> &mapVars,
      llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms,
      llvm::SmallVectorImpl<mlir::Type> *mapTypes = nullptr,
      llvm::SmallVectorImpl<mlir::Location> *mapLocs = nullptr);
```
Each caller can then populate the `objects` and `mapTypeBits` in a different way, since that's where their behavior seems to diverge.

https://github.com/llvm/llvm-project/pull/101703


More information about the flang-commits mailing list