[llvm-branch-commits] [Flang][OpenMP] Derived type explicit allocatable member mapping (PR #111192)

Kareem Ergawy via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Oct 8 04:21:16 PDT 2024


================
@@ -145,11 +146,174 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
       builder.getStringAttr(name), builder.getBoolAttr(partialMap));
-
   return op;
 }
 
-static int
+omp::ObjectList gatherObjects(omp::Object obj,
+                              semantics::SemanticsContext &semaCtx) {
+  omp::ObjectList objList;
+  std::optional<omp::Object> baseObj = obj;
+  while (baseObj.has_value()) {
+    objList.push_back(baseObj.value());
+    baseObj = getBaseObject(baseObj.value(), semaCtx);
+  }
+  return omp::ObjectList{llvm::reverse(objList)};
+}
+
+bool isDuplicateMemberMapInfo(OmpMapParentAndMemberData &parentMembers,
+                              llvm::SmallVectorImpl<int64_t> &memberIndices) {
+  for (auto memberData : parentMembers.memberPlacementIndices)
+    if (std::equal(memberIndices.begin(), memberIndices.end(),
+                   memberData.begin()))
+      return true;
+  return false;
+}
+
+static void generateArrayIndices(lower::AbstractConverter &converter,
+                                 fir::FirOpBuilder &firOpBuilder,
+                                 lower::StatementContext &stmtCtx,
+                                 mlir::Location clauseLocation,
+                                 llvm::SmallVectorImpl<mlir::Value> &indices,
+                                 omp::Object object) {
+  if (auto maybeRef = evaluate::ExtractDataRef(*object.ref())) {
+    evaluate::DataRef ref = *maybeRef;
+    if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u)) {
+      for (auto v : arr->subscript()) {
+        if (std::holds_alternative<Triplet>(v.u)) {
+          llvm_unreachable("Triplet indexing in map clause is unsupported");
+        } else {
+          auto expr =
+              std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(v.u);
+          mlir::Value subscript = fir::getBase(
+              converter.genExprValue(toEvExpr(expr.value()), stmtCtx));
+          mlir::Value one = firOpBuilder.createIntegerConstant(
+              clauseLocation, firOpBuilder.getIndexType(), 1);
+          subscript = firOpBuilder.createConvert(
+              clauseLocation, firOpBuilder.getIndexType(), subscript);
+          indices.push_back(firOpBuilder.create<mlir::arith::SubIOp>(
+              clauseLocation, subscript, one));
+        }
+      }
+    }
+  }
+}
+
+// When mapping members of derived types, there is a chance that one of the
+// members along the way to a mapped member is an descriptor. In which case
+// we have to make sure we generate a map for those along the way otherwise
+// we will be missing a chunk of data required to actually map the member
+// type to device. This function effectively generates these maps and the
+// appropriate data accesses required to generate these maps. It will avoid
+// creating duplicate maps, as duplicates are just as bad as unmapped
+// descriptor data in a lot of cases for the runtime (and unnecessary
+// data movement should be avoided where possible)
+mlir::Value createParentSymAndGenIntermediateMaps(
+    mlir::Location clauseLocation, lower::AbstractConverter &converter,
+    semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
+    omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
+    OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
+    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) {
+
+  auto arrayExprWithSubscript = [](omp::Object obj) {
+    if (auto maybeRef = evaluate::ExtractDataRef(*obj.ref())) {
+      evaluate::DataRef ref = *maybeRef;
+      if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u))
+        return !arr->subscript().empty();
+    }
+    return false;
+  };
+
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::AddrAndBoundsInfo parentBaseAddr = lower::getDataOperandBaseAddr(
+      converter, firOpBuilder, *objectList[0].sym(), clauseLocation);
+  mlir::Value curValue = parentBaseAddr.addr;
+
+  // Iterate over all objects in the objectList, this should consist of all
+  // record types between the parent and the member being mapped (including
+  // the parent). The object list may also contain array objects as well,
+  // this can occur when specifying bounds or a specific element access
+  // within a member map, we skip these.
+  size_t currentIndex = 0;
+  for (size_t i = 0; i < objectList.size(); ++i) {
----------------
ergawy wrote:

We can use a range-based loop here.

https://github.com/llvm/llvm-project/pull/111192


More information about the llvm-branch-commits mailing list