[flang-commits] [flang] [OpenMP][MLIR] Add "IsolatedFromAbove" and "OutlineableOpenMPOpInterface" trait to omp.target (PR #67164)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Mon Sep 25 04:43:35 PDT 2023


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/67164

>From 4bcfbcaca1ba1a98db7eb97daed2f0d89b92aa7e Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:12:06 +0100
Subject: [PATCH] [OpenMP][MLIR] Add "IsolatedFromAbove" and
 "OutlineableOpenMPOpInterface" trait to omp.target

This patch adds the PFT lowering changes required for adding the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target.

Key Changes:
	- Add IsolatedFromAbove and OutlineableOpenMPOpInterface traits to target op in MLIR.
	- Main reason for this change is to prevent CSE and other similar optimisations from crossing region boundaries for target operations. The link below has the discourse discussion surrounding this issue.
	- Move implicit operand capturing to the PFT lowering stage.
	- Update related tests.

Related discussion: https://discourse.llvm.org/t/rfc-prevent-cse-from-removing-expressions-inside-some-non-isolatedfromabove-operation-regions/73150
---
 flang/lib/Lower/OpenMP.cpp                    | 127 +++++++++++++++---
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |  10 +-
 flang/test/Lower/OpenMP/FIR/location.f90      |   2 +-
 flang/test/Lower/OpenMP/FIR/target.f90        |  36 ++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   8 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  12 ++
 6 files changed, 153 insertions(+), 42 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 5f5e968eaaa6414..c369cba4255d4f4 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -538,7 +538,11 @@ class ClauseProcessor {
                   const llvm::omp::Directive &directive,
                   Fortran::semantics::SemanticsContext &semanticsContext,
                   Fortran::lower::StatementContext &stmtCtx,
-                  llvm::SmallVectorImpl<mlir::Value> &mapOperands) const;
+                  llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+                  llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
+                  llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+                  llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                      *mapSymbols = nullptr) const;
   bool processReduction(
       mlir::Location currentLocation,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
@@ -1665,7 +1669,7 @@ static mlir::omp::MapInfoOp
 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 mlir::Value baseAddr, std::stringstream &name,
                 mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
-                mlir::omp::VariableCaptureKind mapCaptureType, bool implicit,
+                mlir::omp::VariableCaptureKind mapCaptureType,
                 mlir::Type retTy) {
   mlir::Value varPtrPtr;
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
@@ -1676,7 +1680,6 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::omp::MapInfoOp op =
       builder.create<mlir::omp::MapInfoOp>(loc, retTy, baseAddr);
   op.setNameAttr(builder.getStringAttr(name.str()));
-  op.setImplicit(implicit);
   op.setMapType(mapType);
   op.setMapCaptureType(mapCaptureType);
 
@@ -1695,7 +1698,11 @@ bool ClauseProcessor::processMap(
     mlir::Location currentLocation, const llvm::omp::Directive &directive,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::StatementContext &stmtCtx,
-    llvm::SmallVectorImpl<mlir::Value> &mapOperands) const {
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+    llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+    llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
+    const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   return findRepeatableClause<ClauseTy::Map>(
       [&](const ClauseTy::Map *mapClause,
@@ -1755,13 +1762,20 @@ bool ClauseProcessor::processMap(
           // Explicit map captures are captured ByRef by default,
           // optimisation passes may alter this to ByCopy or other capture
           // types to optimise
-          mapOperands.push_back(createMapInfoOp(
+          mlir::Value mapOp = createMapInfoOp(
               firOpBuilder, clauseLocation, baseAddr, asFortran, bounds,
               static_cast<
                   std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
                   mapTypeBits),
-              mlir::omp::VariableCaptureKind::ByRef, false,
-              baseAddr.getType()));
+              mlir::omp::VariableCaptureKind::ByRef, baseAddr.getType());
+
+          mapOperands.push_back(mapOp);
+          if (mapSymTypes)
+            mapSymTypes->push_back(mapOp.getType());
+          if (mapSymLocs)
+            mapSymLocs->push_back(mapOp.getLoc());
+          if (mapSymbols)
+            mapSymbols->push_back(getOmpObjectSymbol(ompObject));
         }
       });
 }
@@ -2142,7 +2156,7 @@ static void createBodyOfOp(
   }
 }
 
-static void createBodyOfTargetDataOp(
+static void genBodyOfTargetDataOp(
     Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
     const llvm::SmallVector<mlir::Type> &useDeviceTypes,
     const llvm::SmallVector<mlir::Location> &useDeviceLocs,
@@ -2356,8 +2370,8 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
       currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
       deviceAddrOperands, mapOperands);
-  createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
-                           useDeviceSymbols, currentLocation);
+  genBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+                        useDeviceSymbols, currentLocation);
   return dataOp;
 }
 
@@ -2400,6 +2414,52 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
                                    deviceOperand, nowaitAttr, mapOperands);
 }
 
+static void genBodyOfTargetOp(
+    Fortran::lower::AbstractConverter &converter, mlir::omp::TargetOp &targetOp,
+    const llvm::SmallVector<mlir::Type> &mapSymTypes,
+    const llvm::SmallVector<mlir::Location> &mapSymLocs,
+    const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
+    const mlir::Location &currentLocation) {
+  assert(mapSymTypes.size() == mapSymLocs.size() &&
+         mapSymTypes.size() == mapSymbols.size());
+
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Region &region = targetOp.getRegion();
+
+  firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
+  firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+  firOpBuilder.setInsertionPointToStart(&region.front());
+
+  unsigned argIndex = 0;
+  for (const Fortran::semantics::Symbol *sym : mapSymbols) {
+    const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
+    fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+    mlir::Value val = fir::getBase(arg);
+    extVal.match(
+        [&](const fir::BoxValue &v) {
+          converter.bindSymbol(*sym, fir::BoxValue(val, v.getLBounds(),
+                                                   v.getExplicitParameters(),
+                                                   v.getExplicitExtents()));
+        },
+        [&](const fir::MutableBoxValue &v) {
+          converter.bindSymbol(*sym,
+                               fir::MutableBoxValue(val, v.getLBounds(),
+                                                    v.getMutableProperties()));
+        },
+        [&](const fir::ArrayBoxValue &v) {
+          converter.bindSymbol(*sym, fir::ArrayBoxValue(val, v.getExtents(),
+                                                        v.getLBounds(),
+                                                        v.getSourceBox()));
+        },
+        [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, val); },
+        [&](const auto &) {
+          TODO(converter.getCurrentLocation(),
+               "target map clause operand unsupported type");
+        });
+    argIndex++;
+  }
+}
+
 static mlir::omp::TargetOp
 genTargetOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::pft::Evaluation &eval,
@@ -2411,6 +2471,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
   mlir::UnitAttr nowaitAttr;
   llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Type> mapSymTypes;
+  llvm::SmallVector<mlir::Location> mapSymLocs;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(stmtCtx,
@@ -2420,7 +2483,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processThreadLimit(stmtCtx, threadLimitOperand);
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
-                mapOperands);
+                mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Private,
                  Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
@@ -2433,10 +2496,44 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::parser::OmpClause::Defaultmap>(
       currentLocation, llvm::omp::Directive::OMPD_target);
 
-  return genOpWithBody<mlir::omp::TargetOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
-      ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr,
-      mapOperands);
+  auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
+    if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+      mlir::Value baseOp = converter.getSymbolAddress(sym);
+      if (!baseOp)
+        if (const auto *details = sym.template detailsIf<
+                                  Fortran::semantics::HostAssocDetails>()) {
+          baseOp = converter.getSymbolAddress(details->symbol());
+          converter.copySymbolBinding(details->symbol(), sym);
+        }
+
+      if (baseOp) {
+        llvm::SmallVector<mlir::Value> bounds;
+        std::stringstream name;
+        name << sym.name().ToString();
+        mlir::Value mapOp = createMapInfoOp(
+            converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, name, bounds,
+            static_cast<
+                std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
+            mlir::omp::VariableCaptureKind::ByCopy, baseOp.getType());
+        mapOperands.push_back(mapOp);
+        mapSymTypes.push_back(baseOp.getType());
+        mapSymLocs.push_back(baseOp.getLoc());
+        mapSymbols.push_back(&sym);
+      }
+    }
+  };
+  Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
+
+  auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
+      currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
+      nowaitAttr, mapOperands);
+
+  genBodyOfTargetOp(converter, targetOp, mapSymTypes, mapSymLocs, mapSymbols,
+                    currentLocation);
+
+  return targetOp;
 }
 
 static mlir::omp::TeamsOp
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 501a2bee29321a3..4054b57b02a90f7 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -420,7 +420,7 @@ func.func @_QPomp_target_data_empty() {
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
 // CHECK: omp.target_data   use_device_addr(%1 : !llvm.ptr<array<1024 x i32>>) {
-// CHECK: }
+// CHECK: } 
 
 // -----
 
@@ -434,11 +434,12 @@ func.func @_QPomp_target() {
   %2 = omp.bounds   lower_bound(%c0 : index) upper_bound(%1 : index) extent(%c512 : index) stride(%c1 : index) start_idx(%c1 : index)
   %3 = omp.map_info var_ptr(%0 : !fir.ref<!fir.array<512xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%2) -> !fir.ref<!fir.array<512xi32>> {name = "a"}
   omp.target   thread_limit(%c64_i32 : i32) map_entries(%3 : !fir.ref<!fir.array<512xi32>>) {
+    ^bb0(%arg0: !fir.ref<!fir.array<512xi32>>):
     %c10_i32 = arith.constant 10 : i32
     %c1_i64 = arith.constant 1 : i64
     %c1_i64_0 = arith.constant 1 : i64
     %4 = arith.subi %c1_i64, %c1_i64_0 : i64
-    %5 = fir.coordinate_of %0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
+    %5 = fir.coordinate_of %arg0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
     fir.store %c10_i32 to %5 : !fir.ref<i32>
     omp.terminator
   }
@@ -456,11 +457,12 @@ func.func @_QPomp_target() {
 // CHECK:           %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[LOWER]] : i64) upper_bound(%[[UPPER]] : i64) extent(%[[EXTENT]] : i64) stride(%[[STRIDE]] : i64) start_idx(%[[STRIDE]] : i64)
 // CHECK:           %[[MAP:.*]] = omp.map_info var_ptr(%2 : !llvm.ptr<array<512 x i32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !llvm.ptr<array<512 x i32>> {name = "a"}
 // CHECK:           omp.target   thread_limit(%[[VAL_2]] : i32) map_entries(%[[MAP]] : !llvm.ptr<array<512 x i32>>) {
+// CHECK:           ^bb0(%[[ARG_0:.*]]: !llvm.ptr<array<512 x i32>>):
 // CHECK:             %[[VAL_3:.*]] = llvm.mlir.constant(10 : i32) : i32
 // CHECK:             %[[VAL_4:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:             %[[VAL_5:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:             %[[VAL_6:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK:             %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_1]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
+// CHECK:             %[[VAL_7:.*]] = llvm.getelementptr %[[ARG_0]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
 // CHECK:             llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr<i32>
 // CHECK:             omp.terminator
 // CHECK:           }
@@ -827,4 +829,4 @@ func.func @sub_() {
     omp.terminator
   }
   return
-} 
+}
diff --git a/flang/test/Lower/OpenMP/FIR/location.f90 b/flang/test/Lower/OpenMP/FIR/location.f90
index 0e36e09b19e1942..84bbd1605179262 100644
--- a/flang/test/Lower/OpenMP/FIR/location.f90
+++ b/flang/test/Lower/OpenMP/FIR/location.f90
@@ -17,7 +17,7 @@ subroutine sub_parallel()
 !CHECK-LABEL: sub_target
 subroutine sub_target()
   print *, x
-!CHECK: omp.target  {
+!CHECK: omp.target
   !$omp target
     print *, x
 !CHECK:   omp.terminator loc(#[[TAR_LOC:.*]])
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 9b1fb5c15ac1d2d..749ddff523500c5 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -190,12 +190,13 @@ subroutine omp_target
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
    !CHECK: omp.target   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target map(tofrom: a)
       !CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
       !CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64
       !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
       !CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_3]] : i64
-      !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+      !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
       !CHECK: fir.store %[[VAL_1]] to %[[VAL_5]] : !fir.ref<i32>
       a(1) = 10
    !CHECK: omp.terminator
@@ -213,6 +214,7 @@ subroutine omp_target_thread_limit
    !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a"}
    !CHECK: omp.target   thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] : !fir.ref<i32>) {
+   !CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
    !$omp target map(tofrom: a) thread_limit(64)
       a = 10
    !CHECK: omp.terminator
@@ -274,23 +276,25 @@ subroutine omp_target_parallel_do
    !CHECK: %[[C0:.*]] = arith.constant 0 : index
    !CHECK: %[[SUB:.*]] = arith.subi %[[C1024]], %[[C1]] : index
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
-   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[VAL_1]] : !fir.ref<i32>)   map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+   !CHECK: omp.target   map_entries(%[[MAP1]], %[[MAP2]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[VAL_2:.*]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_3:.*]]: !fir.ref<i32>):
       !CHECK-NEXT: omp.parallel
       !$omp target parallel do map(tofrom: a)
-         !CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
-         !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
-         !CHECK: %[[VAL_4:.*]] = arith.constant 1024 : i32
+         !CHECK: %[[VAL_4:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
          !CHECK: %[[VAL_5:.*]] = arith.constant 1 : i32
-         !CHECK: omp.wsloop   for  (%[[VAL_6:.*]]) : i32 = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_5]]) {
-         !CHECK: fir.store %[[VAL_6]] to %[[VAL_2]] : !fir.ref<i32>
-         !CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
-         !CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
-         !CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> i64
-         !CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64
-         !CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_9]], %[[VAL_10]] : i64
-         !CHECK: %[[VAL_12:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_11]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
-         !CHECK: fir.store %[[VAL_7]] to %[[VAL_12]] : !fir.ref<i32>
+         !CHECK: %[[VAL_6:.*]] = arith.constant 1024 : i32
+         !CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32
+         !CHECK: omp.wsloop   for  (%[[VAL_8:.*]]) : i32 = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_7]]) {
+         !CHECK: fir.store %[[VAL_8]] to %[[VAL_4]] : !fir.ref<i32>
+         !CHECK: %[[VAL_9:.*]] = arith.constant 10 : i32
+         !CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+         !CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (i32) -> i64
+         !CHECK: %[[VAL_12:.*]] = arith.constant 1 : i64
+         !CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_11]], %[[VAL_12]] : i64
+         !CHECK: %[[VAL_14:.*]] = fir.coordinate_of %[[VAL_2]], %[[VAL_13]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+         !CHECK: fir.store %[[VAL_9]] to %[[VAL_14]] : !fir.ref<i32>
          do i = 1, 1024
             a(i) = 10
          end do
@@ -301,4 +305,4 @@ subroutine omp_target_parallel_do
    !CHECK: omp.terminator
    !CHECK: }
    !$omp end target parallel do
-end subroutine omp_target_parallel_do
+ end subroutine omp_target_parallel_do
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bcd60d8046c8925..248b92a08649c88 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1149,7 +1149,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
                        Variadic<DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
                        OptionalAttr<UI64Attr>:$map_type,
                        OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
-                       DefaultValuedAttr<BoolAttr, "false">:$implicit,
                        OptionalAttr<StrAttr>:$name);
   let results = (outs OpenMP_PointerLikeType:$omp_ptr);
 
@@ -1177,7 +1176,7 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
     ```
     =>
     ```mlir
-    omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef) implicit(false)
+    omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef)
       name(index)
     ```
 
@@ -1189,9 +1188,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
       objects (e.g. derived types or classes), indicates the bounds to be copied
       of the variable. When it's an array slice it is in rank order where rank 0
       is the inner-most dimension.
-    - `implicit`: indicates where the map item has been specified explicitly in a
-      map clause or captured implicitly by being used in a target region with no
-      map or other data mapping construct.
     - 'map_clauses': OpenMP map type for this map capture, for example: from, to and
       always. It's a bitfield composed of the OpenMP runtime flags stored in
       OpenMPOffloadMappingFlags.
@@ -1375,7 +1371,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
 
-def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
+def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
   let summary = "target construct";
   let description = [{
     The target construct includes a region of code which is to be executed
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf9355ed62676b..bb4fc191432957c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -693,6 +693,12 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "always")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
 
+    if (mapTypeMod == "literal")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+    if (mapTypeMod == "implicit")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+
     if (mapTypeMod == "close")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
 
@@ -740,6 +746,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
   if (mapTypeToBitFlag(mapTypeBits,
                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
     mapTypeStrs.push_back("always");
+  if (mapTypeToBitFlag(mapTypeBits,
+                       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL))
+    mapTypeStrs.push_back("literal");
+  if (mapTypeToBitFlag(mapTypeBits,
+                       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
+    mapTypeStrs.push_back("implicit");
   if (mapTypeToBitFlag(mapTypeBits,
                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
     mapTypeStrs.push_back("close");



More information about the flang-commits mailing list