[flang-commits] [mlir] [flang] [flang][MLIR][OpenMP] Emit `UpdateDataOp` from `!$omp target update` (PR #75345)

via flang-commits flang-commits at lists.llvm.org
Wed Dec 13 06:28:35 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Kareem Ergawy (ergawy)

<details>
<summary>Changes</summary>

Emits MLIR op corresponding to `!$omp target update` directive. So far,
only motion types: `to` and `from` are supported. Motion modifiers:
`present`, `mapper`, and `iterator` are not supported yet.

This is a follow up to #<!-- -->75047 & #<!-- -->75159, only the last commit is relevant
to this PR.

---

Patch is 24.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75345.diff


8 Files Affected:

- (modified) flang/lib/Lower/OpenMP.cpp (+102-1) 
- (modified) flang/test/Lower/OpenMP/target.f90 (+88) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+50) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+42-1) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+21-3) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+60) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+12) 
- (modified) mlir/test/Target/LLVMIR/omptarget-llvm.mlir (+38) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index eeba87fcd15116..a86f35b7b1f818 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -605,6 +605,16 @@ class ClauseProcessor {
                       llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
                           &useDeviceSymbols) const;
 
+  bool
+  processToMotionClauses(Fortran::semantics::SemanticsContext &semanticsContext,
+                         Fortran::lower::StatementContext &stmtCtx,
+                         llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
+  bool processFromMotionClauses(
+      Fortran::semantics::SemanticsContext &semanticsContext,
+      Fortran::lower::StatementContext &stmtCtx,
+      llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
   // Call this method for these clauses that should be supported but are not
   // implemented yet. It triggers a compilation error if any of the given
   // clauses is found.
@@ -671,6 +681,12 @@ class ClauseProcessor {
     return false;
   }
 
+  template <typename T>
+  bool
+  processMotionClauses(Fortran::semantics::SemanticsContext &semanticsContext,
+                       Fortran::lower::StatementContext &stmtCtx,
+                       llvm::SmallVectorImpl<mlir::Value> &mapOperands);
+
   Fortran::lower::AbstractConverter &converter;
   const Fortran::parser::OmpClauseList &clauses;
 };
@@ -1890,6 +1906,62 @@ bool ClauseProcessor::processUseDevicePtr(
       });
 }
 
+bool ClauseProcessor::processToMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return processMotionClauses<ClauseProcessor::ClauseTy::To>(
+      semanticsContext, stmtCtx, mapOperands);
+}
+
+bool ClauseProcessor::processFromMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return processMotionClauses<ClauseProcessor::ClauseTy::From>(
+      semanticsContext, stmtCtx, mapOperands);
+}
+
+template <typename T>
+bool ClauseProcessor::processMotionClauses(
+    Fortran::semantics::SemanticsContext &semanticsContext,
+    Fortran::lower::StatementContext &stmtCtx,
+    llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
+  return findRepeatableClause<T>([&](const T *motionClause,
+                                     const Fortran::parser::CharBlock &source) {
+    mlir::Location clauseLocation = converter.genLocation(source);
+    fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+    static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
+                  std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
+
+    // TODO Support motion modifiers: present, mapper, iterator.
+    constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+        std::is_same_v<T, ClauseProcessor::ClauseTy::To>
+            ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
+            : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+
+    for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
+      llvm::SmallVector<mlir::Value> bounds;
+      std::stringstream asFortran;
+      mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
+          Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
+          mlir::omp::DataBoundsOp>(converter, firOpBuilder, semanticsContext,
+                                   stmtCtx, ompObject, clauseLocation,
+                                   asFortran, bounds, treatIndexAsSection);
+
+      mlir::Value mapOp = createMapInfoOp(
+          firOpBuilder, clauseLocation, baseAddr, asFortran, bounds,
+          static_cast<
+              std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+              mapTypeBits),
+          mlir::omp::VariableCaptureKind::ByRef, baseAddr.getType());
+
+      mapOperands.push_back(mapOp);
+    }
+  });
+}
+
 template <typename... Ts>
 void ClauseProcessor::processTODO(mlir::Location currentLocation,
                                   llvm::omp::Directive directive) const {
@@ -2455,6 +2527,34 @@ genEnterExitDataOp(Fortran::lower::AbstractConverter &converter,
                                    deviceOperand, nowaitAttr, mapOperands);
 }
 
+mlir::omp::UpdateDataOp
+genUpdateDataOp(Fortran::lower::AbstractConverter &converter,
+                Fortran::semantics::SemanticsContext &semanticsContext,
+                mlir::Location currentLocation,
+                const Fortran::parser::OmpClauseList &clauseList) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  Fortran::lower::StatementContext stmtCtx;
+  mlir::Value ifClauseOperand, deviceOperand;
+  mlir::UnitAttr nowaitAttr;
+  llvm::SmallVector<mlir::Value> motionOperands;
+  llvm::SmallVector<mlir::Value> mapOperands;
+
+  ClauseProcessor cp(converter, clauseList);
+  cp.processIf(
+      Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate,
+      ifClauseOperand);
+  cp.processDevice(stmtCtx, deviceOperand);
+  cp.processNowait(nowaitAttr);
+  cp.processToMotionClauses(semanticsContext, stmtCtx, mapOperands);
+  cp.processFromMotionClauses(semanticsContext, stmtCtx, mapOperands);
+
+  cp.processTODO<Fortran::parser::OmpClause::Depend>(
+      currentLocation, llvm::omp::Directive::OMPD_target_update);
+
+  return firOpBuilder.create<mlir::omp::UpdateDataOp>(
+      currentLocation, ifClauseOperand, deviceOperand, nowaitAttr, mapOperands);
+}
+
 // This functions creates a block for the body of the targetOp's region. It adds
 // all the symbols present in mapSymbols as block arguments to this block.
 static void genBodyOfTargetOp(
@@ -2858,7 +2958,8 @@ genOmpSimpleStandalone(Fortran::lower::AbstractConverter &converter,
                                               currentLocation, opClauseList);
     break;
   case llvm::omp::Directive::OMPD_target_update:
-    TODO(currentLocation, "OMPD_target_update");
+    genUpdateDataOp(converter, semanticsContext, currentLocation, opClauseList);
+    break;
   case llvm::omp::Directive::OMPD_ordered:
     TODO(currentLocation, "OMPD_ordered");
   }
diff --git a/flang/test/Lower/OpenMP/target.f90 b/flang/test/Lower/OpenMP/target.f90
index 26b4d595d52291..5500097ee58a9d 100644
--- a/flang/test/Lower/OpenMP/target.f90
+++ b/flang/test/Lower/OpenMP/target.f90
@@ -134,6 +134,94 @@ subroutine omp_target_exit_device
    !$omp target exit data map(from: a) device(d)
 end subroutine omp_target_exit_device
 
+!===============================================================================
+! Target_Update `to` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_to() {
+subroutine omp_target_update_to
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info var_ptr(%[[A_DECL]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(to) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update to(a)
+end subroutine omp_target_update_to
+
+!===============================================================================
+! Target_Update `from` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_from() {
+subroutine omp_target_update_from
+   integer :: a(1024)
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: %[[TO_MAP:.*]] = omp.map_info var_ptr(%[[A_DECL]]#1 : !fir.ref<!fir.array<1024xi32>>, !fir.array<1024xi32>)
+   !CHECK-SAME: map_clauses(from) capture(ByRef)
+   !CHECK-SAME: bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<1024xi32>> {name = "a"}
+
+   !CHECK: omp.target_update_data motion_entries(%[[TO_MAP]] : !fir.ref<!fir.array<1024xi32>>)
+   !$omp target update from(a)
+end subroutine omp_target_update_from
+
+!===============================================================================
+! Target_Update `if` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_if() {
+subroutine omp_target_update_if
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[COND:.*]] = fir.convert %{{.*}} : (!fir.logical<4>) -> i1
+
+   !CHECK: omp.target_update_data if(%[[COND]] : i1) motion_entries
+   !$omp target update from(a) if(i)
+end subroutine omp_target_update_if
+
+!===============================================================================
+! Target_Update `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_device() {
+subroutine omp_target_update_device
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+   !CHECK-DAG: %[[DEVICE:.*]] = arith.constant 1 : i32
+
+   !CHECK: omp.target_update_data device(%[[DEVICE]] : i32) motion_entries
+   !$omp target update from(a) device(1)
+end subroutine omp_target_update_device
+
+!===============================================================================
+! Target_Update `nowait` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_update_nowait() {
+subroutine omp_target_update_nowait
+   integer :: a(1024)
+   logical :: i
+
+   !CHECK-DAG: %[[A_DECL:.*]]:2 = hlfir.declare %{{.*}}(%{{.*}})
+   !CHECK-DAG: %[[BOUNDS:.*]] = omp.bounds
+
+   !CHECK: omp.target_update_data nowait motion_entries
+   !$omp target update from(a) nowait
+end subroutine omp_target_update_nowait
+
 !===============================================================================
 ! Target_Data with region
 !===============================================================================
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..b9989b335a2aef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
   let hasVerifier = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 2.14.6 target update data Construct
+//===---------------------------------------------------------------------===//
+
+def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
+                                                 [AttrSizedOperandSegments]>{
+  let  summary = "target update data construct";
+  let description = [{
+    The target update directive makes the corresponding list items in the device
+    data environment consistent with their original list items, according to the
+    specified motion clauses. The target update construct is a stand-alone
+    directive.
+
+    The optional $if_expr parameter specifies a boolean result of a
+    conditional check. If this value is 1 or is not provided then the target
+    region runs on a device, if it is 0 then the target region is executed
+    on the host device.
+
+    The optional $device parameter specifies the device number for the
+    target region.
+
+    The optional $nowait eliminates the implicit barrier so the parent
+    task can make progress even if the target task is not yet completed.
+
+    We use `MapInfoOp` to model the motion clauses and their modifiers. Even
+    though the spec differentiates between map-types & map-type-modifiers vs.
+    motion-clauses & motion-modifiers, the motion clauses and their modifiers are
+    a subset of map types and their modifiers. The subset relation is handled in
+    during verification to make sure the restrictions for target update are
+    respected.
+
+    TODO: depend clause
+  }];
+
+  let arguments = (ins Optional<I1>:$if_expr,
+                       Optional<AnyInteger>:$device,
+                       UnitAttr:$nowait,
+                       Variadic<OpenMP_PointerLikeType>:$motion_operands);
+
+  let assemblyFormat = [{
+    oilist(`if` `(` $if_expr `:` type($if_expr) `)`
+    | `device` `(` $device `:` type($device) `)`
+    | `nowait` $nowait
+    | `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
+    ) attr-dict
+   }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..a12378827b4348 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -691,6 +691,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (parser.parseKeyword(&mapTypeMod))
       return failure();
 
+    // Map-type-modifiers
     if (mapTypeMod == "always")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
 
@@ -703,6 +704,7 @@ static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
     if (mapTypeMod == "present")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
 
+    // Map-types
     if (mapTypeMod == "to")
       mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
 
@@ -882,7 +884,6 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
 }
 
 static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
-
   for (auto mapOp : mapOperands) {
     if (!mapOp.getDefiningOp())
       emitError(op->getLoc(), "missing map operation");
@@ -898,6 +899,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
 
       uint64_t mapTypeBits = MapInfoOp.getMapType().value();
 
+      // Map-types
       bool to = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
       bool from = mapTypeToBitFlag(
@@ -905,6 +907,14 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       bool del = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
 
+      // Map-type-modifiers
+      bool always = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+      bool close = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
+      bool implicit = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+
       if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
         return emitError(op->getLoc(),
                          "to, from, tofrom and alloc map types are permitted");
@@ -915,6 +925,33 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       if (isa<ExitDataOp>(op) && to)
         return emitError(op->getLoc(),
                          "from, release and delete map types are permitted");
+
+      if (isa<UpdateDataOp>(op)) {
+        if (del) {
+          return emitError(op->getLoc(),
+                           "at least one of to or from map types must be "
+                           "specified, other map types are not permitted");
+        }
+
+        if (to & from) {
+          return emitError(
+              op->getLoc(),
+              "either to or from map types can be specified, not both");
+        }
+
+        if (!to & !from) {
+          return emitError(op->getLoc(),
+                           "at least one of to or from map types must be "
+                           "specified, other map types are not permitted");
+        }
+      }
+
+      // Check UpdateDataOp's valid map-type-modifiers.
+      if (isa<UpdateDataOp>(op) && (always | close | implicit)) {
+        return emitError(
+            op->getLoc(),
+            "present, mapper and iterator map type modifiers are permitted");
+      }
     } else {
       emitError(op->getLoc(), "map argument is not a map entry operation");
     }
@@ -940,6 +977,10 @@ LogicalResult ExitDataOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
 
+LogicalResult UpdateDataOp::verify() {
+  return verifyMapClause(*this, getMotionOperands());
+}
+
 LogicalResult TargetOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4f6200d29a70a6..088e7ae4231bef 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1915,6 +1915,23 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
             mapOperands = exitDataOp.getMapOperands();
             return success();
           })
+          .Case([&](omp::UpdateDataOp updateDataOp) {
+            if (updateDataOp.getNowait())
+              return failure();
+
+            if (auto ifExprVar = updateDataOp.getIfExpr())
+              ifCond = moduleTranslation.lookupValue(ifExprVar);
+
+            if (auto devId = updateDataOp.getDevice())
+              if (auto constOp =
+                      dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
+                if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
+                  deviceID = intAttr.getInt();
+
+            RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
+            mapOperands = updateDataOp.getMotionOperands();
+            return success();
+          })
           .Default([&](Operation *op) {
             return op->emitError("unsupported OpenMP operation: ")
                    << op->getName();
@@ -2748,9 +2765,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
       .Case([&](omp::ThreadprivateOp) {
         return convertOmpThreadprivate(*op, builder, moduleTranslation);
       })
-      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp>([&](auto op) {
-        return convertOmpTargetData(op, builder, moduleTranslation);
-      })
+      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
+          [&](auto op) {
+            return convertOmpTargetData(op, builder, moduleTranslation);
+          })
       .Case([&](omp::TargetOp) {
         return convertOmpTarget(*op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e54808f6cfdee5..df11f7ad413c86 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1658,4 +1658,64 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   return
 }
 
+// -----
+
+func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/75345


More information about the flang-commits mailing list