[Mlir-commits] [flang] [llvm] [mlir] [Flang][OpenMP] Update MapInfoFinalization to use BlockArgs Interface… (PR #113919)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 28 07:48:25 PDT 2024
https://github.com/agozillon created https://github.com/llvm/llvm-project/pull/113919
… and modify use_device_ptr/addr to be order independent
This patch primarily updates the MapInfoFinalization pass to utilise the BlockArgument interface. It also shuffles newly added arguments the MapInfoFinalization passes to the end of the BlockArg/Relevant MapInfo lists, instead of one prior to the owning descriptor type.
During this it was noted that the use_device_ptr/addr handling of target data was a little bit too order dependent so I've attempted to make it less so, as we cannot depend on argument ordering to be the same as Fortran for any future frontends.
>From edeb77b546f2a4d79069604811d75221b0148dc3 Mon Sep 17 00:00:00 2001
From: agozillon <Andrew.Gozillon at amd.com>
Date: Mon, 28 Oct 2024 09:45:28 -0500
Subject: [PATCH] [Flang][OpenMP] Update MapInfoFinalization to use BlockArgs
Interface and modify use_device_ptr/addr to be order independent
This patch primarily updates the MapInfoFinalization pass to
utilise the BlockArgument interface. It also shuffles newly
added arguments the MapInfoFinalization passes to the end
of the BlockArg/Relevant MapInfo lists, instead of one prior
to the owning descriptor type.
During this it was noted that the use_device_ptr/addr handling
of target data was a little bit too order dependent so I've
attempted to make it less so, as we cannot depend on
argument ordering to be the same as Fortran for any future
frontends.
---
.../Optimizer/OpenMP/MapInfoFinalization.cpp | 101 +++++++++------
flang/test/Lower/OpenMP/allocatable-map.f90 | 2 +-
flang/test/Lower/OpenMP/array-bounds.f90 | 2 +-
flang/test/Lower/OpenMP/target.f90 | 4 +-
.../use-device-ptr-to-use-device-addr.f90 | 11 +-
.../Transforms/omp-map-info-finalization.fir | 2 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 118 +++++++++---------
mlir/test/Target/LLVMIR/omptarget-llvm.mlir | 16 +--
offload/test/Inputs/target-use-dev-ptr.c | 23 ++++
.../offloading/fortran/target-use-dev-ptr.f90 | 37 ++++++
10 files changed, 201 insertions(+), 115 deletions(-)
create mode 100644 offload/test/Inputs/target-use-dev-ptr.c
create mode 100644 offload/test/offloading/fortran/target-use-dev-ptr.f90
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 7ebeb51cf3dec7..6eb65b25594295 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -125,61 +125,82 @@ class MapInfoFinalizationPass
// TODO: map the addendum segment of the descriptor, similarly to the
// above base address/data pointer member.
- auto addOperands = [&](mlir::OperandRange &operandsArr,
- mlir::MutableOperandRange &mutableOpRange,
- auto directiveOp) {
- llvm::SmallVector<mlir::Value> newMapOps;
- for (size_t i = 0; i < operandsArr.size(); ++i) {
- if (operandsArr[i] == op) {
- // Push new implicit maps generated for the descriptor.
- newMapOps.push_back(baseAddr);
+ mlir::omp::MapInfoOp newDescParentMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op->getLoc(), op.getResult().getType(), descriptor,
+ mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
+ /*members_index=*/
+ mlir::DenseIntElementsAttr::get(
+ mlir::VectorType::get(
+ llvm::ArrayRef<int64_t>({1, 1}),
+ mlir::IntegerType::get(builder.getContext(), 32)),
+ llvm::ArrayRef<int32_t>({0})),
+ /*bounds=*/mlir::SmallVector<mlir::Value>{},
+ builder.getIntegerAttr(builder.getIntegerType(64, false),
+ op.getMapType().value()),
+ op.getMapCaptureTypeAttr(), op.getNameAttr(),
+ op.getPartialMapAttr());
+ op.replaceAllUsesWith(newDescParentMapOp.getResult());
+ op->erase();
- // for TargetOp's which have IsolatedFromAbove we must align the
- // new additional map operand with an appropriate BlockArgument,
- // as the printing and later processing currently requires a 1:1
- // mapping of BlockArgs to MapInfoOp's at the same placement in
- // each array (BlockArgs and MapOperands).
- if (directiveOp) {
- directiveOp.getRegion().insertArgument(i, baseAddr.getType(), loc);
+ auto addOperands = [&](mlir::OperandRange &mapVarsArr,
+ mlir::MutableOperandRange &mutableOpRange,
+ mlir::Operation *directiveOp,
+ mlir::omp::MapInfoOp newDesc,
+ unsigned blockArgInsertIndex = 0,
+ bool insertBlockArgs = true) {
+ if (llvm::is_contained(mapVarsArr, newDesc.getResult())) {
+ llvm::SmallVector<mlir::Value> newMapOps{mapVarsArr};
+ for (auto mapMember : newDesc.getMembers()) {
+ if (!llvm::is_contained(mapVarsArr, mapMember)) {
+ newMapOps.push_back(mapMember);
+ if (directiveOp && insertBlockArgs) {
+ directiveOp->getRegion(0).insertArgument(
+ blockArgInsertIndex, mapMember.getType(), mapMember.getLoc());
+ }
+ blockArgInsertIndex++;
}
}
- newMapOps.push_back(operandsArr[i]);
+ mutableOpRange.assign(newMapOps);
}
- mutableOpRange.assign(newMapOps);
};
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(target);
+
if (auto mapClauseOwner =
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
- mlir::OperandRange mapOperandsArr = mapClauseOwner.getMapVars();
+ mlir::OperandRange mapVarsArr = mapClauseOwner.getMapVars();
mlir::MutableOperandRange mapMutableOpRange =
mapClauseOwner.getMapVarsMutable();
- mlir::omp::TargetOp targetOp =
- llvm::dyn_cast<mlir::omp::TargetOp>(target);
- addOperands(mapOperandsArr, mapMutableOpRange, targetOp);
+ unsigned blockArgInsertIndex =
+ argIface
+ ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
+ : 0;
+ addOperands(mapVarsArr, mapMutableOpRange, argIface.getOperation(),
+ newDescParentMapOp, blockArgInsertIndex,
+ !llvm::isa<mlir::omp::TargetDataOp>(target));
}
+
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
mlir::OperandRange useDevAddrArr = targetDataOp.getUseDeviceAddrVars();
mlir::MutableOperandRange useDevAddrMutableOpRange =
targetDataOp.getUseDeviceAddrVarsMutable();
- addOperands(useDevAddrArr, useDevAddrMutableOpRange, targetDataOp);
- }
+ addOperands(useDevAddrArr, useDevAddrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDeviceAddrBlockArgsStart() +
+ argIface.numUseDeviceAddrBlockArgs());
- mlir::Value newDescParentMapOp = builder.create<mlir::omp::MapInfoOp>(
- op->getLoc(), op.getResult().getType(), descriptor,
- mlir::TypeAttr::get(fir::unwrapRefType(descriptor.getType())),
- /*varPtrPtr=*/mlir::Value{},
- /*members=*/mlir::SmallVector<mlir::Value>{baseAddr},
- /*members_index=*/
- mlir::DenseIntElementsAttr::get(
- mlir::VectorType::get(
- llvm::ArrayRef<int64_t>({1, 1}),
- mlir::IntegerType::get(builder.getContext(), 32)),
- llvm::ArrayRef<int32_t>({0})),
- /*bounds=*/mlir::SmallVector<mlir::Value>{},
- builder.getIntegerAttr(builder.getIntegerType(64, false),
- op.getMapType().value()),
- op.getMapCaptureTypeAttr(), op.getNameAttr(), op.getPartialMapAttr());
- op.replaceAllUsesWith(newDescParentMapOp);
- op->erase();
+ mlir::OperandRange useDevPtrArr = targetDataOp.getUseDevicePtrVars();
+ mlir::MutableOperandRange useDevPtrMutableOpRange =
+ targetDataOp.getUseDevicePtrVarsMutable();
+ addOperands(useDevPtrArr, useDevPtrMutableOpRange, target,
+ newDescParentMapOp,
+ argIface.getUseDevicePtrBlockArgsStart() +
+ argIface.numUseDevicePtrBlockArgs());
+ }
}
// We add all mapped record members not directly used in the target region
diff --git a/flang/test/Lower/OpenMP/allocatable-map.f90 b/flang/test/Lower/OpenMP/allocatable-map.f90
index a9f576a6f09992..c1f94f41901489 100644
--- a/flang/test/Lower/OpenMP/allocatable-map.f90
+++ b/flang/test/Lower/OpenMP/allocatable-map.f90
@@ -4,7 +4,7 @@
!HLFIRDIALECT: %[[BOX_OFF:.*]] = fir.box_offset %[[POINTER]]#1 base_addr : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> !fir.llvm_ptr<!fir.ref<i32>>
!HLFIRDIALECT: %[[POINTER_MAP_MEMBER:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr(%[[BOX_OFF]] : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!HLFIRDIALECT: %[[POINTER_MAP:.*]] = omp.map.info var_ptr(%[[POINTER]]#1 : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[POINTER_MAP_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "point"}
-!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP_MEMBER]] -> {{.*}}, %[[POINTER_MAP]] -> {{.*}} : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+!HLFIRDIALECT: omp.target map_entries(%[[POINTER_MAP]] -> {{.*}}, %[[POINTER_MAP_MEMBER]] -> {{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
subroutine pointer_routine()
integer, pointer :: point
!$omp target map(tofrom:point)
diff --git a/flang/test/Lower/OpenMP/array-bounds.f90 b/flang/test/Lower/OpenMP/array-bounds.f90
index 09498ca6cdde99..40fd276f10462b 100644
--- a/flang/test/Lower/OpenMP/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/array-bounds.f90
@@ -53,7 +53,7 @@ module assumed_array_routines
!HOST: %[[VAR_PTR_PTR:.*]] = fir.box_offset %0 base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
!HOST: %[[MAP_INFO_MEMBER:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[VAR_PTR_PTR]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
!HOST: %[[MAP:.*]] = omp.map.info var_ptr(%[[INTERMEDIATE_ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_INFO_MEMBER]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP_INFO_MEMBER]] -> %{{.*}}, %[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
+!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, %[[MAP_INFO_MEMBER]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine assumed_shape_array(arr_read_write)
integer, intent(inout) :: arr_read_write(:)
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 63a43e750979d5..f8bbde93072e91 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -528,9 +528,9 @@ subroutine omp_target_device_addr
!CHECK: %[[MAP:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[MAP_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
!CHECK: %[[DEV_ADDR_MEMBERS:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, i32) var_ptr_ptr({{.*}} : !fir.llvm_ptr<!fir.ref<i32>>) map_clauses(tofrom) capture(ByRef) -> !fir.llvm_ptr<!fir.ref<i32>> {name = ""}
!CHECK: %[[DEV_ADDR:.*]] = omp.map.info var_ptr({{.*}} : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.box<!fir.ptr<i32>>) map_clauses(tofrom) capture(ByRef) members(%[[DEV_ADDR_MEMBERS]] : [0] : !fir.llvm_ptr<!fir.ref<i32>>) -> !fir.ref<!fir.box<!fir.ptr<i32>>> {name = "a"}
- !CHECK: omp.target_data map_entries(%[[MAP_MEMBERS]], %[[MAP]] : {{.*}}) use_device_addr(%[[DEV_ADDR_MEMBERS]] -> %[[ARG_0:.*]], %[[DEV_ADDR]] -> %[[ARG_1:.*]] : !fir.llvm_ptr<!fir.ref<i32>>, !fir.ref<!fir.box<!fir.ptr<i32>>>) {
+ !CHECK: omp.target_data map_entries(%[[MAP]], %[[MAP_MEMBERS]] : {{.*}}) use_device_addr(%[[DEV_ADDR]] -> %[[ARG_0:.*]], %[[DEV_ADDR_MEMBERS]] -> %[[ARG_1:.*]] : !fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.llvm_ptr<!fir.ref<i32>>) {
!$omp target data map(tofrom: a) use_device_addr(a)
- !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_1]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
+ !CHECK: %[[VAL_1_DECL:.*]]:2 = hlfir.declare %[[ARG_0]] {fortran_attrs = #fir.var_attrs<pointer>, uniq_name = "_QFomp_target_device_addrEa"} : (!fir.ref<!fir.box<!fir.ptr<i32>>>) -> (!fir.ref<!fir.box<!fir.ptr<i32>>>, !fir.ref<!fir.box<!fir.ptr<i32>>>)
!CHECK: %[[C10:.*]] = arith.constant 10 : i32
!CHECK: %[[A_BOX:.*]] = fir.load %[[VAL_1_DECL]]#0 : !fir.ref<!fir.box<!fir.ptr<i32>>>
!CHECK: %[[A_ADDR:.*]] = fir.box_addr %[[A_BOX]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
diff --git a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90 b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
index cb26246a6e80f0..8c1abad8eaa8d5 100644
--- a/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
+++ b/flang/test/Lower/OpenMP/use-device-ptr-to-use-device-addr.f90
@@ -6,7 +6,8 @@
! use_device_ptr to use_device_addr works, without breaking any functionality.
!CHECK: func.func @{{.*}}only_use_device_ptr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) use_device_ptr(%{{.*}} -> %{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine only_use_device_ptr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -18,7 +19,7 @@ subroutine only_use_device_ptr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr()
-!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+!CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr({{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -30,7 +31,7 @@ subroutine mix_use_device_ptr_and_addr
end subroutine
!CHECK: func.func @{{.*}}only_use_device_addr()
- !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_device_addr
use iso_c_binding
integer, pointer, dimension(:) :: array
@@ -42,7 +43,7 @@ subroutine only_use_device_addr
end subroutine
!CHECK: func.func @{{.*}}mix_use_device_ptr_and_addr_and_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}} : !fir.ref<i32>, !fir.ref<i32>) use_device_addr(%{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>) use_device_ptr(%{{.*}} : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>) {
subroutine mix_use_device_ptr_and_addr_and_map
use iso_c_binding
integer :: i, j
@@ -55,7 +56,7 @@ subroutine mix_use_device_ptr_and_addr_and_map
end subroutine
!CHECK: func.func @{{.*}}only_use_map()
- !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>) {
+ !CHECK: omp.target_data map_entries(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>, !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.ref<!fir.box<!fir.ptr<!fir.array<?xi32>>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xf32>>>, !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) {
subroutine only_use_map
use iso_c_binding
integer, pointer, dimension(:) :: array
diff --git a/flang/test/Transforms/omp-map-info-finalization.fir b/flang/test/Transforms/omp-map-info-finalization.fir
index fa7b65d41929b7..de0ad2143fc853 100644
--- a/flang/test/Transforms/omp-map-info-finalization.fir
+++ b/flang/test/Transforms/omp-map-info-finalization.fir
@@ -39,7 +39,7 @@ module attributes {omp.is_target_device = false} {
// CHECK: %[[BASE_ADDR_OFF_2:.*]] = fir.box_offset %[[ALLOCA]] base_addr : (!fir.ref<!fir.box<!fir.array<?xi32>>>) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>
// CHECK: %[[DESC_MEMBER_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.array<?xi32>) var_ptr_ptr(%[[BASE_ADDR_OFF_2]] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) map_clauses(from) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>> {name = ""}
// CHECK: %[[DESC_PARENT_MAP_2:.*]] = omp.map.info var_ptr(%[[ALLOCA]] : !fir.ref<!fir.box<!fir.array<?xi32>>>, !fir.box<!fir.array<?xi32>>) map_clauses(from) capture(ByRef) members(%[[DESC_MEMBER_MAP_2]] : [0] : !fir.llvm_ptr<!fir.ref<!fir.array<?xi32>>>) -> !fir.ref<!fir.array<?xi32>>
-// CHECK: omp.target map_entries(%[[DESC_MEMBER_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG3:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
+// CHECK: omp.target map_entries(%[[DESC_PARENT_MAP]] -> %[[ARG1:.*]], %[[DESC_PARENT_MAP_2]] -> %[[ARG2:.*]], %[[DESC_MEMBER_MAP]] -> %[[ARG3:.*]], %[[DESC_MEMBER_MAP_2]] -> %[[ARG4:.*]] : {{.*}}) {
// -----
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 27cd38dc3c62d9..fbbaa909fd3dd8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2327,21 +2327,20 @@ static void collectMapDataFromMapOperands(
mapData.IsAMember.push_back(checkIsAMember(mapVars, mapOp));
}
- auto findMapInfo = [&mapData](llvm::Value *val,
- llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
- unsigned index = 0;
- bool found = false;
- for (llvm::Value *basePtr : mapData.OriginalValue) {
- if (basePtr == val && mapData.IsAMapping[index]) {
- found = true;
- mapData.Types[index] |=
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
- mapData.DevicePointers[index] = devInfoTy;
- }
- index++;
- }
- return found;
- };
+ // This function alters the original mapped pointers type if it was present in
+ // a map clause as well as being present in a useDevAddr/Ptr clause.
+ auto alterAndCreateUseDevMapType =
+ [&mapData](llvm::Value *val,
+ llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
+ for (auto [i, origVal] : llvm::enumerate(mapData.OriginalValue)) {
+ if (origVal == val && mapData.IsAMapping[i]) {
+ mapData.Types[i] |=
+ llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ mapData.DevicePointers[i] = devInfoTy;
+ }
+ }
+ return llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
+ };
// Process useDevPtr(Addr)Operands
auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
@@ -2351,25 +2350,22 @@ static void collectMapDataFromMapOperands(
Value offloadPtr =
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
-
- // Check if map info is already present for this entry.
- if (!findMapInfo(origValue, devInfoTy)) {
- mapData.OriginalValue.push_back(origValue);
- mapData.Pointers.push_back(mapData.OriginalValue.back());
- mapData.IsDeclareTarget.push_back(false);
- mapData.BasePointers.push_back(mapData.OriginalValue.back());
- mapData.BaseType.push_back(
- moduleTranslation.convertType(mapOp.getVarType()));
- mapData.Sizes.push_back(builder.getInt64(0));
- mapData.MapClause.push_back(mapOp.getOperation());
- mapData.Types.push_back(
- llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
- mapData.Names.push_back(LLVM::createMappingInformation(
- mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
- mapData.DevicePointers.push_back(devInfoTy);
- mapData.IsAMapping.push_back(false);
- mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
- }
+ llvm::omp::OpenMPOffloadMappingFlags mapType =
+ alterAndCreateUseDevMapType(origValue, devInfoTy);
+ mapData.OriginalValue.push_back(origValue);
+ mapData.Pointers.push_back(mapData.OriginalValue.back());
+ mapData.IsDeclareTarget.push_back(false);
+ mapData.BasePointers.push_back(mapData.OriginalValue.back());
+ mapData.BaseType.push_back(
+ moduleTranslation.convertType(mapOp.getVarType()));
+ mapData.Sizes.push_back(builder.getInt64(0));
+ mapData.MapClause.push_back(mapOp.getOperation());
+ mapData.Types.push_back(mapType);
+ mapData.Names.push_back(LLVM::createMappingInformation(
+ mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
+ mapData.DevicePointers.push_back(devInfoTy);
+ mapData.IsAMapping.push_back(false);
+ mapData.IsAMember.push_back(checkIsAMember(useDevOperands, mapOp));
}
};
@@ -2850,6 +2846,15 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
if (!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
createAlteredByCaptureMap(mapData, moduleTranslation, builder);
+ auto useDevAndMapped = [&mapData](unsigned mapIdx) {
+ if (!mapData.IsAMapping[mapIdx])
+ for (auto [idx, origValue] : llvm::enumerate(mapData.OriginalValue))
+ if (origValue == mapData.OriginalValue[mapIdx] &&
+ mapData.IsAMapping[idx])
+ return true;
+ return false;
+ };
+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
// We operate under the assumption that all vectors that are
@@ -2863,6 +2868,9 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
if (mapData.IsAMember[i])
continue;
+ if (useDevAndMapped(i))
+ continue;
+
auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
if (!mapInfoOp.getMembers().empty()) {
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
@@ -2999,23 +3007,21 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
[&moduleTranslation](
llvm::OpenMPIRBuilder::DeviceInfoTy type,
llvm::ArrayRef<BlockArgument> blockArgs,
- llvm::OpenMPIRBuilder::MapValuesArrayTy &basePointers,
- llvm::OpenMPIRBuilder::MapDeviceInfoArrayTy &devicePointers,
+ llvm::SmallVectorImpl<Value> &devicePtrVars, MapInfoData &mapInfoData,
llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
- // Get a range to iterate over `basePointers` after filtering based on
- // `devicePointers` and the given device info type.
- auto basePtrRange = llvm::map_range(
- llvm::make_filter_range(
- llvm::zip_equal(basePointers, devicePointers),
- [type](auto x) { return std::get<1>(x) == type; }),
- [](auto x) { return std::get<0>(x); });
-
- // Map block arguments to the corresponding processed base pointer. If
- // a mapper is not specified, map the block argument to the base pointer
- // directly.
- for (auto [arg, basePointer] : llvm::zip_equal(blockArgs, basePtrRange))
- moduleTranslation.mapValue(arg, mapper ? mapper(basePointer)
- : basePointer);
+ for (auto [arg, devPtrVar] :
+ llvm::zip_equal(blockArgs, devicePtrVars)) {
+ for (size_t i = 0; i < mapInfoData.MapClause.size(); ++i) {
+ if (mapInfoData.MapClause[i] == devPtrVar.getDefiningOp()) {
+ if (mapInfoData.DevicePointers[i] == type) {
+ moduleTranslation.mapValue(
+ arg, mapper ? mapper(mapInfoData.BasePointers[i])
+ : mapInfoData.BasePointers[i]);
+ }
+ break;
+ }
+ }
+ }
};
using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
@@ -3030,19 +3036,17 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
// Check if any device ptr/addr info is available
if (!info.DevicePtrInfoMap.empty()) {
builder.restoreIP(codeGenIP);
-
mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
blockArgIface.getUseDeviceAddrBlockArgs(),
- combinedInfo.BasePointers, combinedInfo.DevicePointers,
+ useDeviceAddrVars, mapData,
[&](llvm::Value *basePointer) -> llvm::Value * {
return builder.CreateLoad(
builder.getPtrTy(),
info.DevicePtrInfoMap[basePointer].second);
});
mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
- blockArgIface.getUseDevicePtrBlockArgs(),
- combinedInfo.BasePointers, combinedInfo.DevicePointers,
- [&](llvm::Value *basePointer) {
+ blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
+ mapData, [&](llvm::Value *basePointer) {
return info.DevicePtrInfoMap[basePointer].second;
});
@@ -3061,10 +3065,10 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
if (ompBuilder->Config.IsTargetDevice.value_or(false)) {
mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
blockArgIface.getUseDeviceAddrBlockArgs(),
- mapData.BasePointers, mapData.DevicePointers);
+ useDeviceAddrVars, mapData);
mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
blockArgIface.getUseDevicePtrBlockArgs(),
- mapData.BasePointers, mapData.DevicePointers);
+ useDevicePtrVars, mapData);
}
bodyGenStatus = inlineConvertOmpRegions(region, "omp.data.region",
diff --git a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
index 654763c577d1af..7f21095763a397 100644
--- a/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-llvm.mlir
@@ -237,14 +237,14 @@ llvm.func @_QPopenmp_target_use_dev_ptr() {
// CHECK: store ptr null, ptr %[[VAL_9]], align 8
// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_11:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_begin_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_10]], ptr %[[VAL_11]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: %[[VAL_12:.*]] = load ptr, ptr %[[VAL_7]], align 8
// CHECK: store ptr %[[VAL_12]], ptr %[[VAL_3]], align 8
// CHECK: %[[VAL_13:.*]] = load ptr, ptr %[[VAL_3]], align 8
// CHECK: store i32 10, ptr %[[VAL_13]], align 4
// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_end_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_14]], ptr %[[VAL_15]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_14]], ptr %[[VAL_15]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: ret void
// -----
@@ -280,13 +280,13 @@ llvm.func @_QPopenmp_target_use_dev_addr() {
// CHECK: store ptr null, ptr %[[VAL_8]], align 8
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_begin_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_9]], ptr %[[VAL_10]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_9]], ptr %[[VAL_10]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_6]], align 8
// CHECK: %[[VAL_12:.*]] = load ptr, ptr %[[VAL_11]], align 8
// CHECK: store i32 10, ptr %[[VAL_12]], align 4
// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_end_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_13]], ptr %[[VAL_14]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: ret void
// -----
@@ -321,12 +321,12 @@ llvm.func @_QPopenmp_target_use_dev_addr_no_ptr() {
// CHECK: store ptr null, ptr %[[VAL_8]], align 8
// CHECK: %[[VAL_9:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_10:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_begin_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_9]], ptr %[[VAL_10]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_9]], ptr %[[VAL_10]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: %[[VAL_11:.*]] = load ptr, ptr %[[VAL_6]], align 8
// CHECK: store i32 10, ptr %[[VAL_11]], align 4
// CHECK: %[[VAL_12:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_13:.*]] = getelementptr inbounds [1 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_end_mapper(ptr @2, i64 -1, i32 1, ptr %[[VAL_12]], ptr %[[VAL_13]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr @{{.*}}, i64 -1, i32 1, ptr %[[VAL_12]], ptr %[[VAL_13]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: ret void
// -----
@@ -433,7 +433,7 @@ llvm.func @_QPopenmp_target_use_dev_both() {
// CHECK: store ptr null, ptr %[[VAL_13]], align 8
// CHECK: %[[VAL_14:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_begin_mapper(ptr @3, i64 -1, i32 2, ptr %[[VAL_14]], ptr %[[VAL_15]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_begin_mapper(ptr @{{.*}}, i64 -1, i32 2, ptr %[[VAL_14]], ptr %[[VAL_15]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: %[[VAL_16:.*]] = load ptr, ptr %[[VAL_8]], align 8
// CHECK: store ptr %[[VAL_16]], ptr %[[VAL_3]], align 8
// CHECK: %[[VAL_17:.*]] = load ptr, ptr %[[VAL_11]], align 8
@@ -443,7 +443,7 @@ llvm.func @_QPopenmp_target_use_dev_both() {
// CHECK: store i32 20, ptr %[[VAL_19]], align 4
// CHECK: %[[VAL_20:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_0]], i32 0, i32 0
// CHECK: %[[VAL_21:.*]] = getelementptr inbounds [2 x ptr], ptr %[[VAL_1]], i32 0, i32 0
-// CHECK: call void @__tgt_target_data_end_mapper(ptr @3, i64 -1, i32 2, ptr %[[VAL_20]], ptr %[[VAL_21]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
+// CHECK: call void @__tgt_target_data_end_mapper(ptr @{{.*}}, i64 -1, i32 2, ptr %[[VAL_20]], ptr %[[VAL_21]], ptr @.offload_sizes, ptr @.offload_maptypes, ptr @.offload_mapnames, ptr null)
// CHECK: ret void
// -----
diff --git a/offload/test/Inputs/target-use-dev-ptr.c b/offload/test/Inputs/target-use-dev-ptr.c
new file mode 100644
index 00000000000000..e1430a93fbc7dc
--- /dev/null
+++ b/offload/test/Inputs/target-use-dev-ptr.c
@@ -0,0 +1,23 @@
+// Helper function used in Offload Fortran test
+// target-use-dev-ptr.f90 to allocate data and
+// check resulting addresses.
+
+#include <assert.h>
+#include <malloc.h>
+#include <stdio.h>
+
+int *get_ptr() {
+ int *ptr = malloc(sizeof(int));
+ assert(ptr && "malloc returned null");
+ return ptr;
+}
+
+int check_result(int *host_ptr, int *dev_ptr) {
+ if (dev_ptr == NULL || dev_ptr == host_ptr) {
+ printf("FAILURE\n");
+ return -1;
+ } else {
+ printf("SUCCESS\n");
+ return 0;
+ }
+}
diff --git a/offload/test/offloading/fortran/target-use-dev-ptr.f90 b/offload/test/offloading/fortran/target-use-dev-ptr.f90
new file mode 100644
index 00000000000000..4476f45699d6ec
--- /dev/null
+++ b/offload/test/offloading/fortran/target-use-dev-ptr.f90
@@ -0,0 +1,37 @@
+! Basic test of use_device_ptr, checking if the appropriate
+! addresses are maintained across target boundaries
+! REQUIRES: clang, flang, amdgcn-amd-amdhsa
+
+! RUN: %clang -c -fopenmp -fopenmp-targets=amdgcn-amd-amdhsa \
+! RUN: %S/../../Inputs/target-use-dev-ptr.c -o target-use-dev-ptr_c.o
+! RUN: %libomptarget-compile-fortran-generic target-use-dev-ptr_c.o
+! RUN: %t | %fcheck-generic
+
+program use_device_test
+ use iso_c_binding
+ interface
+ type(c_ptr) function get_ptr() BIND(C)
+ USE, intrinsic :: iso_c_binding
+ implicit none
+ end function get_ptr
+
+ integer(c_int) function check_result(host, dev) BIND(C)
+ USE, intrinsic :: iso_c_binding
+ implicit none
+ type(c_ptr), intent(in) :: host, dev
+ end function check_result
+ end interface
+
+ type(c_ptr) :: device_ptr, x
+
+ x = get_ptr()
+ device_ptr = x
+
+ !$omp target data map(tofrom: x) use_device_ptr(x)
+ device_ptr = x
+ !$omp end target data
+
+ print *, check_result(x, device_ptr)
+end program use_device_test
+
+! CHECK: SUCCESS
More information about the Mlir-commits
mailing list