[flang-commits] [flang] 55d6643 - [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (#81081)

via flang-commits flang-commits at lists.llvm.org
Tue Feb 13 03:15:56 PST 2024


Author: Pranav Bhandarkar
Date: 2024-02-13T11:15:51Z
New Revision: 55d6643ccf6f9394d88d3d6359492000c58c2357

URL: https://github.com/llvm/llvm-project/commit/55d6643ccf6f9394d88d3d6359492000c58c2357
DIFF: https://github.com/llvm/llvm-project/commit/55d6643ccf6f9394d88d3d6359492000c58c2357.diff

LOG: [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (#81081)

This patch adds support for the depend clause in a number of OpenMP
directives/constructs related to offloading. Specifically, it adds the
handling of the depend clause when it is used with the following
constructs

- target
- target enter data
- target update data
- target exit data

Added: 
    

Modified: 
    flang/lib/Lower/OpenMP.cpp
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e5887620d503b9..06850bebd7d05a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2825,7 +2825,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
                                                      directive);
 
   return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
-                                   deviceOperand, nowaitAttr, mapOperands);
+                                   deviceOperand, nullptr, mlir::ValueRange(),
+                                   nowaitAttr, mapOperands);
 }
 
 // This functions creates a block for the body of the targetOp's region. It adds
@@ -3090,7 +3091,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
 
   auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
       currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
-      nowaitAttr, mapOperands);
+      nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
 
   genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
                     mapSymLocs, mapSymbols, currentLocation);

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 44f3e5b8dbc361..c7a32de256e2a5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -781,7 +781,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
 
 def ClauseTaskDepend : I32EnumAttr<
     "ClauseTaskDepend",
-    "task depend clause",
+    "depend clause in a target or task construct",
     [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::omp";
@@ -1447,11 +1447,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
 
     The $map_types specifies the types and modifiers for the map clause.
 
-    TODO:  depend clause and map_type_modifier values iterator and mapper.
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO:  map_type_modifier values iterator and mapper.
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1460,6 +1466,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `map_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1494,11 +1501,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
 
     The $map_types specifies the types and modifiers for the map clause.
 
-    TODO:  depend clause and map_type_modifier values iterator and mapper.
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO: map_type_modifier values iterator and mapper.
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1507,6 +1520,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `map_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1545,11 +1559,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
     during verification to make sure the restrictions for target update are
     respected.
 
-    TODO: depend clause
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$map_operands);
 
@@ -1558,6 +1577,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `motion_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1587,13 +1607,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
     The optional $nowait elliminates the implicit barrier so the parent task can make progress
     even if the target task is not yet completed.
 
-    TODO:  is_device_ptr, depend, defaultmap, in_reduction
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO:  is_device_ptr, defaultmap, in_reduction
 
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
                        Optional<AnyInteger>:$thread_limit,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1605,6 +1631,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
     | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
     | `nowait` $nowait
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];
 

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index ef08bd87efc93a..849449f9127dd8 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -628,7 +628,7 @@ static LogicalResult verifyDependVarList(Operation *op,
       return op->emitOpError() << "expected as many depend values"
                                   " as depend variables";
   } else {
-    if (depends)
+    if (depends && !depends->empty())
       return op->emitOpError() << "unexpected depend values";
     return success();
   }
@@ -1032,19 +1032,31 @@ LogicalResult DataOp::verify() {
 }
 
 LogicalResult EnterDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult ExitDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult UpdateDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult TargetOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 59b42390b206f1..1c1b6ea58e02ee 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
 func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
   // expected-error @below {{from, release and delete map types are permitted}}
@@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
 func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
   %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
 
@@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32
 
 // -----
 
+func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
+func.func @omp_target_depend(%data_var: memref<i32>) {
+  // expected-error @below {{op expected as many depend values as depend variables}}
+    "omp.target"(%data_var) ({
+      "omp.terminator"() : () -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
+   "func.return"() : () -> ()
+}
+
+// -----
+
 func.func @omp_distribute(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 651405964c0675..3bb4a288376ede 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %map1:
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1717,6 +1717,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   return
 }
 
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+  // CHECK:  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+    // CHECK: omp.terminator
+    omp.terminator
+  } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+  return
+}
+
 func.func @omp_threadprivate() {
   %0 = arith.constant 1 : i32
   %1 = arith.constant 2 : i32
@@ -2145,3 +2157,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
   }
   return
 }
+
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+  %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+  // Do some work on the host that writes to 'a'
+  omp.task depend(taskdependout -> %a : memref<?xi32>) {
+    "test.foo"(%a) : (memref<?xi32>) -> ()
+    omp.terminator
+  }
+
+  // Then map that over to the target
+  // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin ->  %a: memref<?xi32>)
+
+  // Compute 'b' on the target and copy it back
+  // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+  omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>) :
+      "test.foo"(%arg0) : (memref<?xi32>) -> ()
+      omp.terminator
+  }
+
+  // Update 'a' on the host using 'b'
+  omp.task depend(taskdependout -> %a: memref<?xi32>){
+    "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+  }
+
+  // Copy the updated 'a' onto the target
+  // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  omp.target_update_data motion_entries(%map_a :  memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+  // Compute 'c' on the target and copy it back
+  %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+  ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+    "test.foobar"() : ()->()
+    omp.terminator
+  }
+  // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
+  omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+  return
+}


        


More information about the flang-commits mailing list