[llvm-branch-commits] [flang] [Flang][OpenMP] Derived type explicit allocatable member mapping (PR #96266)

Kareem Ergawy via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jul 15 01:07:17 PDT 2024


================
@@ -141,6 +143,110 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
   return op;
 }
 
+omp::ObjectList gatherObjects(omp::Object obj,
+                              semantics::SemanticsContext &semaCtx) {
+  omp::ObjectList objList;
+  std::optional<omp::Object> baseObj = getBaseObject(obj, semaCtx);
+  while (baseObj.has_value()) {
+    objList.push_back(baseObj.value());
+    baseObj = getBaseObject(baseObj.value(), semaCtx);
+  }
+  return omp::ObjectList{llvm::reverse(objList)};
+}
+
+bool duplicateMemberMapInfo(OmpMapMemberIndicesData &parentMembers,
+                            llvm::SmallVectorImpl<int> &memberIndices) {
+  // A variation of std:equal that supports non-equal length index lists for our
+  // specific use-case, if one is larger than the other, we use -1, the default
+  // filler element in place of the smaller vector, this prevents UB from over
+  // indexing and removes the need for us to do any filling of intermediate
+  // index lists we'll discard.
+  auto isEqual = [](auto first1, auto last1, auto first2, auto last2) {
+    int v1, v2;
+    for (; first1 != last1; ++first1, ++first2) {
+      v1 = (first1 == last1) ? -1 : *first1;
+      v2 = (first2 == last2) ? -1 : *first2;
+
+      if (!(v1 == v2))
+        return false;
+    }
+    return true;
+  };
+
+  for (auto memberData : parentMembers.memberPlacementIndices)
+    if (isEqual(memberData.begin(), memberData.end(), memberIndices.begin(),
+                memberIndices.end()))
+      return true;
+  return false;
+}
+
+// When mapping members of derived types, there is a chance that one of the
+// members along the way to a mapped member is an descriptor. In which case
+// we have to make sure we generate a map for those along the way otherwise
+// we will be missing a chunk of data required to actually map the member
+// type to device. This function effectively generates these maps and the
+// appropriate data accesses required to generate these maps. It will avoid
+// creating duplicate maps, as duplicates are just as bad as unmapped
+// descriptor data in a lot of cases for the runtime (and unnecessary
+// data movement should be avoided where possible)
+mlir::Value createParentSymAndGenIntermediateMaps(
+    mlir::Location clauseLocation, Fortran::lower::AbstractConverter &converter,
+    omp::ObjectList &objectList, llvm::SmallVector<int> &indices,
+    OmpMapMemberIndicesData &parentMemberIndices, std::string asFortran,
+    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  Fortran::lower::AddrAndBoundsInfo parentBaseAddr =
+      Fortran::lower::getDataOperandBaseAddr(
+          converter, firOpBuilder, *objectList[0].sym(), clauseLocation);
+  mlir::Value curValue = parentBaseAddr.addr;
+
+  for (size_t i = 0; i < objectList.size(); ++i) {
+    mlir::Type unwrappedTy =
+        fir::unwrapSequenceType(fir::unwrapPassByRefType(curValue.getType()));
+    if (fir::RecordType recordType =
+            mlir::dyn_cast_or_null<fir::RecordType>(unwrappedTy)) {
+      mlir::Value idxConst = firOpBuilder.createIntegerConstant(
+          clauseLocation, firOpBuilder.getIndexType(), indices[i]);
+      mlir::Type memberTy = recordType.getTypeList().at(indices[i]).second;
+      curValue = firOpBuilder.create<fir::CoordinateOp>(
+          clauseLocation, firOpBuilder.getRefType(memberTy), curValue,
+          idxConst);
+
+      if ((i != indices.size() - 1) && fir::isTypeWithDescriptor(memberTy)) {
+        llvm::SmallVector<int> intermIndices = indices;
+        std::fill(std::next(intermIndices.begin(), i + 1), intermIndices.end(),
+                  -1);
+        if (!duplicateMemberMapInfo(parentMemberIndices, intermIndices)) {
+          // TODO: Perhaps generate bounds for these intermediate maps, as it
+          // may be required for cases such as:
+          //    dtype(1)%second(3)%array
+          // where second is an allocatable (and dtype may be an allocatable as
+          // well, although in this case I am not sure the fortran syntax would
+          // be legal)
+          mlir::omp::MapInfoOp mapOp = createMapInfoOp(
+              firOpBuilder, clauseLocation, curValue,
+              /*varPtrPtr=*/mlir::Value{}, asFortran,
+              /*bounds=*/llvm::SmallVector<mlir::Value>{},
+              /*members=*/{},
+              /*membersIndex=*/mlir::DenseIntElementsAttr{},
+              static_cast<
+                  std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                  mapTypeBits),
+              mlir::omp::VariableCaptureKind::ByRef, curValue.getType());
+
+          parentMemberIndices.memberPlacementIndices.push_back(intermIndices);
+          parentMemberIndices.memberMap.push_back(mapOp);
+        }
+
+        if (i != indices.size() - 1)
----------------
ergawy wrote:

We already know this is true from the enclosing `if`.

https://github.com/llvm/llvm-project/pull/96266


More information about the llvm-branch-commits mailing list