[Mlir-commits] [mlir] [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (PR #81081)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 7 18:55:37 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Pranav Bhandarkar (bhandarkar-pranav)
<details>
<summary>Changes</summary>
This patch adds support for the depend clause in a number of OpenMP directives/constructs related to offloading. Specifically, it adds the handling of the depend clause when it is used with the following constructs
- target
- target enter data
- target update data
- target exit data
---
Full diff: https://github.com/llvm/llvm-project/pull/81081.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+32-5)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+17-5)
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+37)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+62-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ca363505485773..3c9b59cea181a7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -693,7 +693,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
def ClauseTaskDepend : I32EnumAttr<
"ClauseTaskDepend",
- "task depend clause",
+ "depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::omp";
@@ -1359,11 +1359,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
The $map_types specifies the types and modifiers for the map clause.
- TODO: depend clause and map_type_modifier values iterator and mapper.
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: map_type_modifier values iterator and mapper.
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1372,6 +1378,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1406,11 +1413,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
The $map_types specifies the types and modifiers for the map clause.
- TODO: depend clause and map_type_modifier values iterator and mapper.
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: map_type_modifier values iterator and mapper.
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1419,6 +1432,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1457,11 +1471,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
during verification to make sure the restrictions for target update are
respected.
- TODO: depend clause
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$map_operands);
@@ -1470,6 +1489,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `motion_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1499,13 +1519,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
The optional $nowait elliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.
- TODO: is_device_ptr, depend, defaultmap, in_reduction
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: is_device_ptr, defaultmap, in_reduction
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
Optional<AnyInteger>:$thread_limit,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1517,6 +1543,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 381f17d0804191..11800b3cd6ce55 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -561,7 +561,7 @@ static LogicalResult verifyDependVarList(Operation *op,
return op->emitOpError() << "expected as many depend values"
" as depend variables";
} else {
- if (depends)
+ if (depends && !depends->empty())
return op->emitOpError() << "unexpected depend values";
return success();
}
@@ -965,19 +965,31 @@ LogicalResult DataOp::verify() {
}
LogicalResult EnterDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult ExitDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult UpdateDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult TargetOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 812b79e35595f0..481aa950ceeef3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
// -----
+func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
func.func @omp_target_exit_data(%map1: memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
// expected-error @below {{from, release and delete map types are permitted}}
@@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
// -----
+func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32
// -----
+func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
+func.func @omp_target_depend(%data_var: memref<i32>) {
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ "omp.target"(%data_var) ({
+ "omp.terminator"() : () -> ()
+ }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
+ "func.return"() : () -> ()
+}
+
+// -----
+
func.func @omp_distribute(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 65a704d18107b5..950558483d21f4 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1:
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1710,6 +1710,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+ // CHECK: omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.terminator
+ omp.terminator
+ } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+ return
+}
+
func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
@@ -2138,3 +2150,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
}
return
}
+
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+ %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+ // Do some work on the host that writes to 'a'
+ omp.task depend(taskdependout -> %a : memref<?xi32>) {
+ "test.foo"(%a) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Then map that over to the target
+ // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
+
+ // Compute 'b' on the target and copy it back
+ // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+ omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+ ^bb0(%arg0: memref<?xi32>) :
+ "test.foo"(%arg0) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Update 'a' on the host using 'b'
+ omp.task depend(taskdependout -> %a: memref<?xi32>){
+ "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+ }
+
+ // Copy the updated 'a' onto the target
+ // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+ // Compute 'c' on the target and copy it back
+ %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
+ omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/81081
More information about the Mlir-commits
mailing list