[Mlir-commits] [mlir] [flang][MLIR][OpenMP] Add support for `target update` directive. (PR #75047)

Kareem Ergawy llvmlistbot at llvm.org
Wed Dec 13 20:51:56 PST 2023


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/75047

>From 19ec0f4cf2f521e1fd627365ff78ed16642a0e99 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 11 Dec 2023 07:27:47 -0600
Subject: [PATCH] [flang][MLIR][OpenMP] Add support for `target update`
 directive.

Summary:

Add an op in the OMP dialect to model the `target update` direcive. This
change reuses the `MapInfoOp` used by other device directive to model
`map` clauses but verifies that the restrictions imposed by the `target
update` directive are respected.
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 50 ++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 41 ++++++++++++-
 mlir/test/Dialect/OpenMP/invalid.mlir         | 60 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 12 ++++
 4 files changed, 162 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 8ff5380f71ad45..b9989b335a2aef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1370,6 +1370,56 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
   let hasVerifier = 1;
 }
 
+//===---------------------------------------------------------------------===//
+// 2.14.6 target update data Construct
+//===---------------------------------------------------------------------===//
+
+def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
+                                                 [AttrSizedOperandSegments]>{
+  let  summary = "target update data construct";
+  let description = [{
+    The target update directive makes the corresponding list items in the device
+    data environment consistent with their original list items, according to the
+    specified motion clauses. The target update construct is a stand-alone
+    directive.
+
+    The optional $if_expr parameter specifies a boolean result of a
+    conditional check. If this value is 1 or is not provided then the target
+    region runs on a device, if it is 0 then the target region is executed
+    on the host device.
+
+    The optional $device parameter specifies the device number for the
+    target region.
+
+    The optional $nowait eliminates the implicit barrier so the parent
+    task can make progress even if the target task is not yet completed.
+
+    We use `MapInfoOp` to model the motion clauses and their modifiers. Even
+    though the spec differentiates between map-types & map-type-modifiers vs.
+    motion-clauses & motion-modifiers, the motion clauses and their modifiers are
+    a subset of map types and their modifiers. The subset relation is handled in
+    during verification to make sure the restrictions for target update are
+    respected.
+
+    TODO: depend clause
+  }];
+
+  let arguments = (ins Optional<I1>:$if_expr,
+                       Optional<AnyInteger>:$device,
+                       UnitAttr:$nowait,
+                       Variadic<OpenMP_PointerLikeType>:$motion_operands);
+
+  let assemblyFormat = [{
+    oilist(`if` `(` $if_expr `:` type($if_expr) `)`
+    | `device` `(` $device `:` type($device) `)`
+    | `nowait` $nowait
+    | `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
+    ) attr-dict
+   }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.14.5 target construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 20df0099cbd24d..2fe0f6672646c2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -882,7 +882,6 @@ static ParseResult parseCaptureType(OpAsmParser &parser,
 }
 
 static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
-
   for (auto mapOp : mapOperands) {
     if (!mapOp.getDefiningOp())
       emitError(op->getLoc(), "missing map operation");
@@ -898,6 +897,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
 
       uint64_t mapTypeBits = MapInfoOp.getMapType().value();
 
+      // Map-types
       bool to = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
       bool from = mapTypeToBitFlag(
@@ -905,6 +905,14 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       bool del = mapTypeToBitFlag(
           mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
 
+      // Map-type-modifiers
+      bool always = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
+      bool close = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
+      bool implicit = mapTypeToBitFlag(
+          mapTypeBits, llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
+
       if ((isa<DataOp>(op) || isa<TargetOp>(op)) && del)
         return emitError(op->getLoc(),
                          "to, from, tofrom and alloc map types are permitted");
@@ -915,6 +923,33 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
       if (isa<ExitDataOp>(op) && to)
         return emitError(op->getLoc(),
                          "from, release and delete map types are permitted");
+
+      if (isa<UpdateDataOp>(op)) {
+        if (del) {
+          return emitError(op->getLoc(),
+                           "at least one of to or from map types must be "
+                           "specified, other map types are not permitted");
+        }
+
+        if (to && from) {
+          return emitError(
+              op->getLoc(),
+              "either to or from map types can be specified, not both");
+        }
+
+        if (!to && !from) {
+          return emitError(op->getLoc(),
+                           "at least one of to or from map types must be "
+                           "specified, other map types are not permitted");
+        }
+      }
+
+      // Check UpdateDataOp's valid map-type-modifiers.
+      if (isa<UpdateDataOp>(op) && (always || close || implicit)) {
+        return emitError(
+            op->getLoc(),
+            "present, mapper and iterator map type modifiers are permitted");
+      }
     } else {
       emitError(op->getLoc(), "map argument is not a map entry operation");
     }
@@ -940,6 +975,10 @@ LogicalResult ExitDataOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
 
+LogicalResult UpdateDataOp::verify() {
+  return verifyMapClause(*this, getMotionOperands());
+}
+
 LogicalResult TargetOp::verify() {
   return verifyMapClause(*this, getMapOperands());
 }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e54808f6cfdee5..df11f7ad413c86 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1658,4 +1658,64 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   return
 }
 
+// -----
+
+func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_type_2(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{at least one of to or from map types must be specified, other map types are not permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_2(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(close, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_3(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{present, mapper and iterator map type modifiers are permitted}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
+// -----
+
+func.func @omp_target_update_invalid_motion_modifier_4(%map1 : memref<?xi32>) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(implicit, tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // expected-error @below {{either to or from map types can be specified, not both}}
+  omp.target_update_data motion_entries(%mapv : memref<?xi32>)
+  return
+}
+
 llvm.mlir.global internal @_QFsubEx() : i32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4d88d9ac86fe16..b0a6a5ac0a3fb9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2082,3 +2082,15 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
 
     return
 }
+
+// CHECK-LABEL: omp_target_update_data
+func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<?xi32>) -> () {
+    %mapv_from = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32> {name = ""}
+
+    %mapv_to = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(present, to) capture(ByRef) -> memref<?xi32> {name = ""}
+
+    // CHECK: omp.target_update_data if(%[[VAL_0:.*]] : i1) device(%[[VAL_1:.*]] : si32) nowait motion_entries(%{{.*}}, %{{.*}} : memref<?xi32>, memref<?xi32>)
+    omp.target_update_data if(%if_cond : i1) device(%device : si32) nowait motion_entries(%mapv_from , %mapv_to : memref<?xi32>, memref<?xi32>)
+    return
+}
+



More information about the Mlir-commits mailing list