[flang] [llvm] [mlir] [Flang][OpenMP] Update MapInfoFinalization to use BlockArgs Interface and modify use_device_ptr/addr to be order independent (PR #113919)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 07:49:00 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: None (agozillon)

<details>
<summary>Changes</summary>

This patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type.

During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends.

---

Patch is 35.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/113919.diff


10 Files Affected:

- (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+61-40) 
- (modified) flang/test/Lower/OpenMP/allocatable-map.f90 (+1-1) 
- (modified) flang/test/Lower/OpenMP/array-bounds.f90 (+1-1) 
- (modified) flang/test/Lower/OpenMP/target.f90 (+2-2) 
- (modified) flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 (+6-5) 
- (modified) flang/test/Transforms/omp-map-info-finalization.fir (+1-1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+61-57) 
- (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+8-8) 
- (added) offload/test/Inputs/target-use-dev-ptr.c (+23) 
- (added) offload/test/offloading/fortran/target-use-dev-ptr.f90 (+37) 


``````````diff
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 7ebeb51cf3dec7..6eb65b25594295 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -125,61 +125,82 @@ class MapInfoFinalizationPass
     // TODO: map the addendum segment of the descriptor, similarly to the
     // above base address/data pointer member.
 
-    auto addOperands = [&](mlir::OperandRange &operandsArr,
-                           mlir::MutableOperandRange &mutableOpRange,
-                           auto directiveOp) {
-      llvm::SmallVector<mlir::Value> newMapOps;
-      for (size_t i = 0; i < operandsArr.size(); ++i) {
-        if (operandsArr[i] == op) {
-          // Push new implicit maps generated for the descriptor.
-          newMapOps.push_back(baseAddr);
+    mlir::omp::MapInfoOp newDescParentMapOp =
+        builder.create<mlir::omp::MapInfoOp>(
+            op->getLoc(), op.getResult().getType(), descriptor,
+            mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
+            /*varPtrPtr=*/mlir::Value{},
+            /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
+            /*members_index=*/
+            mlir::DenseIntElementsAttr::get(
+                mlir::VectorType::get(
+                    llvm::ArrayRef<int64_t>({1, 1}),
+                    mlir::IntegerType::get(builder.getContext(), 32)),
+                llvm::ArrayRef<int32_t>({0})),
+            /*bounds=*/mlir::SmallVector<mlir::Value>{},
+            builder.getIntegerAttr(builder.getIntegerType(64, false),
+                                   op.getMapType().value()),
+            op.getMapCaptureTypeAttr(), op.getNameAttr(),
+            op.getPartialMapAttr());
+    op.replaceAllUsesWith(newDescParentMapOp.getResult());
+    op->erase();
 
-          // for TargetOp's which have IsolatedFromAbove we must align the
-          // new additional map operand with an appropriate BlockArgument,
-          // as the printing and later processing currently requires a 1:1
-          // mapping of BlockArgs to MapInfoOp's at the same placement in
-          // each array (BlockArgs and MapOperands).
-          if (directiveOp) {
-            directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
+    auto addOperands = [&](mlir::OperandRange &mapVarsArr,
+                           mlir::MutableOperandRange &mutableOpRange,
+                           mlir::Operation *directiveOp,
+                           mlir::omp::MapInfoOp newDesc,
+                           unsigned blockArgInsertIndex = 0,
+                           bool insertBlockArgs = true) {
+      if (llvm::is_contained(mapVarsArr, newDesc.getResult())) {
+        llvm::SmallVector<mlir::Value> newMapOps{mapVarsArr};
+        for (auto mapMember : newDesc.getMembers()) {
+          if (!llvm::is_contained(mapVarsArr, mapMember)) {
+            newMapOps.push_back(mapMember);
+            if (directiveOp && insertBlockArgs) {
+              directiveOp->getRegion(0).insertArgument(
+                  blockArgInsertIndex, mapMember.getType(), mapMember.getLoc());
+            }
+            blockArgInsertIndex++;
           }
         }
-        newMapOps.push_back(operandsArr[i]);
+        mutableOpRange.assign(newMapOps);
       }
-      mutableOpRange.assign(newMapOps);
     };
+
+    auto argIface =
+        llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
     if (auto mapClauseOwner =
             llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
-      mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapVars();
+      mlir::OperandRange mapVarsArr = mapClauseOwner.getMapVars();
       mlir::MutableOperandRange mapMutableOpRange =
           mapClauseOwner.getMapVarsMutable();
-      mlir::omp::TargetOp targetOp =
-          llvm::dyn_cast<mlir::omp::TargetOp>(target);
-      addOperands(mapOperandsArr, mapMutableOpRange, targetOp);
+      unsigned blockArgInsertIndex =
+          argIface
+              ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+              : 0;
+      addOperands(mapVarsArr, mapMutableOpRange, argIface.getOperation(),
+                  newDescParentMapOp, blockArgInsertIndex,
+                  !llvm::isa<mlir::omp::TargetDataOp>(target));
     }
+
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
       mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddrVars();
       mlir::MutableOperandRange useDevAddrMutableOpRange =
           targetDataOp.getUseDeviceAddrVarsMutable();
-      addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
-    }
+      addOperands(useDevAddrArr, useDevAddrMutableOpRange, target,
+                  newDescParentMapOp,
+                  argIface.getUseDeviceAddrBlockArgsStart() +
+                      argIface.numUseDeviceAddrBlockArgs());
 
-    mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
-        op->getLoc(), op.getResult().getType(), descriptor,
-        mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
-        /*varPtrPtr=*/mlir::Value{},
-        /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
-        /*members_index=*/
-        mlir::DenseIntElementsAttr::get(
-            mlir::VectorType::get(
-                llvm::ArrayRef<int64_t>({1, 1}),
-                mlir::IntegerType::get(builder.getContext(), 32)),
-            llvm::ArrayRef<int32_t>({0})),
-        /*bounds=*/mlir::SmallVector<mlir::Value>{},
-        builder.getIntegerAttr(builder.getIntegerType(64, false),
-                               op.getMapType().value()),
-        op.getMapCaptureTypeAttr(), op.getNameAttr(), op.getPartialMapAttr());
-    op.replaceAllUsesWith(newDescParentMapOp);
-    op->erase();
+      mlir::OperandRange useDevPtrArr = targetDataOp.getUseDevicePtrVars();
+      mlir::MutableOperandRange useDevPtrMutableOpRange =
+          targetDataOp.getUseDevicePtrVarsMutable();
+      addOperands(useDevPtrArr, useDevPtrMutableOpRange, target,
+                  newDescParentMapOp,
+                  argIface.getUseDevicePtrBlockArgsStart() +
+                      argIface.numUseDevicePtrBlockArgs());
+    }
   }
 
   // We add all mapped record members not directly used in the target region
diff --git a/flang/test/Lower/OpenMP/allocatable-map.f90 b/flang/test/Lower/OpenMP/allocatable-map.f90
index a9f576a6f09992..c1f94f41901489 100644
--- a/flang/test/Lower/OpenMP/allocatable-map.f90
+++ b/flang/test/Lower/OpenMP/allocatable-map.f90
@@ -4,7 +4,7 @@
 !HLFIRDIALECT: %[[BOX_OFF:.*]] = fir.box_offset %[[POINTER]]#1 base_addr : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
 !HLFIRDIALECT: %[[POINTER_MAP_MEMBER:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr(%[[BOX_OFF]] : !fir.llvm_ptr<!fir.ref<i32>>)  map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
 !HLFIRDIALECT: %[[POINTER_MAP:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[POINTER_MAP_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "point"}
-!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP_MEMBER]] -> {{.*}}, %[[POINTER_MAP]] -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP]] -> {{.*}}, %[[POINTER_MAP_MEMBER]] -> {{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
 subroutine pointer_routine()
     integer, pointer :: point 
 !$omp target map(tofrom:point)
diff --git a/flang/test/Lower/OpenMP/array-bounds.f90 b/flang/test/Lower/OpenMP/array-bounds.f90
index 09498ca6cdde99..40fd276f10462b 100644
--- a/flang/test/Lower/OpenMP/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/array-bounds.f90
@@ -53,7 +53,7 @@ module assumed_array_routines
 !HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
 !HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
 !HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target   map_entries(%[[MAP_INFO_MEMBER]] -> %{{.*}}, %[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
+!HOST: omp.target   map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
     subroutine assumed_shape_array(arr_read_write)
             integer, intent(inout) :: arr_read_write(:)
 
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 63a43e750979d5..f8bbde93072e91 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -528,9 +528,9 @@ subroutine omp_target_device_addr
    !CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
    !CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
    !CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
-   !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+   !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
    !$omp target data map(tofrom: a) use_device_addr(a)
-   !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
+   !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
    !CHECK: %[[C10:.*]] = arith.constant 10 : i32
    !CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
    !CHECK: %[[A_ADDR:.*]] = fir.box_addr %[[A_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index cb26246a6e80f0..8c1abad8eaa8d5 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,7 +6,8 @@
 ! use_device_ptr to use_device_addr works, without breaking any functionality.
 
 !CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
 subroutine only_use_device_ptr
     use iso_c_binding
     integer, pointer, dimension(:) :: array
@@ -18,7 +19,7 @@ subroutine only_use_device_ptr
      end subroutine
 
 !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
 subroutine mix_use_device_ptr_and_addr
     use iso_c_binding
     integer, pointer, dimension(:) :: array
@@ -30,7 +31,7 @@ subroutine mix_use_device_ptr_and_addr
      end subroutine
 
      !CHECK: func.func @{{.*}}only_use_device_addr()
-     !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+     !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
      subroutine only_use_device_addr
         use iso_c_binding
         integer, pointer, dimension(:) :: array
@@ -42,7 +43,7 @@ subroutine only_use_device_addr
      end subroutine
 
      !CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
-     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
      subroutine mix_use_device_ptr_and_addr_and_map
         use iso_c_binding
         integer :: i, j
@@ -55,7 +56,7 @@ subroutine mix_use_device_ptr_and_addr_and_map
      end subroutine
 
      !CHECK: func.func @{{.*}}only_use_map()
-     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+     !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
      subroutine only_use_map
         use iso_c_binding
         integer, pointer, dimension(:) :: array
diff --git a/flang/test/Transforms/omp-map-info-finalization.fir b/flang/test/Transforms/omp-map-info-finalization.fir
index fa7b65d41929b7..de0ad2143fc853 100644
--- a/flang/test/Transforms/omp-map-info-finalization.fir
+++ b/flang/test/Transforms/omp-map-info-finalization.fir
@@ -39,7 +39,7 @@ module attributes {omp.is_target_device = false} {
 // CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
 // CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
 // CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
-// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
+// CHECK: omp.target map_entries(%[[DESC_PARENT_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP]] -> %[[ARG3:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
 
 // -----
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 27cd38dc3c62d9..fbbaa909fd3dd8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2327,21 +2327,20 @@ static void collectMapDataFromMapOperands(
     mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
   }
 
-  auto findMapInfo = [&mapData](llvm::Value *val,
-                                llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
-    unsigned index = 0;
-    bool found = false;
-    for (llvm::Value *basePtr : mapData.OriginalValue) {
-      if (basePtr == val && mapData.IsAMapping[index]) {
-        found = true;
-        mapData.Types[index] |=
-            llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
-        mapData.DevicePointers[index] = devInfoTy;
-      }
-      index++;
-    }
-    return found;
-  };
+  // This function alters the original mapped pointers type if it was present in
+  // a map clause as well as being present in a useDevAddr/Ptr clause.
+  auto alterAndCreateUseDevMapType =
+      [&mapData](llvm::Value *val,
+                 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+        for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) {
+          if (origVal == val && mapData.IsAMapping[i]) {
+            mapData.Types[i] |=
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+            mapData.Devi...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/113919


More information about the llvm-commits mailing list