[Mlir-commits] [mlir] [mlir][openmp] - Add the depend clause to omp.target and related offloading directives (PR #81081)

Pranav Bhandarkar llvmlistbot at llvm.org
Wed Feb 7 18:54:53 PST 2024


https://github.com/bhandarkar-pranav created https://github.com/llvm/llvm-project/pull/81081

This patch adds support for the depend clause in a number of OpenMP directives/constructs related to offloading. Specifically, it adds the handling of the depend clause when it is used with the following constructs

- target
- target enter data
- target update data
- target exit data

>From 79f720064392432542fd41f0f0fec89c14002dea Mon Sep 17 00:00:00 2001
From: Pranav Bhandarkar <pranav.bhandarkar at amd.com>
Date: Wed, 7 Feb 2024 20:44:38 -0600
Subject: [PATCH] [mlir][openmp] - Add the depend clause to omp.target and
 related offloading directives

This patch adds support for the depend clause in a number of OpenMP directives/constructs related to offloading.
Specifically, it adds the handling of the depend clause when it is used with the following constructs

- target
- target enter data
- target update data
- target exit data
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 37 +++++++++--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 22 +++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 37 +++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 63 ++++++++++++++++++-
 4 files changed, 148 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ca36350548577..3c9b59cea181a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -693,7 +693,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
 
 def ClauseTaskDepend : I32EnumAttr<
     "ClauseTaskDepend",
-    "task depend clause",
+    "depend clause in a target or task construct",
     [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::omp";
@@ -1359,11 +1359,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
 
     The $map_types specifies the types and modifiers for the map clause.
 
-    TODO:  depend clause and map_type_modifier values iterator and mapper.
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO:  map_type_modifier values iterator and mapper.
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1372,6 +1378,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `map_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1406,11 +1413,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
 
     The $map_types specifies the types and modifiers for the map clause.
 
-    TODO:  depend clause and map_type_modifier values iterator and mapper.
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO: map_type_modifier values iterator and mapper.
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1419,6 +1432,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `map_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1457,11 +1471,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
     during verification to make sure the restrictions for target update are
     respected.
 
-    TODO: depend clause
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$map_operands);
 
@@ -1470,6 +1489,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
     | `device` `(` $device `:` type($device) `)`
     | `nowait` $nowait
     | `motion_entries` `(` $map_operands `:` type($map_operands) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) attr-dict
    }];
 
@@ -1499,13 +1519,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
     The optional $nowait elliminates the implicit barrier so the parent task can make progress
     even if the target task is not yet completed.
 
-    TODO:  is_device_ptr, depend, defaultmap, in_reduction
+    The `depends` and `depend_vars` arguments are variadic lists of values
+    that specify the dependencies of this particular target task in relation to
+    other tasks.
+
+    TODO:  is_device_ptr, defaultmap, in_reduction
 
   }];
 
   let arguments = (ins Optional<I1>:$if_expr,
                        Optional<AnyInteger>:$device,
                        Optional<AnyInteger>:$thread_limit,
+                       OptionalAttr<TaskDependArrayAttr>:$depends,
+                       Variadic<OpenMP_PointerLikeType>:$depend_vars,
                        UnitAttr:$nowait,
                        Variadic<AnyType>:$map_operands);
 
@@ -1517,6 +1543,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
     | `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
     | `nowait` $nowait
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 381f17d080419..11800b3cd6ce5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -561,7 +561,7 @@ static LogicalResult verifyDependVarList(Operation *op,
       return op->emitOpError() << "expected as many depend values"
                                   " as depend variables";
   } else {
-    if (depends)
+    if (depends && !depends->empty())
       return op->emitOpError() << "unexpected depend values";
     return success();
   }
@@ -965,19 +965,31 @@ LogicalResult DataOp::verify() {
 }
 
 LogicalResult EnterDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult ExitDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult UpdateDataOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 LogicalResult TargetOp::verify() {
-  return verifyMapClause(*this, getMapOperands());
+  LogicalResult verifyDependVars =
+      verifyDependVarList(*this, getDepends(), getDependVars());
+  return failed(verifyDependVars) ? verifyDependVars
+                                  : verifyMapClause(*this, getMapOperands());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 812b79e35595f..481aa950ceeef 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
 func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
   // expected-error @below {{from, release and delete map types are permitted}}
@@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
 func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
   %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
 
@@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32
 
 // -----
 
+func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
+  %0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  // expected-error @below {{op expected as many depend values as depend variables}}
+  omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
+  return
+}
+
+// -----
+
+func.func @omp_target_depend(%data_var: memref<i32>) {
+  // expected-error @below {{op expected as many depend values as depend variables}}
+    "omp.target"(%data_var) ({
+      "omp.terminator"() : () -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
+   "func.return"() : () -> ()
+}
+
+// -----
+
 func.func @omp_distribute(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 65a704d18107b..950558483d21f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %map1:
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1710,6 +1710,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
   return
 }
 
+
+// CHECK-LABEL: @omp_target_depend
+// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
+func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
+  // CHECK:  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+  omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
+    // CHECK: omp.terminator
+    omp.terminator
+  } {operandSegmentSizes = array<i32: 0,0,0,3,0>}
+  return
+}
+
 func.func @omp_threadprivate() {
   %0 = arith.constant 1 : i32
   %1 = arith.constant 2 : i32
@@ -2138,3 +2150,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
   }
   return
 }
+
+// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
+// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
+func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
+// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
+// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
+  %map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
+  %map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  %map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
+
+  // Do some work on the host that writes to 'a'
+  omp.task depend(taskdependout -> %a : memref<?xi32>) {
+    "test.foo"(%a) : (memref<?xi32>) -> ()
+    omp.terminator
+  }
+
+  // Then map that over to the target
+  // CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin ->  %a: memref<?xi32>)
+
+  // Compute 'b' on the target and copy it back
+  // CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
+  omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>) :
+      "test.foo"(%arg0) : (memref<?xi32>) -> ()
+      omp.terminator
+  }
+
+  // Update 'a' on the host using 'b'
+  omp.task depend(taskdependout -> %a: memref<?xi32>){
+    "test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
+  }
+
+  // Copy the updated 'a' onto the target
+  // CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
+  omp.target_update_data motion_entries(%map_a :  memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
+
+  // Compute 'c' on the target and copy it back
+  %map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
+  omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
+  ^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
+    "test.foobar"() : ()->()
+    omp.terminator
+  }
+  // CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
+  omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
+  return
+}



More information about the Mlir-commits mailing list