[Mlir-commits] [mlir] [OpenMP][MLIR] Add "IsolatedFromAbove" and "OutlineableOpenMPOpInterface" trait to omp.target (PR #67318)

Akash Banerjee llvmlistbot at llvm.org
Mon Sep 25 04:55:48 PDT 2023


https://github.com/TIFitis created https://github.com/llvm/llvm-project/pull/67318

This patch adds the MLIR translation changes required for add the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target. It links the newly added block arguments to their corresponding llvm values.

Depends on #67164.

>From 4bcfbcaca1ba1a98db7eb97daed2f0d89b92aa7e Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:12:06 +0100
Subject: [PATCH 1/2] [OpenMP][MLIR] Add "IsolatedFromAbove" and
 "OutlineableOpenMPOpInterface" trait to omp.target

This patch adds the PFT lowering changes required for adding the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target.

Key Changes:
	- Add IsolatedFromAbove and OutlineableOpenMPOpInterface traits to target op in MLIR.
	- Main reason for this change is to prevent CSE and other similar optimisations from crossing region boundaries for target operations. The link below has the discourse discussion surrounding this issue.
	- Move implicit operand capturing to the PFT lowering stage.
	- Update related tests.

Related discussion: https://discourse.llvm.org/t/rfc-prevent-cse-from-removing-expressions-inside-some-non-isolatedfromabove-operation-regions/73150
---
 flang/lib/Lower/OpenMP.cpp                    | 127 +++++++++++++++---
 .../Fir/convert-to-llvm-openmp-and-fir.fir    |  10 +-
 flang/test/Lower/OpenMP/FIR/location.f90      |   2 +-
 flang/test/Lower/OpenMP/FIR/target.f90        |  36 ++---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   8 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  12 ++
 6 files changed, 153 insertions(+), 42 deletions(-)

diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 5f5e968eaaa6414..c369cba4255d4f4 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -538,7 +538,11 @@ class ClauseProcessor {
                   const llvm::omp::Directive &directive,
                   Fortran::semantics::SemanticsContext &semanticsContext,
                   Fortran::lower::StatementContext &stmtCtx,
-                  llvm::SmallVectorImpl<mlir::Value> &mapOperands) const;
+                  llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+                  llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr,
+                  llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
+                  llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
+                      *mapSymbols = nullptr) const;
   bool processReduction(
       mlir::Location currentLocation,
       llvm::SmallVectorImpl<mlir::Value> &reductionVars,
@@ -1665,7 +1669,7 @@ static mlir::omp::MapInfoOp
 createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
                 mlir::Value baseAddr, std::stringstream &name,
                 mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
-                mlir::omp::VariableCaptureKind mapCaptureType, bool implicit,
+                mlir::omp::VariableCaptureKind mapCaptureType,
                 mlir::Type retTy) {
   mlir::Value varPtrPtr;
   if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
@@ -1676,7 +1680,6 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
   mlir::omp::MapInfoOp op =
       builder.create<mlir::omp::MapInfoOp>(loc, retTy, baseAddr);
   op.setNameAttr(builder.getStringAttr(name.str()));
-  op.setImplicit(implicit);
   op.setMapType(mapType);
   op.setMapCaptureType(mapCaptureType);
 
@@ -1695,7 +1698,11 @@ bool ClauseProcessor::processMap(
     mlir::Location currentLocation, const llvm::omp::Directive &directive,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::StatementContext &stmtCtx,
-    llvm::SmallVectorImpl<mlir::Value> &mapOperands) const {
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands,
+    llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
+    llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
+    llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
+    const {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   return findRepeatableClause<ClauseTy::Map>(
       [&](const ClauseTy::Map *mapClause,
@@ -1755,13 +1762,20 @@ bool ClauseProcessor::processMap(
           // Explicit map captures are captured ByRef by default,
           // optimisation passes may alter this to ByCopy or other capture
           // types to optimise
-          mapOperands.push_back(createMapInfoOp(
+          mlir::Value mapOp = createMapInfoOp(
               firOpBuilder, clauseLocation, baseAddr, asFortran, bounds,
               static_cast<
                   std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
                   mapTypeBits),
-              mlir::omp::VariableCaptureKind::ByRef, false,
-              baseAddr.getType()));
+              mlir::omp::VariableCaptureKind::ByRef, baseAddr.getType());
+
+          mapOperands.push_back(mapOp);
+          if (mapSymTypes)
+            mapSymTypes->push_back(mapOp.getType());
+          if (mapSymLocs)
+            mapSymLocs->push_back(mapOp.getLoc());
+          if (mapSymbols)
+            mapSymbols->push_back(getOmpObjectSymbol(ompObject));
         }
       });
 }
@@ -2142,7 +2156,7 @@ static void createBodyOfOp(
   }
 }
 
-static void createBodyOfTargetDataOp(
+static void genBodyOfTargetDataOp(
     Fortran::lower::AbstractConverter &converter, mlir::omp::DataOp &dataOp,
     const llvm::SmallVector<mlir::Type> &useDeviceTypes,
     const llvm::SmallVector<mlir::Location> &useDeviceLocs,
@@ -2356,8 +2370,8 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
   auto dataOp = converter.getFirOpBuilder().create<mlir::omp::DataOp>(
       currentLocation, ifClauseOperand, deviceOperand, devicePtrOperands,
       deviceAddrOperands, mapOperands);
-  createBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
-                           useDeviceSymbols, currentLocation);
+  genBodyOfTargetDataOp(converter, dataOp, useDeviceTypes, useDeviceLocs,
+                        useDeviceSymbols, currentLocation);
   return dataOp;
 }
 
@@ -2400,6 +2414,52 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
                                    deviceOperand, nowaitAttr, mapOperands);
 }
 
+static void genBodyOfTargetOp(
+    Fortran::lower::AbstractConverter &converter, mlir::omp::TargetOp &targetOp,
+    const llvm::SmallVector<mlir::Type> &mapSymTypes,
+    const llvm::SmallVector<mlir::Location> &mapSymLocs,
+    const llvm::SmallVector<const Fortran::semantics::Symbol *> &mapSymbols,
+    const mlir::Location &currentLocation) {
+  assert(mapSymTypes.size() == mapSymLocs.size() &&
+         mapSymTypes.size() == mapSymbols.size());
+
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Region &region = targetOp.getRegion();
+
+  firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
+  firOpBuilder.create<mlir::omp::TerminatorOp>(currentLocation);
+  firOpBuilder.setInsertionPointToStart(&region.front());
+
+  unsigned argIndex = 0;
+  for (const Fortran::semantics::Symbol *sym : mapSymbols) {
+    const mlir::BlockArgument &arg = region.front().getArgument(argIndex);
+    fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
+    mlir::Value val = fir::getBase(arg);
+    extVal.match(
+        [&](const fir::BoxValue &v) {
+          converter.bindSymbol(*sym, fir::BoxValue(val, v.getLBounds(),
+                                                   v.getExplicitParameters(),
+                                                   v.getExplicitExtents()));
+        },
+        [&](const fir::MutableBoxValue &v) {
+          converter.bindSymbol(*sym,
+                               fir::MutableBoxValue(val, v.getLBounds(),
+                                                    v.getMutableProperties()));
+        },
+        [&](const fir::ArrayBoxValue &v) {
+          converter.bindSymbol(*sym, fir::ArrayBoxValue(val, v.getExtents(),
+                                                        v.getLBounds(),
+                                                        v.getSourceBox()));
+        },
+        [&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, val); },
+        [&](const auto &) {
+          TODO(converter.getCurrentLocation(),
+               "target map clause operand unsupported type");
+        });
+    argIndex++;
+  }
+}
+
 static mlir::omp::TargetOp
 genTargetOp(Fortran::lower::AbstractConverter &converter,
             Fortran::lower::pft::Evaluation &eval,
@@ -2411,6 +2471,9 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifClauseOperand, deviceOperand, threadLimitOperand;
   mlir::UnitAttr nowaitAttr;
   llvm::SmallVector<mlir::Value> mapOperands;
+  llvm::SmallVector<mlir::Type> mapSymTypes;
+  llvm::SmallVector<mlir::Location> mapSymLocs;
+  llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
 
   ClauseProcessor cp(converter, clauseList);
   cp.processIf(stmtCtx,
@@ -2420,7 +2483,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
   cp.processThreadLimit(stmtCtx, threadLimitOperand);
   cp.processNowait(nowaitAttr);
   cp.processMap(currentLocation, directive, semanticsContext, stmtCtx,
-                mapOperands);
+                mapOperands, &mapSymTypes, &mapSymLocs, &mapSymbols);
   cp.processTODO<Fortran::parser::OmpClause::Private,
                  Fortran::parser::OmpClause::Depend,
                  Fortran::parser::OmpClause::Firstprivate,
@@ -2433,10 +2496,44 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
                  Fortran::parser::OmpClause::Defaultmap>(
       currentLocation, llvm::omp::Directive::OMPD_target);
 
-  return genOpWithBody<mlir::omp::TargetOp>(
-      converter, eval, currentLocation, outerCombined, &clauseList,
-      ifClauseOperand, deviceOperand, threadLimitOperand, nowaitAttr,
-      mapOperands);
+  auto captureImplicitMap = [&](const Fortran::semantics::Symbol &sym) {
+    if (llvm::find(mapSymbols, &sym) == mapSymbols.end()) {
+      mlir::Value baseOp = converter.getSymbolAddress(sym);
+      if (!baseOp)
+        if (const auto *details = sym.template detailsIf<
+                                  Fortran::semantics::HostAssocDetails>()) {
+          baseOp = converter.getSymbolAddress(details->symbol());
+          converter.copySymbolBinding(details->symbol(), sym);
+        }
+
+      if (baseOp) {
+        llvm::SmallVector<mlir::Value> bounds;
+        std::stringstream name;
+        name << sym.name().ToString();
+        mlir::Value mapOp = createMapInfoOp(
+            converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, name, bounds,
+            static_cast<
+                std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
+                llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
+            mlir::omp::VariableCaptureKind::ByCopy, baseOp.getType());
+        mapOperands.push_back(mapOp);
+        mapSymTypes.push_back(baseOp.getType());
+        mapSymLocs.push_back(baseOp.getLoc());
+        mapSymbols.push_back(&sym);
+      }
+    }
+  };
+  Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);
+
+  auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
+      currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
+      nowaitAttr, mapOperands);
+
+  genBodyOfTargetOp(converter, targetOp, mapSymTypes, mapSymLocs, mapSymbols,
+                    currentLocation);
+
+  return targetOp;
 }
 
 static mlir::omp::TeamsOp
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index 501a2bee29321a3..4054b57b02a90f7 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -420,7 +420,7 @@ func.func @_QPomp_target_data_empty() {
 
 // CHECK-LABEL:   llvm.func @_QPomp_target_data_empty
 // CHECK: omp.target_data   use_device_addr(%1 : !llvm.ptr<array<1024 x i32>>) {
-// CHECK: }
+// CHECK: } 
 
 // -----
 
@@ -434,11 +434,12 @@ func.func @_QPomp_target() {
   %2 = omp.bounds   lower_bound(%c0 : index) upper_bound(%1 : index) extent(%c512 : index) stride(%c1 : index) start_idx(%c1 : index)
   %3 = omp.map_info var_ptr(%0 : !fir.ref<!fir.array<512xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%2) -> !fir.ref<!fir.array<512xi32>> {name = "a"}
   omp.target   thread_limit(%c64_i32 : i32) map_entries(%3 : !fir.ref<!fir.array<512xi32>>) {
+    ^bb0(%arg0: !fir.ref<!fir.array<512xi32>>):
     %c10_i32 = arith.constant 10 : i32
     %c1_i64 = arith.constant 1 : i64
     %c1_i64_0 = arith.constant 1 : i64
     %4 = arith.subi %c1_i64, %c1_i64_0 : i64
-    %5 = fir.coordinate_of %0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
+    %5 = fir.coordinate_of %arg0, %4 : (!fir.ref<!fir.array<512xi32>>, i64) -> !fir.ref<i32>
     fir.store %c10_i32 to %5 : !fir.ref<i32>
     omp.terminator
   }
@@ -456,11 +457,12 @@ func.func @_QPomp_target() {
 // CHECK:           %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[LOWER]] : i64) upper_bound(%[[UPPER]] : i64) extent(%[[EXTENT]] : i64) stride(%[[STRIDE]] : i64) start_idx(%[[STRIDE]] : i64)
 // CHECK:           %[[MAP:.*]] = omp.map_info var_ptr(%2 : !llvm.ptr<array<512 x i32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !llvm.ptr<array<512 x i32>> {name = "a"}
 // CHECK:           omp.target   thread_limit(%[[VAL_2]] : i32) map_entries(%[[MAP]] : !llvm.ptr<array<512 x i32>>) {
+// CHECK:           ^bb0(%[[ARG_0:.*]]: !llvm.ptr<array<512 x i32>>):
 // CHECK:             %[[VAL_3:.*]] = llvm.mlir.constant(10 : i32) : i32
 // CHECK:             %[[VAL_4:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:             %[[VAL_5:.*]] = llvm.mlir.constant(1 : i64) : i64
 // CHECK:             %[[VAL_6:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK:             %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_1]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
+// CHECK:             %[[VAL_7:.*]] = llvm.getelementptr %[[ARG_0]][0, %[[VAL_6]]] : (!llvm.ptr<array<512 x i32>>, i64) -> !llvm.ptr<i32>
 // CHECK:             llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr<i32>
 // CHECK:             omp.terminator
 // CHECK:           }
@@ -827,4 +829,4 @@ func.func @sub_() {
     omp.terminator
   }
   return
-} 
+}
diff --git a/flang/test/Lower/OpenMP/FIR/location.f90 b/flang/test/Lower/OpenMP/FIR/location.f90
index 0e36e09b19e1942..84bbd1605179262 100644
--- a/flang/test/Lower/OpenMP/FIR/location.f90
+++ b/flang/test/Lower/OpenMP/FIR/location.f90
@@ -17,7 +17,7 @@ subroutine sub_parallel()
 !CHECK-LABEL: sub_target
 subroutine sub_target()
   print *, x
-!CHECK: omp.target  {
+!CHECK: omp.target
   !$omp target
     print *, x
 !CHECK:   omp.terminator loc(#[[TAR_LOC:.*]])
diff --git a/flang/test/Lower/OpenMP/FIR/target.f90 b/flang/test/Lower/OpenMP/FIR/target.f90
index 9b1fb5c15ac1d2d..749ddff523500c5 100644
--- a/flang/test/Lower/OpenMP/FIR/target.f90
+++ b/flang/test/Lower/OpenMP/FIR/target.f90
@@ -190,12 +190,13 @@ subroutine omp_target
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound({{.*}}) upper_bound({{.*}}) extent({{.*}}) stride({{.*}}) start_idx({{.*}})
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
    !CHECK: omp.target   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: ^bb0(%[[ARG_0:.*]]: !fir.ref<!fir.array<1024xi32>>):
    !$omp target map(tofrom: a)
       !CHECK: %[[VAL_1:.*]] = arith.constant 10 : i32
       !CHECK: %[[VAL_2:.*]] = arith.constant 1 : i64
       !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
       !CHECK: %[[VAL_4:.*]] = arith.subi %[[VAL_2]], %[[VAL_3]] : i64
-      !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+      !CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG_0]], %[[VAL_4]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
       !CHECK: fir.store %[[VAL_1]] to %[[VAL_5]] : !fir.ref<i32>
       a(1) = 10
    !CHECK: omp.terminator
@@ -213,6 +214,7 @@ subroutine omp_target_thread_limit
    !CHECK: %[[VAL_1:.*]] = arith.constant 64 : i32
    !CHECK: %[[MAP:.*]] = omp.map_info var_ptr({{.*}})   map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a"}
    !CHECK: omp.target   thread_limit(%[[VAL_1]] : i32) map_entries(%[[MAP]] : !fir.ref<i32>) {
+   !CHECK: ^bb0(%{{.*}}: !fir.ref<i32>):
    !$omp target map(tofrom: a) thread_limit(64)
       a = 10
    !CHECK: omp.terminator
@@ -274,23 +276,25 @@ subroutine omp_target_parallel_do
    !CHECK: %[[C0:.*]] = arith.constant 0 : index
    !CHECK: %[[SUB:.*]] = arith.subi %[[C1024]], %[[C1]] : index
    !CHECK: %[[BOUNDS:.*]] = omp.bounds   lower_bound(%[[C0]] : index) upper_bound(%[[SUB]] : index) extent(%[[C1024]] : index) stride(%[[C1]] : index) start_idx(%[[C1]] : index)
-   !CHECK: %[[MAP:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
-   !CHECK: omp.target   map_entries(%[[MAP]] : !fir.ref<!fir.array<1024xi32>>) {
+   !CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[VAL_0]] : !fir.ref<!fir.array<1024xi32>>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+   !CHECK: %[[MAP2:.*]] = omp.map_info var_ptr(%[[VAL_1]] : !fir.ref<i32>)   map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+   !CHECK: omp.target   map_entries(%[[MAP1]], %[[MAP2]] : !fir.ref<!fir.array<1024xi32>>, !fir.ref<i32>) {
+   !CHECK: ^bb0(%[[VAL_2:.*]]: !fir.ref<!fir.array<1024xi32>>, %[[VAL_3:.*]]: !fir.ref<i32>):
       !CHECK-NEXT: omp.parallel
       !$omp target parallel do map(tofrom: a)
-         !CHECK: %[[VAL_2:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
-         !CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
-         !CHECK: %[[VAL_4:.*]] = arith.constant 1024 : i32
+         !CHECK: %[[VAL_4:.*]] = fir.alloca i32 {adapt.valuebyref, pinned}
          !CHECK: %[[VAL_5:.*]] = arith.constant 1 : i32
-         !CHECK: omp.wsloop   for  (%[[VAL_6:.*]]) : i32 = (%[[VAL_3]]) to (%[[VAL_4]]) inclusive step (%[[VAL_5]]) {
-         !CHECK: fir.store %[[VAL_6]] to %[[VAL_2]] : !fir.ref<i32>
-         !CHECK: %[[VAL_7:.*]] = arith.constant 10 : i32
-         !CHECK: %[[VAL_8:.*]] = fir.load %[[VAL_2]] : !fir.ref<i32>
-         !CHECK: %[[VAL_9:.*]] = fir.convert %[[VAL_8]] : (i32) -> i64
-         !CHECK: %[[VAL_10:.*]] = arith.constant 1 : i64
-         !CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_9]], %[[VAL_10]] : i64
-         !CHECK: %[[VAL_12:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_11]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
-         !CHECK: fir.store %[[VAL_7]] to %[[VAL_12]] : !fir.ref<i32>
+         !CHECK: %[[VAL_6:.*]] = arith.constant 1024 : i32
+         !CHECK: %[[VAL_7:.*]] = arith.constant 1 : i32
+         !CHECK: omp.wsloop   for  (%[[VAL_8:.*]]) : i32 = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_7]]) {
+         !CHECK: fir.store %[[VAL_8]] to %[[VAL_4]] : !fir.ref<i32>
+         !CHECK: %[[VAL_9:.*]] = arith.constant 10 : i32
+         !CHECK: %[[VAL_10:.*]] = fir.load %[[VAL_4]] : !fir.ref<i32>
+         !CHECK: %[[VAL_11:.*]] = fir.convert %[[VAL_10]] : (i32) -> i64
+         !CHECK: %[[VAL_12:.*]] = arith.constant 1 : i64
+         !CHECK: %[[VAL_13:.*]] = arith.subi %[[VAL_11]], %[[VAL_12]] : i64
+         !CHECK: %[[VAL_14:.*]] = fir.coordinate_of %[[VAL_2]], %[[VAL_13]] : (!fir.ref<!fir.array<1024xi32>>, i64) -> !fir.ref<i32>
+         !CHECK: fir.store %[[VAL_9]] to %[[VAL_14]] : !fir.ref<i32>
          do i = 1, 1024
             a(i) = 10
          end do
@@ -301,4 +305,4 @@ subroutine omp_target_parallel_do
    !CHECK: omp.terminator
    !CHECK: }
    !$omp end target parallel do
-end subroutine omp_target_parallel_do
+ end subroutine omp_target_parallel_do
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bcd60d8046c8925..248b92a08649c88 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1149,7 +1149,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
                        Variadic<DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
                        OptionalAttr<UI64Attr>:$map_type,
                        OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
-                       DefaultValuedAttr<BoolAttr, "false">:$implicit,
                        OptionalAttr<StrAttr>:$name);
   let results = (outs OpenMP_PointerLikeType:$omp_ptr);
 
@@ -1177,7 +1176,7 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
     ```
     =>
     ```mlir
-    omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef) implicit(false)
+    omp.map_info var_ptr(%index_ssa) map_type(to) map_capture_type(ByRef)
       name(index)
     ```
 
@@ -1189,9 +1188,6 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
       objects (e.g. derived types or classes), indicates the bounds to be copied
       of the variable. When it's an array slice it is in rank order where rank 0
       is the inner-most dimension.
-    - `implicit`: indicates where the map item has been specified explicitly in a
-      map clause or captured implicitly by being used in a target region with no
-      map or other data mapping construct.
     - 'map_clauses': OpenMP map type for this map capture, for example: from, to and
       always. It's a bitfield composed of the OpenMP runtime flags stored in
       OpenMPOffloadMappingFlags.
@@ -1375,7 +1371,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
 
-def TargetOp : OpenMP_Op<"target",[AttrSizedOperandSegments]> {
+def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
   let summary = "target construct";
   let description = [{
     The target construct includes a region of code which is to be executed
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bf9355ed62676b..bb4fc191432957c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -693,6 +693,12 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "always")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
 
+    if (mapTypeMod == "literal")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
+    if (mapTypeMod == "implicit")
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
+
     if (mapTypeMod == "close")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
 
@@ -740,6 +746,12 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
   if (mapTypeToBitFlag(mapTypeBits,
                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
     mapTypeStrs.push_back("always");
+  if (mapTypeToBitFlag(mapTypeBits,
+                       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL))
+    mapTypeStrs.push_back("literal");
+  if (mapTypeToBitFlag(mapTypeBits,
+                       llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
+    mapTypeStrs.push_back("implicit");
   if (mapTypeToBitFlag(mapTypeBits,
                        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
     mapTypeStrs.push_back("close");

>From 928d071fbf717289e1d01c9d82f5618c1d56bc86 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:13:46 +0100
Subject: [PATCH 2/2] [OpenMP][MLIR] Add "IsolatedFromAbove" and
 "OutlineableOpenMPOpInterface" trait to omp.target

This patch adds the MLIR translation changes required for add the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target. It links the newly added block arguments to their corresponding llvm values.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 27 ++++++++++++++-----
 .../OpenMPToLLVM/convert-to-llvmir.mlir       | 14 ++++++----
 mlir/test/Dialect/OpenMP/canonicalize.mlir    |  5 ++--
 .../LLVMIR/omptarget-region-device-llvm.mlir  |  7 ++---
 .../omptarget-region-llvm-target-device.mlir  |  7 ++---
 .../Target/LLVMIR/omptarget-region-llvm.mlir  |  7 ++---
 .../omptarget-region-parallel-llvm.mlir       |  7 ++---
 7 files changed, 48 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8f7f1963b3e5a4f..f796ace1ab29b9e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2011,6 +2011,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   auto targetOp = cast<omp::TargetOp>(opInst);
   auto &targetRegion = targetOp.getRegion();
+  DataLayout DL = DataLayout(opInst.getParentOfType<ModuleOp>());
+  SmallVector<Value> mapOperands = targetOp.getMapOperands();
 
   // This function filters out kernel data that will not show up as kernel
   // input arguments to the generated kernel function but will still need
@@ -2018,7 +2020,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   // (declare target). It also prepares some data used for generating the
   // kernel and populating the associated OpenMP runtime data structures.
   auto getKernelArguments =
-      [&](const llvm::SetVector<Value> &operandSet,
+      [&](const llvm::SmallVectorImpl<Value> &operandSet,
           llvm::SmallVectorImpl<llvm::Value *> &llvmInputs) {
         for (Value operand : operandSet) {
           if (!getRefPtrIfDeclareTarget(operand, moduleTranslation))
@@ -2026,11 +2028,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
         }
       };
 
-  llvm::SetVector<Value> operandSet;
-  getUsedValuesDefinedAbove(targetRegion, operandSet);
+  llvm::SmallVector<Value> mapVals;
+  for (Value mapOp : mapOperands) {
+    auto mapInfoOp =
+        mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+    mapVals.push_back(mapInfoOp.getVarPtr());
+  }
 
   llvm::SmallVector<llvm::Value *> inputs;
-  getKernelArguments(operandSet, inputs);
+  getKernelArguments(mapVals, inputs);
 
   LogicalResult bodyGenStatus = success();
 
@@ -2038,6 +2044,16 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   auto bodyCB = [&](InsertPointTy allocaIP,
                     InsertPointTy codeGenIP) -> InsertPointTy {
     builder.restoreIP(codeGenIP);
+    unsigned argIndex = 0;
+    for (auto &mapOp : mapOperands) {
+      auto mapInfoOp =
+          mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+      llvm::Value *mapOpValue =
+          moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
+      const auto &arg = targetRegion.front().getArgument(argIndex);
+      moduleTranslation.mapValue(arg, mapOpValue);
+      argIndex++;
+    }
     llvm::BasicBlock *exitBlock = convertOmpOpRegions(
         targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
     builder.SetInsertPoint(exitBlock);
@@ -2065,9 +2081,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
-  DataLayout DL = DataLayout(opInst.getParentOfType<ModuleOp>());
-  SmallVector<Value> mapOperands = targetOp.getMapOperands();
-
   auto getMapTypes = [](mlir::OperandRange mapOperands,
                         mlir::MLIRContext *ctx) {
     SmallVector<mlir::Attribute> mapTypes;
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index 1df27dd9957e594..b290c69a14cfba2 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -244,10 +244,12 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr<array<1024 x i32>>, %i : !ll
 // CHECK:                             %[[ARG_0:.*]]: !llvm.ptr<array<1024 x i32>>,
 // CHECK:                             %[[ARG_1:.*]]: !llvm.ptr<i32>) {
 // CHECK:           %[[VAL_0:.*]] = llvm.mlir.constant(64 : i32) : i32
-// CHECK:           %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr<array<1024 x i32>>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
-// CHECK:           omp.target   thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP]] : !llvm.ptr<array<1024 x i32>>) {
+// CHECK:           %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr<array<1024 x i32>>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
+// CHECK:           %[[MAP2:.*]] = omp.map_info var_ptr(%[[ARG_1]] : !llvm.ptr<i32>)   map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr<i32> {name = ""}
+// CHECK:           omp.target   thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP1]], %[[MAP2]] : !llvm.ptr<array<1024 x i32>>, !llvm.ptr<i32>) {
+// CHECK:           ^bb0(%[[BB_ARG0:.*]]: !llvm.ptr<array<1024 x i32>>, %[[BB_ARG1:.*]]: !llvm.ptr<i32>):
 // CHECK:             %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32
-// CHECK:             llvm.store %[[VAL_1]], %[[ARG_1]] : !llvm.ptr<i32>
+// CHECK:             llvm.store %[[VAL_1]], %[[BB_ARG1]] : !llvm.ptr<i32>
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
@@ -256,9 +258,11 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr<array<1024 x i32>>, %i : !ll
 llvm.func @_QPomp_target(%a : !llvm.ptr<array<1024 x i32>>, %i : !llvm.ptr<i32>) {
   %0 = llvm.mlir.constant(64 : i32) : i32
   %1 = omp.map_info var_ptr(%a : !llvm.ptr<array<1024 x i32>>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<array<1024 x i32>> {name = ""}
-  omp.target   thread_limit(%0 : i32) map_entries(%1 : !llvm.ptr<array<1024 x i32>>) {
+  %3 = omp.map_info var_ptr(%i : !llvm.ptr<i32>)   map_clauses(literal, implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr<i32> {name = ""}
+  omp.target   thread_limit(%0 : i32) map_entries(%1, %3 : !llvm.ptr<array<1024 x i32>>, !llvm.ptr<i32>) {
+    ^bb0(%arg0: !llvm.ptr<array<1024 x i32>>, %arg1: !llvm.ptr<i32>):
     %2 = llvm.mlir.constant(10 : i32) : i32
-    llvm.store %2, %i : !llvm.ptr<i32>
+    llvm.store %2, %arg1 : !llvm.ptr<i32>
     omp.terminator
   }
   llvm.return
diff --git a/mlir/test/Dialect/OpenMP/canonicalize.mlir b/mlir/test/Dialect/OpenMP/canonicalize.mlir
index 68f5bacb1def178..4ecbb027ba47f24 100644
--- a/mlir/test/Dialect/OpenMP/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenMP/canonicalize.mlir
@@ -131,8 +131,9 @@ func.func private @foo() -> ()
 
 func.func @constant_hoisting_target(%x : !llvm.ptr<i32>) {
   omp.target {
+    ^bb0(%arg0: !llvm.ptr<i32>):
     %c1 = arith.constant 10 : i32
-    llvm.store %c1, %x : i32, !llvm.ptr<i32>
+    llvm.store %c1, %arg0 : i32, !llvm.ptr<i32>
     omp.terminator
   }
   return
@@ -141,4 +142,4 @@ func.func @constant_hoisting_target(%x : !llvm.ptr<i32>) {
 // CHECK-LABEL: func.func @constant_hoisting_target
 // CHECK-NOT: arith.constant
 // CHECK: omp.target
-// CHECK-NEXT: arith.constant
+// CHECK: arith.constant
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index cf70469e7484f64..dbc5f2d3475d01c 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -16,10 +16,11 @@ module attributes {omp.is_target_device = true} {
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
-      %8 = llvm.load %3 : !llvm.ptr<i32>
-      %9 = llvm.load %5 : !llvm.ptr<i32>
+    ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
+      %8 = llvm.load %arg0 : !llvm.ptr<i32>
+      %9 = llvm.load %arg1 : !llvm.ptr<i32>
       %10 = llvm.add %8, %9  : i32
-      llvm.store %10, %7 : !llvm.ptr<i32>
+      llvm.store %10, %arg2 : !llvm.ptr<i32>
       omp.terminator
     }
     llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
index 01a3bd556294f3e..bea3e481cae3496 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
@@ -3,10 +3,11 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 module attributes {omp.is_target_device = true} {
-  llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
-    %0 = omp.map_info var_ptr(%arg0 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
-    %1 = omp.map_info var_ptr(%arg1 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
+  llvm.func @writeindex_omp_outline_0_(%val0: !llvm.ptr<i32>, %val1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
+    %0 = omp.map_info var_ptr(%val0 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
+    %1 = omp.map_info var_ptr(%val1 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     omp.target   map_entries(%0, %1 : !llvm.ptr<i32>, !llvm.ptr<i32>) {
+    ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>):
       %2 = llvm.mlir.constant(20 : i32) : i32
       %3 = llvm.mlir.constant(10 : i32) : i32
       llvm.store %3, %arg0 : !llvm.ptr<i32>
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
index 51506386d83782d..8d11a684f7139c2 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
@@ -16,10 +16,11 @@ module attributes {omp.is_target_device = false} {
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
-      %8 = llvm.load %3 : !llvm.ptr<i32>
-      %9 = llvm.load %5 : !llvm.ptr<i32>
+    ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
+      %8 = llvm.load %arg0 : !llvm.ptr<i32>
+      %9 = llvm.load %arg1 : !llvm.ptr<i32>
       %10 = llvm.add %8, %9  : i32
-      llvm.store %10, %7 : !llvm.ptr<i32>
+      llvm.store %10, %arg2 : !llvm.ptr<i32>
       omp.terminator
     }
     llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
index f0bd37ca36e93b4..a0c57ddc8148aed 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
@@ -16,11 +16,12 @@ module attributes {omp.is_target_device = false} {
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr<i32>)   map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr<i32>)   map_clauses(tofrommap_info) capture(ByRef) -> !llvm.ptr<i32> {name = ""}
     omp.target map_entries( %map1, %map2, %map3 : !llvm.ptr<i32>, !llvm.ptr<i32>, !llvm.ptr<i32>) {
+    ^bb0(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>, %arg2: !llvm.ptr<i32>):
       omp.parallel {
-        %8 = llvm.load %3 : !llvm.ptr<i32>
-        %9 = llvm.load %5 : !llvm.ptr<i32>
+        %8 = llvm.load %arg0 : !llvm.ptr<i32>
+        %9 = llvm.load %arg1 : !llvm.ptr<i32>
         %10 = llvm.add %8, %9  : i32
-        llvm.store %10, %7 : !llvm.ptr<i32>
+        llvm.store %10, %arg2 : !llvm.ptr<i32>
         omp.terminator
         }
       omp.terminator



More information about the Mlir-commits mailing list