[Mlir-commits] [mlir] 55d6643 - [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (#81081)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 13 03:15:55 PST 2024
Author: Pranav Bhandarkar
Date: 2024-02-13T11:15:51Z
New Revision: 55d6643ccf6f9394d88d3d6359492000c58c2357
URL: https://github.com/llvm/llvm-project/commit/55d6643ccf6f9394d88d3d6359492000c58c2357
DIFF: https://github.com/llvm/llvm-project/commit/55d6643ccf6f9394d88d3d6359492000c58c2357.diff
LOG: [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (#81081)
This patch adds support for the depend clause in a number of OpenMP
directives/constructs related to offloading. Specifically, it adds the
handling of the depend clause when it is used with the following
constructs
- target
- target enter data
- target update data
- target exit data
Added:
Modified:
flang/lib/Lower/OpenMP.cpp
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e5887620d503b9..06850bebd7d05a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2825,7 +2825,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
directive);
return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
- deviceOperand, nowaitAttr, mapOperands);
+ deviceOperand, nullptr, mlir::ValueRange(),
+ nowaitAttr, mapOperands);
}
// This functions creates a block for the body of the targetOp's region. It adds
@@ -3090,7 +3091,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
- nowaitAttr, mapOperands);
+ nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 44f3e5b8dbc361..c7a32de256e2a5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -781,7 +781,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
def ClauseTaskDepend : I32EnumAttr<
"ClauseTaskDepend",
- "task depend clause",
+ "depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::omp";
@@ -1447,11 +1447,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
The $map_types specifies the types and modifiers for the map clause.
- TODO: depend clause and map_type_modifier values iterator and mapper.
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: map_type_modifier values iterator and mapper.
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1460,6 +1466,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1494,11 +1501,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
The $map_types specifies the types and modifiers for the map clause.
- TODO: depend clause and map_type_modifier values iterator and mapper.
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: map_type_modifier values iterator and mapper.
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1507,6 +1520,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1545,11 +1559,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
during verification to make sure the restrictions for target update are
respected.
- TODO: depend clause
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$map_operands);
@@ -1558,6 +1577,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `motion_entries` `(` $map_operands `:` type($map_operands) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];
@@ -1587,13 +1607,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
The optional $nowait elliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.
- TODO: is_device_ptr, depend, defaultmap, in_reduction
+ The `depends` and `depend_vars` arguments are variadic lists of values
+ that specify the dependencies of this particular target task in relation to
+ other tasks.
+
+ TODO: is_device_ptr, defaultmap, in_reduction
}];
let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
Optional<AnyInteger>:$thread_limit,
+ OptionalAttr<TaskDependArrayAttr>:$depends,
+ Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);
@@ -1605,6 +1631,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+ | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index ef08bd87efc93a..849449f9127dd8 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -628,7 +628,7 @@ static LogicalResult verifyDependVarList(Operation *op,
return op->emitOpError() << "expected as many depend values"
" as depend variables";
} else {
- if (depends)
+ if (depends && !depends->empty())
return op->emitOpError() << "unexpected depend values";
return success();
}
@@ -1032,19 +1032,31 @@ LogicalResult DataOp::verify() {
}
LogicalResult EnterDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult ExitDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult UpdateDataOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
LogicalResult TargetOp::verify() {
- return verifyMapClause(*this, getMapOperands());
+ LogicalResult verifyDependVars =
+ verifyDependVarList(*this, getDepends(), getDependVars());
+ return failed(verifyDependVars) ? verifyDependVars
+ : verifyMapClause(*this, getMapOperands());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 59b42390b206f1..1c1b6ea58e02ee 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
// -----
+func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
func.func @omp_target_exit_data(%map1: memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
// expected-error @below {{from, release and delete map types are permitted}}
@@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
// -----
+func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32
// -----
+func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
+ %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+ return
+}
+
+// -----
+
+func.func @omp_target_depend(%data_var: memref<i32>) {
+ // expected-error @below {{op expected as many depend values as depend variables}}
+ "omp.target"(%data_var) ({
+ "omp.terminator"() : () -> ()
+ }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
+ "func.return"() : () -> ()
+}
+
+// -----
+
func.func @omp_distribute(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 651405964c0675..3bb4a288376ede 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1:
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1717,6 +1717,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+ // CHECK: omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+ // CHECK: omp.terminator
+ omp.terminator
+ } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+ return
+}
+
func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
@@ -2145,3 +2157,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
}
return
}
+
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+ %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+ %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+ // Do some work on the host that writes to 'a'
+ omp.task depend(taskdependout -> %a : memref<?xi32>) {
+ "test.foo"(%a) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Then map that over to the target
+ // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
+
+ // Compute 'b' on the target and copy it back
+ // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+ omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+ ^bb0(%arg0: memref<?xi32>) :
+ "test.foo"(%arg0) : (memref<?xi32>) -> ()
+ omp.terminator
+ }
+
+ // Update 'a' on the host using 'b'
+ omp.task depend(taskdependout -> %a: memref<?xi32>){
+ "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+ }
+
+ // Copy the updated 'a' onto the target
+ // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+ omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+ // Compute 'c' on the target and copy it back
+ %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+ omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+ ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+ "test.foobar"() : ()->()
+ omp.terminator
+ }
+ // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
+ omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+ return
+}
More information about the Mlir-commits
mailing list