[llvm-branch-commits] [flang] [llvm] [mlir] [MLIR][OpenMP] Introduce overlapped record type map support (PR #119588)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 24 06:38:58 PST 2025


================
@@ -2979,39 +2979,61 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
   return std::distance(mapData.MapClause.begin(), res);
 }
 
-static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
-                                                    bool first) {
-  ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
-  // Only 1 member has been mapped, we can return it.
-  if (indexAttr.size() == 1)
-    return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
+static void sortMapIndices(llvm::SmallVector<size_t> &indices,
+                           mlir::omp::MapInfoOp mapInfo,
+                           bool ascending = true) {
+  mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
+  if (indexAttr.empty() || indexAttr.size() == 1 || indices.empty() ||
+      indices.size() == 1)
+    return;
 
-  llvm::SmallVector<size_t> indices(indexAttr.size());
-  std::iota(indices.begin(), indices.end(), 0);
+  llvm::sort(
+      indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
+        auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
+        auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
+
+        size_t smallestMember = memberIndicesA.size() < memberIndicesB.size()
+                                    ? memberIndicesA.size()
+                                    : memberIndicesB.size();
 
-  llvm::sort(indices.begin(), indices.end(),
-             [&](const size_t a, const size_t b) {
-               auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
-               auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
-               for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
-                 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
-                 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+        for (size_t i = 0; i < smallestMember; ++i) {
+          int64_t aIndex =
+              mlir::cast<mlir::IntegerAttr>(memberIndicesA.getValue()[i])
+                  .getInt();
+          int64_t bIndex =
+              mlir::cast<mlir::IntegerAttr>(memberIndicesB.getValue()[i])
+                  .getInt();
 
-                 if (aIndex == bIndex)
-                   continue;
+          if (aIndex == bIndex)
+            continue;
 
-                 if (aIndex < bIndex)
-                   return first;
+          if (aIndex < bIndex)
+            return ascending;
 
-                 if (aIndex > bIndex)
-                   return !first;
-               }
+          if (aIndex > bIndex)
+            return !ascending;
+        }
 
-               // Iterated the up until the end of the smallest member and
-               // they were found to be equal up to that point, so select
-               // the member with the lowest index count, so the "parent"
-               return memberIndicesA.size() < memberIndicesB.size();
-             });
+        // Iterated up until the end of the smallest member and
+        // they were found to be equal up to that point, so select
+        // the member with the lowest index count, so the "parent"
+        return memberIndicesA.size() < memberIndicesB.size();
+      });
+}
+
+static mlir::omp::MapInfoOp
+getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
+  mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
+  // Only 1 member has been mapped, we can return it.
+  if (indexAttr.size() == 1)
+    if (auto mapOp =
+            dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
----------------
skatrak wrote:

Let me know if I understood this wrong, but it seems like there is nothing preventing the `llvm::cast` call at the end of this function to trigger an assert if there was a single member mapped that wasn't defined by an `omp.map.info`.

I don't know whether this function can be expected to return `null`, in which case we could replace the `cast` below with a `dyn_cast`, or if this check here should be replaced with `return cast<omp::MapInfoOp>(...)`.

https://github.com/llvm/llvm-project/pull/119588


More information about the llvm-branch-commits mailing list