[flang-commits] [flang] [openmp] [mlir] [Flang][OpenMP] Remove use of non reference values from MapInfoOp (PR #72444)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Wed Nov 15 14:17:35 PST 2023


https://github.com/TIFitis created https://github.com/llvm/llvm-project/pull/72444

This patch removes the val field from the `MapInfoOp`.

Previously when lowering `TargetOp`, the bounds information for the `BoxValues` were also being mapped. Instead these ops are now cloned inside the target region to prevent mapping of non reference typed values.

>From 9917cea771c074a4e76c9329f9e44e995f5d1c13 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Wed, 15 Nov 2023 22:12:31 +0000
Subject: [PATCH] [Flang][OpenMP] Remove use of non reference values from
 MapInfoOp

This patch removes the `val` field from the MapInfoOp.

Previously when lowering TargetOp, the bounds information for the BoxValues were also being mapped. Instead these ops are now duplicated inside the target region to prevent mapping of non reference typed values.
---
 flang/lib/Lower/OpenMP.cpp                    | 161 ++++++++----------
 flang/test/Lower/OpenMP/FIR/array-bounds.f90  |   4 +-
 flang/test/Lower/OpenMP/FIR/target.f90        |  74 ++++++--
 flang/test/Lower/OpenMP/array-bounds.f90      |   4 +-
 flang/test/Lower/OpenMP/target.f90            |  71 +++++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  15 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |   9 -
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  41 ++---
 mlir/test/Dialect/OpenMP/invalid.mlir         |  36 ----
 .../offloading/fortran/constant-arr-index.f90 |   6 +-
 10 files changed, 198 insertions(+), 223 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 00f16026d9b4bb8..30d2be33357c8c0 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -28,6 +28,7 @@
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Support/CommandLine.h"
 
@@ -1709,9 +1710,9 @@ static mlir::omp::MapInfoOp
 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 mlir::Value baseAddr, std::stringstream &name,
                 mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
-                mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
-                bool isVal = false) {
-  mlir::Value val, varPtr, varPtrPtr;
+                mlir::omp::VariableCaptureKind mapCaptureType,
+                mlir::Type retTy) {
+  mlir::Value varPtr, varPtrPtr;
   mlir::TypeAttr varType;
 
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
@@ -1719,16 +1720,12 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
     retTy = baseAddr.getType();
   }
 
-  if (isVal)
-    val = baseAddr;
-  else
-    varPtr = baseAddr;
-
-  if (auto ptrType = llvm::dyn_cast<mlir::omp::PointerLikeType>(retTy))
-    varType = mlir::TypeAttr::get(ptrType.getElementType());
+  varPtr = baseAddr;
+  varType = mlir::TypeAttr::get(
+      llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
 
   mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
-      loc, retTy, val, varPtr, varType, varPtrPtr, bounds,
+      loc, retTy, varPtr, varType, varPtrPtr, bounds,
       builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
       builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
       builder.getStringAttr(name.str()));
@@ -2489,21 +2486,20 @@ static void genBodyOfTargetOp(
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   mlir::Region &region = targetOp.getRegion();
 
-  firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
+  auto *regionBlock =
+      firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
 
   unsigned argIndex = 0;
-  unsigned blockArgsIndex = mapSymbols.size();
-
-  // The block arguments contain the map_operands followed by the bounds in
-  // order. This returns a vector containing the next 'n' block arguments for
-  // the bounds.
-  auto extractBoundArgs = [&](auto n) {
-    llvm::SmallVector<mlir::Value> argExtents;
-    while (n--) {
-      argExtents.push_back(fir::getBase(region.getArgument(blockArgsIndex)));
-      blockArgsIndex++;
+
+  // Clones the `bounds` and returns them.
+  auto cloneBounds = [&](llvm::ArrayRef<mlir::Value> bounds) {
+    llvm::SmallVector<mlir::Value> clonedBounds;
+    for (mlir::Value bound : bounds) {
+      mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
+      regionBlock->push_back(clonedOp);
+      clonedBounds.emplace_back(clonedOp->getResult(0));
     }
-    return argExtents;
+    return clonedBounds;
   };
 
   // Bind the symbols to their corresponding block arguments.
@@ -2512,34 +2508,35 @@ static void genBodyOfTargetOp(
     fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
     extVal.match(
         [&](const fir::BoxValue &v) {
-          converter.bindSymbol(
-              *sym,
-              fir::BoxValue(arg, extractBoundArgs(v.getLBounds().size()),
-                            v.getExplicitParameters(), v.getExplicitExtents()));
+          converter.bindSymbol(*sym,
+                               fir::BoxValue(arg, cloneBounds(v.getLBounds()),
+                                             v.getExplicitParameters(),
+                                             v.getExplicitExtents()));
         },
         [&](const fir::MutableBoxValue &v) {
           converter.bindSymbol(
-              *sym,
-              fir::MutableBoxValue(arg, extractBoundArgs(v.getLBounds().size()),
-                                   v.getMutableProperties()));
+              *sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
+                                         v.getMutableProperties()));
         },
         [&](const fir::ArrayBoxValue &v) {
           converter.bindSymbol(
-              *sym,
-              fir::ArrayBoxValue(arg, extractBoundArgs(v.getExtents().size()),
-                                 extractBoundArgs(v.getLBounds().size()),
-                                 v.getSourceBox()));
+              *sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
+                                       cloneBounds(v.getLBounds()),
+                                       v.getSourceBox()));
         },
         [&](const fir::CharArrayBoxValue &v) {
+          mlir::Operation *clonedLen = v.getLen().getDefiningOp()->clone();
+          regionBlock->push_back(clonedLen);
           converter.bindSymbol(
-              *sym,
-              fir::CharArrayBoxValue(arg, extractBoundArgs(1).front(),
-                                     extractBoundArgs(v.getExtents().size()),
-                                     extractBoundArgs(v.getLBounds().size())));
+              *sym, fir::CharArrayBoxValue(arg, clonedLen->getResult(0),
+                                           cloneBounds(v.getExtents()),
+                                           cloneBounds(v.getLBounds())));
         },
         [&](const fir::CharBoxValue &v) {
-          converter.bindSymbol(
-              *sym, fir::CharBoxValue(arg, extractBoundArgs(1).front()));
+          mlir::Operation *clonedLen = v.getLen().getDefiningOp()->clone();
+          regionBlock->push_back(clonedLen);
+          converter.bindSymbol(*sym,
+                               fir::CharBoxValue(arg, clonedLen->getResult(0)));
         },
         [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
         [&](const auto &) {
@@ -2549,6 +2546,43 @@ static void genBodyOfTargetOp(
     argIndex++;
   }
 
+  // Check if cloning the bounds introduced any dependency on the outer region.
+  // If so, then either clone them as well if trivial, or else add them to the
+  // map and block_argument lists.
+  llvm::SetVector<mlir::Value> valuesDefinedAbove;
+  mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
+  while (!valuesDefinedAbove.empty()) {
+    for (mlir::Value val : valuesDefinedAbove) {
+      if (fir::isa_trivial(val.getType())) {
+        mlir::Operation *clonedVal = val.getDefiningOp()->clone();
+        regionBlock->push_front(clonedVal);
+        val.replaceUsesWithIf(
+            clonedVal->getResult(0), [regionBlock](mlir::OpOperand &use) {
+              return use.getOwner()->getBlock() == regionBlock;
+            });
+      } else {
+        llvm::SmallVector<mlir::Value> bounds;
+        std::stringstream name;
+        auto savedIP = firOpBuilder.getInsertionPoint();
+        firOpBuilder.setInsertionPoint(targetOp);
+        mlir::Value mapOp = createMapInfoOp(
+            firOpBuilder, val.getLoc(), val, name, bounds,
+            static_cast<
+                std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
+            mlir::omp::VariableCaptureKind::ByCopy, val.getType());
+        targetOp.getMapOperandsMutable().append(mapOp);
+        mlir::Value clonedVal = region.addArgument(val.getType(), val.getLoc());
+        val.replaceUsesWithIf(clonedVal, [regionBlock](mlir::OpOperand &use) {
+          return use.getOwner()->getBlock() == regionBlock;
+        });
+        firOpBuilder.setInsertionPoint(regionBlock, savedIP);
+      }
+    }
+    valuesDefinedAbove.clear();
+    mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
+  }
+
   // Insert dummy instruction to remember the insertion position. The
   // marker will be deleted since there are not uses.
   // In the HLFIR flow there are hlfir.declares inserted above while
@@ -2671,53 +2705,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   };
   Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
 
-  // Add the bounds and extents for box values to mapOperands
-  auto addMapInfoForBounds = [&](const auto &bounds) {
-    for (auto &val : bounds) {
-      mapSymLocs.push_back(val.getLoc());
-      mapSymTypes.push_back(val.getType());
-
-      llvm::SmallVector<mlir::Value> bounds;
-      std::stringstream name;
-
-      mlir::Value mapOp = createMapInfoOp(
-          converter.getFirOpBuilder(), val.getLoc(), val, name, bounds,
-          static_cast<
-              std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
-              llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
-          mlir::omp::VariableCaptureKind::ByCopy, val.getType(), true);
-      mapOperands.push_back(mapOp);
-    }
-  };
-
-  for (const Fortran::semantics::Symbol *sym : mapSymbols) {
-    fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
-    extVal.match(
-        [&](const fir::BoxValue &v) { addMapInfoForBounds(v.getLBounds()); },
-        [&](const fir::MutableBoxValue &v) {
-          addMapInfoForBounds(v.getLBounds());
-        },
-        [&](const fir::ArrayBoxValue &v) {
-          addMapInfoForBounds(v.getExtents());
-          addMapInfoForBounds(v.getLBounds());
-        },
-        [&](const fir::CharArrayBoxValue &v) {
-          llvm::SmallVector<mlir::Value> len;
-          len.push_back(v.getLen());
-          addMapInfoForBounds(len);
-          addMapInfoForBounds(v.getExtents());
-          addMapInfoForBounds(v.getLBounds());
-        },
-        [&](const fir::CharBoxValue &v) {
-          llvm::SmallVector<mlir::Value> len;
-          len.push_back(v.getLen());
-          addMapInfoForBounds(len);
-        },
-        [&](const auto &) {
-          // Nothing to do for non-box values.
-        });
-  }
-
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
       nowaitAttr, mapOperands);
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index 02b5ebcee022675..8acc9738821464f 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -16,7 +16,7 @@
 !ALL:  %[[BOUNDS1:.*]] = omp.bounds   lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
 !ALL:  %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>, !fir.array<10xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
 !ALL:  %[[MAP2:.*]] = omp.map_info var_ptr(%[[ITER]] : !fir.ref<i32>, i32)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
-!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>, index, index) {
+!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
 
 subroutine read_write_section()
     integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
@@ -64,7 +64,7 @@ end subroutine assumed_shape_array
 !ALL: %[[BOUNDS:.*]]  = omp.bounds   lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
 !ALL: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
 !ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<i32>, i32)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
-!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, index) {
+!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
 
         subroutine assumed_size_array(arr_read_write)
             integer, intent(inout) :: arr_read_write(*)
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index cf3b6fff0cb7498..20de9726360bb25 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -189,8 +189,8 @@ subroutine omp_target
    integer :: a(1024)
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]], %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<1024xi32>>, index) {
-   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>, %{{.*}}: index):
+   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target map(tofrom: a)
       !CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
       !CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64
@@ -213,8 +213,8 @@ subroutine omp_target_implicit
    !CHECK: %[[VAL_0:.*]] = fir.alloca !fir.array<1024xi32> {bindc_name = "a", uniq_name = "_QFomp_target_implicitEa"}
    integer :: a(1024)
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)   map_clauses(implicit, tofrom) capture(ByRef) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]], %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<1024xi32>>, index) {
-   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>, %{{.*}}: index):
+   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target
       !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG_0]], %{{.*}} : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
       a(1) = 10
@@ -249,22 +249,60 @@ end subroutine omp_target_implicit_nested
 ! Target implicit capture with bounds
 !===============================================================================
 
-!CHECK-LABEL: func.func @_QPomp_target_implicit_bounds(%{{.*}}: !fir.ref<i32> {fir.bindc_name = "n"}) {
+
+!CHECK-LABEL: func.func @_QPomp_target_implicit_bounds(
+!CHECK: %[[VAL_0:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
 subroutine omp_target_implicit_bounds(n)
-   !CHECK: %[[VAL_1:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
-   !CHECK: %[[VAL_2:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
-   !CHECK: %[[VAL_3:.*]] = fir.alloca !fir.array<?x1024xi32>, %[[VAL_1]] {bindc_name = "a", uniq_name = "_QFomp_target_implicit_boundsEa"}
+   !CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+   !CHECK: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (i32) -> i64
+   !CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (i64) -> index
+   !CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_5:.*]] = arith.cmpi sgt, %[[VAL_3]], %[[VAL_4]] : index
+   !CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_3]], %[[VAL_4]] : index
+   !CHECK: %[[VAL_7:.*]] = arith.constant 1024 : i64
+   !CHECK: %[[VAL_8:.*]] = fir.convert %[[VAL_7]] : (i64) -> index
+   !CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_10:.*]] = arith.cmpi sgt, %[[VAL_8]], %[[VAL_9]] : index
+   !CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : index
+   !CHECK: %[[VAL_12:.*]] = fir.alloca !fir.array<?x1024xi32>, %[[VAL_6]] {bindc_name = "a", uniq_name = "_QFomp_target_implicit_boundsEa"}
    integer :: n
    integer :: a(n, 1024)
-   !CHECK: %[[VAL_4:.*]] = omp.map_info var_ptr(%[[VAL_3]] : !fir.ref<!fir.array<?x1024xi32>>, !fir.array<?x1024xi32>)   map_clauses(implicit, tofrom) capture(ByRef) bounds(%{{.*}}) -> !fir.ref<!fir.array<?x1024xi32>> {name = "a"}
-   !CHECK: %[[VAL_5:.*]] = omp.map_info val(%[[VAL_1]] : index)  map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> index {name = ""}
-   !CHECK: %[[VAL_6:.*]] = omp.map_info val(%[[VAL_2]] : index)  map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> index {name = ""}
-   !CHECK: omp.target   map_entries(%[[VAL_4]] -> %[[ARG_1:.*]], %[[VAL_5]] -> %[[ARG_2:.*]], %[[VAL_6]] -> %[[ARG_3:.*]] : !fir.ref<!fir.array<?x1024xi32>>, index, index) {
-   !CHECK: ^bb0(%[[ARG_1]]: !fir.ref<!fir.array<?x1024xi32>>, %[[ARG_2]]: index, %[[ARG_3]]: index):
+   !CHECK: %[[VAL_13:.*]] = arith.constant 1 : index
+   !CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_15:.*]] = arith.subi %[[VAL_6]], %[[VAL_13]] : index
+   !CHECK: %[[VAL_16:.*]] = omp.bounds lower_bound(%[[VAL_14]] : index) upper_bound(%[[VAL_15]] : index) extent(%[[VAL_6]] : index) stride(%[[VAL_13]] : index) start_idx(%[[VAL_13]] : index)
+   !CHECK: %[[VAL_17:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_11]], %[[VAL_13]] : index
+   !CHECK: %[[VAL_19:.*]] = omp.bounds lower_bound(%[[VAL_17]] : index) upper_bound(%[[VAL_18]] : index) extent(%[[VAL_11]] : index) stride(%[[VAL_13]] : index) start_idx(%[[VAL_13]] : index)
+   !CHECK: %[[VAL_20:.*]] = omp.map_info var_ptr(%[[VAL_12]] : !fir.ref<!fir.array<?x1024xi32>>, !fir.array<?x1024xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%[[VAL_16]], %[[VAL_19]]) -> !fir.ref<!fir.array<?x1024xi32>> {name = "a"}
+   !CHECK: %[[VAL_21:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = ""}
+   !CHECK: omp.target map_entries(%[[VAL_20]] -> %[[VAL_22:.*]], %[[VAL_21]] -> %[[VAL_23:.*]] : !fir.ref<!fir.array<?x1024xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[VAL_22]]: !fir.ref<!fir.array<?x1024xi32>>, %[[VAL_23]]: !fir.ref<i32>):
    !$omp target
-      !CHECK: %{{.*}} = fir.convert %[[ARG_1]] : (!fir.ref<!fir.array<?x1024xi32>>) -> !fir.ref<!fir.array<?xi32>>
-      !CHECK: %{{.*}} = arith.muli %{{.*}}, %[[ARG_2]] : index
-      a(11,22) = 33
+      !CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_23]] : !fir.ref<i32>
+      !CHECK: %[[VAL_25:.*]] = fir.convert %[[VAL_24]] : (i32) -> i64
+      !CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
+      !CHECK: %[[VAL_27:.*]] = fir.convert %[[VAL_25]] : (i64) -> index
+      !CHECK: %[[VAL_28:.*]] = arith.cmpi sgt, %[[VAL_27]], %[[VAL_26]] : index
+      !CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_28]], %[[VAL_27]], %[[VAL_26]] : index
+      !CHECK: %[[VAL_30:.*]] = arith.constant 33 : i32
+      !CHECK: %[[VAL_31:.*]] = fir.convert %[[VAL_22]] : (!fir.ref<!fir.array<?x1024xi32>>) -> !fir.ref<!fir.array<?xi32>>
+      !CHECK: %[[VAL_32:.*]] = arith.constant 1 : index
+      !CHECK: %[[VAL_33:.*]] = arith.constant 0 : index
+      !CHECK: %[[VAL_34:.*]] = arith.constant 11 : i64
+      !CHECK: %[[VAL_35:.*]] = fir.convert %[[VAL_34]] : (i64) -> index
+      !CHECK: %[[VAL_36:.*]] = arith.subi %[[VAL_35]], %[[VAL_32]] : index
+      !CHECK: %[[VAL_37:.*]] = arith.muli %[[VAL_32]], %[[VAL_36]] : index
+      !CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_33]] : index
+      !CHECK: %[[VAL_39:.*]] = arith.muli %[[VAL_32]], %[[VAL_29]] : index
+      !CHECK: %[[VAL_40:.*]] = arith.constant 22 : i64
+      !CHECK: %[[VAL_41:.*]] = fir.convert %[[VAL_40]] : (i64) -> index
+      !CHECK: %[[VAL_42:.*]] = arith.subi %[[VAL_41]], %[[VAL_32]] : index
+      !CHECK: %[[VAL_43:.*]] = arith.muli %[[VAL_39]], %[[VAL_42]] : index
+      !CHECK: %[[VAL_44:.*]] = arith.addi %[[VAL_43]], %[[VAL_38]] : index
+      !CHECK: %[[VAL_45:.*]] = fir.coordinate_of %[[VAL_31]], %[[VAL_44]] : (!fir.ref<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+      !CHECK: fir.store %[[VAL_30]] to %[[VAL_45]] : !fir.ref<i32>
+      a(11, 22) = 33
       !CHECK: omp.terminator
    !$omp end target
 !CHECK: }
@@ -344,8 +382,8 @@ subroutine omp_target_parallel_do
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
    !CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
    !CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[VAL_1]] : !fir.ref<i32>, i32)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
-   !CHECK: omp.target map_entries(%[[MAP1]] -> %[[VAL_2:.*]], %[[MAP2]] -> %[[VAL_3:.*]], %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, index) {
-   !CHECK: ^bb0(%[[VAL_2]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_3]]: !fir.ref<i32>, %{{.*}}: index):
+   !CHECK: omp.target map_entries(%[[MAP1]] -> %[[VAL_2:.*]], %[[MAP2]] -> %[[VAL_3:.*]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[VAL_2]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_3]]: !fir.ref<i32>):
       !CHECK-NEXT: omp.parallel
       !$omp target parallel do map(tofrom: a)
          !CHECK: %[[VAL_4:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
diff --git a/flang/test/Lower/OpenMP/array-bounds.f90 b/flang/test/Lower/OpenMP/array-bounds.f90
index d0c584bec6044ad..c0e50acb8e37c42 100644
--- a/flang/test/Lower/OpenMP/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/array-bounds.f90
@@ -22,7 +22,7 @@
 !HOST:  %[[C6:.*]] = arith.constant 4 : index
 !HOST:  %[[BOUNDS1:.*]] = omp.bounds   lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
 !HOST:  %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE_DECL]]#1 : !fir.ref<!fir.array<10xi32>>, !fir.array<10xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
-!HOST:  omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, {{.*}} -> {{.*}}, {{.*}} -> {{.*}}, {{.*}} -> {{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>, index, index) {
+!HOST:  omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {
 
 subroutine read_write_section()
     integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
@@ -73,7 +73,7 @@ end subroutine assumed_shape_array
 !HOST: %[[C2:.*]] = arith.constant 4 : index
 !HOST: %[[BOUNDS:.*]]  = omp.bounds   lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
 !HOST: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0_DECL]]#1 : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
-!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}}, {{.*}} -> {{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, index) {
+!HOST: omp.target map_entries(%[[MAP]] -> %{{.*}}, {{.*}} -> {{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
         subroutine assumed_size_array(arr_read_write)
             integer, intent(inout) :: arr_read_write(*)
 
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 86f456b847df90a..11e043241db9d2f 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -191,10 +191,11 @@ subroutine omp_target
    integer :: a(1024)
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_1]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]], %{{.*}} -> %[[ARG_1:.*]] : !fir.ref<!fir.array<1024xi32>>, index) {
-   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>, %[[ARG_1]]: index):
+   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target map(tofrom: a)
-      !CHECK: %[[VAL_2:.*]] = fir.shape %[[ARG_1]] : (index) -> !fir.shape<1>
+      !CHECK: %[[VAL_7:.*]] = arith.constant 1024 : index
+      !CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
       !CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[ARG_0]](%[[VAL_2]]) {uniq_name = "_QFomp_targetEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
       !CHECK: %[[VAL_4:.*]] = arith.constant 10 : i32
       !CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
@@ -218,10 +219,10 @@ subroutine omp_target_implicit
    !CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_2]]) {uniq_name = "_QFomp_target_implicitEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
    integer :: a(1024)
    !CHECK: %[[VAL_4:.*]] = omp.map_info var_ptr(%[[VAL_3]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)   map_clauses(implicit, tofrom) capture(ByRef) bounds(%{{.*}}) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: %[[VAL_5:.*]] = omp.map_info val(%[[VAL_0]] : index)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> index {name = ""}
-   !CHECK: omp.target   map_entries(%[[VAL_4]] -> %[[VAL_6:.*]], %[[VAL_5]] -> %[[VAL_7:.*]] : !fir.ref<!fir.array<1024xi32>>, index) {
-   !CHECK: ^bb0(%[[VAL_6]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_7]]: index):
+   !CHECK: omp.target   map_entries(%[[VAL_4]] -> %[[VAL_6:.*]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[VAL_6]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target
+      !CHECK: %[[VAL_7:.*]] = arith.constant 1024 : index
       !CHECK: %[[VAL_8:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
       !CHECK: %[[VAL_9:.*]]:2 = hlfir.declare %[[VAL_6]](%[[VAL_8]]) {uniq_name = "_QFomp_target_implicitEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
       !CHECK: %[[VAL_10:.*]] = arith.constant 10 : i32
@@ -265,29 +266,43 @@ end subroutine omp_target_implicit_nested
 ! Target implicit capture with bounds
 !===============================================================================
 
-!CHECK-LABEL: func.func @_QPomp_target_implicit_bounds(%{{.*}}: !fir.ref<i32> {fir.bindc_name = "n"}) {
+!CHECK-LABEL: func.func @_QPomp_target_implicit_bounds(
+!CHECK: %[[VAL_0:.*]]: !fir.ref<i32> {fir.bindc_name = "n"}) {
 subroutine omp_target_implicit_bounds(n)
-   !CHECK: %[[VAL_1:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
-   !CHECK: %[[VAL_2:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
-   !CHECK: %[[VAL_3:.*]] = fir.alloca !fir.array<?x1024xi32>, %[[VAL_1]] {bindc_name = "a", uniq_name = "_QFomp_target_implicit_boundsEa"}
-   !CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_1]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
-   !CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %[[VAL_3]](%[[VAL_4]]) {uniq_name = "_QFomp_target_implicit_boundsEa"} : (!fir.ref<!fir.array<?x1024xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x1024xi32>>, !fir.ref<!fir.array<?x1024xi32>>)
+   !CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFomp_target_implicit_boundsEn"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+   !CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1]]#0 : !fir.ref<i32>
+   !CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (i32) -> i64
+   !CHECK: %[[VAL_4:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
+   !CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_6:.*]] = arith.cmpi sgt, %[[VAL_4]], %[[VAL_5]] : index
+   !CHECK: %[[VAL_7:.*]] = arith.select %[[VAL_6]], %[[VAL_4]], %[[VAL_5]] : index
+   !CHECK: %[[VAL_8:.*]] = fir.alloca !fir.array<?xi32>, %[[VAL_7]] {bindc_name = "a", uniq_name = "_QFomp_target_implicit_boundsEa"}
+   !CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+   !CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_8]](%[[VAL_9]]) {uniq_name = "_QFomp_target_implicit_boundsEa"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
+   !CHECK: %[[VAL_11:.*]] = arith.constant 1 : index
+   !CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
+   !CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_7]], %[[VAL_11]] : index
    integer :: n
-   integer :: a(n, 1024)
-   !CHECK: %[[VAL_6:.*]] = omp.map_info var_ptr(%[[VAL_5]]#1 : !fir.ref<!fir.array<?x1024xi32>>, !fir.array<?x1024xi32>)   map_clauses(implicit, tofrom) capture(ByRef) bounds(%{{.*}}) -> !fir.ref<!fir.array<?x1024xi32>> {name = "a"}
-   !CHECK: %[[VAL_7:.*]] = omp.map_info val(%[[VAL_1]] : index)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> index {name = ""}
-   !CHECK: %[[VAL_8:.*]] = omp.map_info val(%[[VAL_2]] : index)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> index {name = ""}
-   !CHECK: omp.target   map_entries(%[[VAL_6]] -> %[[ARG_1:.*]], %[[VAL_7]] -> %[[ARG_2:.*]], %[[VAL_8]] -> %[[ARG_3:.*]] : !fir.ref<!fir.array<?x1024xi32>>, index, index) {
-   !CHECK: ^bb0(%[[ARG_1]]: !fir.ref<!fir.array<?x1024xi32>>, %[[ARG_2]]: index, %[[ARG_3]]: index):
+   integer :: a(n)
+   !CHECK: %[[VAL_14:.*]] = omp.bounds lower_bound(%[[VAL_12]] : index) upper_bound(%[[VAL_13]] : index) extent(%[[VAL_7]] : index) stride(%[[VAL_11]] : index) start_idx(%[[VAL_11]] : index)
+   !CHECK: %[[VAL_15:.*]] = omp.map_info var_ptr(%[[VAL_10]]#1 : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>) map_clauses(implicit, tofrom) capture(ByRef) bounds(%[[VAL_14]]) -> !fir.ref<!fir.array<?xi32>> {name = "a"}
+   !CHECK: %[[VAL_16:.*]] = omp.map_info var_ptr(%[[VAL_1]]#0 : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = ""}
+   !CHECK: omp.target map_entries(%[[VAL_15]] -> %[[VAL_17:.*]], %[[VAL_16]] -> %[[VAL_18:.*]] : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[VAL_17]]: !fir.ref<!fir.array<?xi32>>, %[[VAL_18]]: !fir.ref<i32>):
    !$omp target
-      !CHECK: %[[VAL_9:.*]] = fir.shape %[[ARG_2]], %[[ARG_3]] : (index, index) -> !fir.shape<2>
-      !CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[ARG_1]](%[[VAL_9]]) {uniq_name = "_QFomp_target_implicit_boundsEa"} : (!fir.ref<!fir.array<?x1024xi32>>, !fir.shape<2>) -> (!fir.box<!fir.array<?x1024xi32>>, !fir.ref<!fir.array<?x1024xi32>>)
-      !CHECK: %[[VAL_11:.*]] = arith.constant 33 : i32
-      !CHECK: %[[VAL_12:.*]] = arith.constant 11 : index
-      !CHECK: %[[VAL_13:.*]] = arith.constant 22 : index
-      !CHECK: %[[VAL_14:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_12]], %[[VAL_13]])  : (!fir.box<!fir.array<?x1024xi32>>, index, index) -> !fir.ref<i32>
-      !CHECK: hlfir.assign %[[VAL_11]] to %[[VAL_14]] : i32, !fir.ref<i32>
-      a(11, 22) = 33
+      !CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_18]] : !fir.ref<i32>
+      !CHECK: %[[VAL_20:.*]] = fir.convert %[[VAL_19]] : (i32) -> i64
+      !CHECK: %[[VAL_21:.*]] = arith.constant 0 : index
+      !CHECK: %[[VAL_22:.*]] = fir.convert %[[VAL_20]] : (i64) -> index
+      !CHECK: %[[VAL_23:.*]] = arith.cmpi sgt, %[[VAL_22]], %[[VAL_21]] : index
+      !CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[VAL_22]], %[[VAL_21]] : index
+      !CHECK: %[[VAL_25:.*]] = fir.shape %[[VAL_24]] : (index) -> !fir.shape<1>
+      !CHECK: %[[VAL_26:.*]]:2 = hlfir.declare %[[VAL_17]](%[[VAL_25]]) {uniq_name = "_QFomp_target_implicit_boundsEa"} : (!fir.ref<!fir.array<?xi32>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xi32>>, !fir.ref<!fir.array<?xi32>>)
+      !CHECK: %[[VAL_27:.*]] = arith.constant 22 : i32
+      !CHECK: %[[VAL_28:.*]] = arith.constant 11 : index
+      !CHECK: %[[VAL_29:.*]] = hlfir.designate %[[VAL_26]]#0 (%[[VAL_28]])  : (!fir.box<!fir.array<?xi32>>, index) -> !fir.ref<i32>
+      !CHECK: hlfir.assign %[[VAL_27]] to %[[VAL_29]] : i32, !fir.ref<i32>
+      a(11) = 22
       !CHECK: omp.terminator
    !$omp end target
 !CHECK: }
@@ -396,8 +411,8 @@ subroutine omp_target_parallel_do
    !CHECK: %[[SUB:.*]] = arith.subi %[[C1024]], %[[C1]] : index
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0_DECL]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]], %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>, index) {
-   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>, %{{.*}}: !fir.ref<i32>, %{{.*}}: index):
+   !CHECK: omp.target   map_entries(%[[MAP]] -> %[[ARG_0:.*]], %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[ARG_0]]: !fir.ref<!fir.array<1024xi32>>, %{{.*}}: !fir.ref<i32>):
       !CHECK: %[[VAL_0_DECL:.*]]:2 = hlfir.declare %[[ARG_0]](%{{.*}}) {uniq_name = "_QFomp_target_parallel_doEa"} : (!fir.ref<!fir.array<1024xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<1024xi32>>, !fir.ref<!fir.array<1024xi32>>)
       !CHECK: omp.parallel
       !$omp target parallel do map(tofrom: a)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bfb58b98884c721..8ff5380f71ad453 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1145,15 +1145,14 @@ def DataBoundsOp : OpenMP_Op<"bounds",
 }
 
 def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
-  let arguments = (ins Optional<AnyType>:$val,
-                       Optional<OpenMP_PointerLikeType>:$var_ptr,
-                       OptionalAttr<TypeAttr>:$var_type,
+  let arguments = (ins OpenMP_PointerLikeType:$var_ptr,
+                       TypeAttr:$var_type,
                        Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
                        Variadic<DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
                        OptionalAttr<UI64Attr>:$map_type,
                        OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
                        OptionalAttr<StrAttr>:$name);
-  let results = (outs AnyType:$omp_ptr);
+  let results = (outs OpenMP_PointerLikeType:$omp_ptr);
 
   let description = [{
     The MapInfoOp captures information relating to individual OpenMP map clauses
@@ -1184,9 +1183,8 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
     ```
 
     Description of arguments:
-    - `val`: The value to copy.
     - `var_ptr`: The address of variable to copy.
-    - `var_type`: The type of the variable/value to copy.
+    - `var_type`: The type of the variable to copy.
     - `var_ptr_ptr`: Used when the variable copied is a member of a class, structure
       or derived type and refers to the originating struct.
     - `bounds`: Used when copying slices of array's, pointers or pointer members of
@@ -1202,10 +1200,9 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
   }];
 
   let assemblyFormat = [{
+    `var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
     oilist(
-        `val` `(` $val `:` type($val) `)`
-      | `var_ptr` `(` $var_ptr `:` type($var_ptr) `,` $var_type `)`
-      | `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
+        `var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
       | `map_clauses` `(` custom<MapClause>($map_type) `)`
       | `capture` `(` custom<CaptureType>($map_capture_type) `)`
       | `bounds` `(` $bounds `)`
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 480af0e1307c158..20df0099cbd24d1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -890,15 +890,6 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
     if (auto MapInfoOp =
             mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
 
-      if (MapInfoOp.getVal() && MapInfoOp.getVarPtr())
-        emitError(op->getLoc(), "only one of val or var_ptr must be used");
-
-      if (!MapInfoOp.getVal() && !MapInfoOp.getVarPtr())
-        emitError(op->getLoc(), "missing val or var_ptr");
-
-      if (!MapInfoOp.getVarPtr() && MapInfoOp.getVarType().has_value())
-        emitError(op->getLoc(), "var_type supplied without var_ptr");
-
       if (!MapInfoOp.getMapType().has_value())
         emitError(op->getLoc(), "missing map type for map operand");
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index ae13a745196c42d..0de4bfdadcb5a41 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1733,13 +1733,13 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
            "missing map info operation or incorrect map info operation type");
     if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
             mapValue.getDefiningOp())) {
-      mapData.OriginalValue.push_back(moduleTranslation.lookupValue(
-          mapOp.getVarPtr() ? mapOp.getVarPtr() : mapOp.getVal()));
+      mapData.OriginalValue.push_back(
+          moduleTranslation.lookupValue(mapOp.getVarPtr()));
       mapData.Pointers.push_back(mapData.OriginalValue.back());
 
-      if (llvm::Value *refPtr = getRefPtrIfDeclareTarget(
-              mapOp.getVarPtr() ? mapOp.getVarPtr() : mapOp.getVal(),
-              moduleTranslation)) { // declare target
+      if (llvm::Value *refPtr =
+              getRefPtrIfDeclareTarget(mapOp.getVarPtr(),
+                                       moduleTranslation)) { // declare target
         mapData.IsDeclareTarget.push_back(true);
         mapData.BasePointers.push_back(refPtr);
       } else { // regular mapped variable
@@ -1747,14 +1747,10 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
         mapData.BasePointers.push_back(mapData.OriginalValue.back());
       }
 
-      mapData.Sizes.push_back(
-          getSizeInBytes(dl,
-                         mapOp.getVal() ? mapOp.getVal().getType()
-                                        : mapOp.getVarType().value(),
-                         mapOp, builder, moduleTranslation));
-      mapData.BaseType.push_back(moduleTranslation.convertType(
-          mapOp.getVal() ? mapOp.getVal().getType()
-                         : mapOp.getVarType().value()));
+      mapData.Sizes.push_back(getSizeInBytes(dl, mapOp.getVarType(), mapOp,
+                                             builder, moduleTranslation));
+      mapData.BaseType.push_back(
+          moduleTranslation.convertType(mapOp.getVarType()));
       mapData.MapClause.push_back(mapOp.getOperation());
       mapData.Types.push_back(
           llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
@@ -1804,8 +1800,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
     if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
       if (mapInfoOp.getMapCaptureType().value() ==
               mlir::omp::VariableCaptureKind::ByCopy &&
-          !(mapInfoOp.getVarType().has_value() &&
-            mapInfoOp.getVarType()->isa<LLVM::LLVMPointerType>()))
+          !mapInfoOp.getVarType().isa<LLVM::LLVMPointerType>())
         mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
 
     combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
@@ -2337,19 +2332,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
   SmallVector<Value> mapOperands = targetOp.getMapOperands();
 
-  // Remove mapOperands/blockArgs that have no use inside the region.
-  assert(mapOperands.size() == targetRegion.getNumArguments() &&
-         "Number of mapOperands must be same as block_arguments");
-  for (size_t i = 0; i < mapOperands.size(); i++) {
-    if (targetRegion.getArgument(i).use_empty()) {
-      targetRegion.eraseArgument(i);
-      mapOperands.erase(&mapOperands[i]);
-      i--;
-    }
-  }
-
   LogicalResult bodyGenStatus = success();
-
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   auto bodyCB = [&](InsertPointTy allocaIP,
                     InsertPointTy codeGenIP) -> InsertPointTy {
@@ -2358,8 +2341,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     for (auto &mapOp : mapOperands) {
       auto mapInfoOp =
           mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
-      llvm::Value *mapOpValue = moduleTranslation.lookupValue(
-          mapInfoOp.getVarPtr() ? mapInfoOp.getVarPtr() : mapInfoOp.getVal());
+      llvm::Value *mapOpValue =
+          moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
       const auto &arg = targetRegion.front().getArgument(argIndex);
       moduleTranslation.mapValue(arg, mapOpValue);
       argIndex++;
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 42e9fb1c64baec7..e54808f6cfdee58 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1658,40 +1658,4 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   return
 }
 
-// -----
-
-func.func @omp_map1(%map1: memref<?xi32>, %map2: i32) {
-  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) val(%map2 : i32)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
-  // expected-error @below {{only one of val or var_ptr must be used}}
-  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
-    ^bb0(%arg0: memref<?xi32>):
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_map2(%map1: memref<?xi32>, %map2: i32) {
-  %mapv = omp.map_info var_ptr( : , tensor<?xi32>) val(%map2 : i32)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
-  // expected-error @below {{var_type supplied without var_ptr}}
-  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
-    ^bb0(%arg0: memref<?xi32>):
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_map3(%map1: memref<?xi32>, %map2: i32) {
-  %mapv = omp.map_info   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
-  // expected-error @below {{missing val or var_ptr}}
-  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
-    ^bb0(%arg0: memref<?xi32>):
-    omp.terminator
-  }
-  return
-}
-
 llvm.mlir.global internal @_QFsubEx() : i32
diff --git a/openmp/libomptarget/test/offloading/fortran/constant-arr-index.f90 b/openmp/libomptarget/test/offloading/fortran/constant-arr-index.f90
index 91dc30cf9604c4b..9064f60896f105f 100644
--- a/openmp/libomptarget/test/offloading/fortran/constant-arr-index.f90
+++ b/openmp/libomptarget/test/offloading/fortran/constant-arr-index.f90
@@ -1,5 +1,5 @@
 ! Basic offloading test with a target region
-! that checks constant indexing on device 
+! that checks constant indexing on device
 ! correctly works (regression test for prior
 ! bug).
 ! REQUIRES: flang, amdgcn-amd-amdhsa
@@ -19,8 +19,8 @@ program main
      sp(5) = 10
   !$omp end target
 
-   ! print *, sp(1)
-   ! print *, sp(5)
+   print *, sp(1)
+   print *, sp(5)
 end program
 
 ! CHECK: 20



More information about the flang-commits mailing list