[Mlir-commits] [mlir] [OpenMP][MLIR] Add `private` clause to `omp.target` (PR #91202)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon May 6 06:15:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Kareem Ergawy (ergawy)

<details>
<summary>Changes</summary>

Starts the effort to support delayed privatization for `omp.target`. This PR extends the `omp.target` MLIR op with a `private` clause similar to what we currently have for `omp.parallel` in order to model privatized variables.

---
Full diff: https://github.com/llvm/llvm-project/pull/91202.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+5-1) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+51-4) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+39-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e620..a641588eaa8d42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
                        Variadic<OpenMP_PointerLikeType>:$has_device_addr,
-                       Variadic<AnyType>:$map_operands);
+                       Variadic<AnyType>:$map_operands,
+                       Variadic<AnyType>:$private_vars,
+                       OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
   let regions = (region AnyRegion:$region);
 
   let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
     | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
     | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
     | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea981..cedcc40864d663 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,14 +469,18 @@ ParseResult parseClauseWithRegionArgs(
 static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
                                       ValueRange argsSubrange,
                                       StringRef clauseName, ValueRange operands,
-                                      TypeRange types, ArrayAttr symbols) {
-  p << clauseName << "(";
+                                      TypeRange types, ArrayAttr symbols,
+                                      bool printPrefixSuffix = true) {
+  if (printPrefixSuffix)
+    p << clauseName << "(";
+
   llvm::interleaveComma(
       llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
-  p << ") ";
+  if (printPrefixSuffix)
+    p << ") ";
 }
 
 static ParseResult parseParallelRegion(
@@ -1048,6 +1052,48 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
   }
 }
 
+static ParseResult parsePrivateList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+    SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+  SmallVector<SymbolRefAttr> privateSymRefs;
+  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+            parser.parseOperand(privateOperands.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+            parser.parseColonType(privateOperandTypes.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+
+  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+                                         privateSymRefs.end());
+  privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+  return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+                             ValueRange privateVarOperands,
+                             TypeRange privateVarTypes,
+                             ArrayAttr privatizerSymbols) {
+  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+  assert(targetOp);
+
+  auto &region = op->getRegion(0);
+  auto *argsBegin = region.front().getArguments().begin();
+  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+                               argsBegin + targetOp.getMapOperands().size() +
+                                   privateVarTypes.size());
+  printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
+                            privateVarTypes, privatizerSymbols,
+                            /*printPrefixSuffix=*/false);
+}
+
 static void printCaptureType(OpAsmPrinter &p, Operation *op,
                              VariableCaptureKindAttr mapCaptureType) {
   std::string typeCapStr;
@@ -1262,7 +1308,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
       builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
-      clauses.mapVars);
+      clauses.mapVars, clauses.privateVars,
+      ArrayAttr::get(builder.getContext(), clauses.privatizers));
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c6875..138c2c9d418dc3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b73..f0b76c117a4568 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
   }
   return
 }
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // CHECK: omp.target
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  // CHECK: omp.target
+
+  // CHECK-SAME: map_entries(
+  // CHECK-SAME:   %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
+  // CHECK-SAME: )
+
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+  // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/91202


More information about the Mlir-commits mailing list