[llvm-branch-commits] [OpenMP][MLIR] Descriptor explicit member map lowering changes (PR #96265)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Aug 12 05:12:54 PDT 2024


================
@@ -2261,47 +2261,47 @@ static int getMapDataMemberIdx(MapInfoData &mapData,
 
 static mlir::omp::MapInfoOp
 getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
-  mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
-
+  mlir::ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
   // Only 1 member has been mapped, we can return it.
   if (indexAttr.size() == 1)
     if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
             mapInfo.getMembers()[0].getDefiningOp()))
       return mapOp;
 
-  llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
-  llvm::SmallVector<size_t> indices(shape[0]);
+  llvm::SmallVector<size_t> indices(indexAttr.size());
   std::iota(indices.begin(), indices.end(), 0);
 
-  llvm::sort(indices.begin(), indices.end(),
-             [&](const size_t a, const size_t b) {
-               auto indexValues = indexAttr.getValues<int32_t>();
-               for (int i = 0; i < shape[1]; ++i) {
-                 int aIndex = indexValues[a * shape[1] + i];
-                 int bIndex = indexValues[b * shape[1] + i];
-
-                 if (aIndex == bIndex)
-                   continue;
-
-                 if (aIndex != -1 && bIndex == -1)
-                   return false;
-
-                 if (aIndex == -1 && bIndex != -1)
-                   return true;
+  llvm::sort(
+      indices.begin(), indices.end(), [&](const size_t a, const size_t b) {
+        auto memberIndicesA = mlir::cast<mlir::ArrayAttr>(indexAttr[a]);
+        auto memberIndicesB = mlir::cast<mlir::ArrayAttr>(indexAttr[b]);
+
+        size_t smallestMember = memberIndicesA.size() < memberIndicesB.size()
+                                    ? memberIndicesA.size()
+                                    : memberIndicesB.size();
+        for (size_t i = 0; i < smallestMember; ++i) {
----------------
skatrak wrote:

Nit: I think `llvm::zip` could simplify this a bit, since it already implements iterating over two ranges of potentially different sizes until the end of the shortest one is reached.

https://github.com/llvm/llvm-project/pull/96265


More information about the llvm-branch-commits mailing list