[llvm-branch-commits] [OpenMP][MLIR] Extend explicit derived type member mapping support for OpenMP dialects lowering to LLVM-IR (PR #81510)

Kareem Ergawy via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 15 08:02:46 PST 2024


================
@@ -1783,6 +1783,98 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
   }
 }
 
+static int getMapDataMemberIdx(MapInfoData &mapData,
+                               mlir::omp::MapInfoOp memberOp) {
+  int memberDataIdx = -1;
+  for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
+    if (mapData.MapClause[i] == memberOp)
+      memberDataIdx = i;
+  }
+  return memberDataIdx;
+}
+
+static mlir::omp::MapInfoOp
+getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
+  // Only 1 member has been mapped, we can return it.
+  if (mapInfo.getMembersIndex()->size() == 1)
+    if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+            mapInfo.getMembers()[0].getDefiningOp()))
+      return mapOp;
+
+  int64_t curPos =
+      mapInfo.getMembersIndex()->begin()->cast<mlir::IntegerAttr>().getInt();
+
+  int64_t idx = 1, curIdx = 0, memberPlacement = 0;
+  for (const auto *iter = std::next(mapInfo.getMembersIndex()->begin());
+       iter != mapInfo.getMembersIndex()->end(); iter++) {
+    memberPlacement = iter->cast<mlir::IntegerAttr>().getInt();
+    if (first) {
+      if (memberPlacement < curPos) {
+        curIdx = idx;
+        curPos = memberPlacement;
+      }
+    } else {
+      if (memberPlacement > curPos) {
+        curIdx = idx;
+        curPos = memberPlacement;
+      }
+    }
+    idx++;
+  }
+
+  if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
+          mapInfo.getMembers()[curIdx].getDefiningOp()))
+    return mapOp;
+
+  return {};
+}
+
+std::vector<llvm::Value *>
+calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
+                      llvm::IRBuilderBase &builder, bool isArrayTy,
+                      mlir::OperandRange bounds) {
+  std::vector<llvm::Value *> idx;
+  llvm::Value *offsetAddress = nullptr;
+  if (!bounds.empty()) {
+    idx.push_back(builder.getInt64(0));
+    if (isArrayTy) {
+      for (int i = bounds.size() - 1; i >= 0; --i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
+        }
+      }
+    } else {
+      std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
+      for (size_t i = 1; i < bounds.size(); ++i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          dimensionIndexSizeOffset.push_back(builder.CreateMul(
+              moduleTranslation.lookupValue(boundOp.getExtent()),
+              dimensionIndexSizeOffset[i - 1]));
+        }
+      }
+
+      for (int i = bounds.size() - 1; i >= 0; --i) {
+        if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::DataBoundsOp>(
+                bounds[i].getDefiningOp())) {
+          if (!offsetAddress)
+            offsetAddress = builder.CreateMul(
+                moduleTranslation.lookupValue(boundOp.getLowerBound()),
+                dimensionIndexSizeOffset[i]);
+          else
+            offsetAddress = builder.CreateAdd(
+                offsetAddress, builder.CreateMul(moduleTranslation.lookupValue(
+                                                     boundOp.getLowerBound()),
+                                                 dimensionIndexSizeOffset[i]));
----------------
ergawy wrote:

I think you can get rid of `offsetAddress` this way:
```suggestion
          if (idx.empty())
            idx.emplace_back(builder.CreateMul(
                moduleTranslation.lookupValue(boundOp.getLowerBound()),
                dimensionIndexSizeOffset[i]));
          else
            offsetAddress.back() = builder.CreateAdd(
                offsetAddress.back(), builder.CreateMul(moduleTranslation.lookupValue(
                                                     boundOp.getLowerBound()),
                                                 dimensionIndexSizeOffset[i]));
```

https://github.com/llvm/llvm-project/pull/81510


More information about the llvm-branch-commits mailing list