[Mlir-commits] [mlir] 427beff - [OpenMP][MLIR] Add `private` clause to `omp.target` (#91202)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 9 19:20:47 PDT 2024


Author: Kareem Ergawy
Date: 2024-05-10T04:20:43+02:00
New Revision: 427beff2ad274f38f9de682f48f550cdcf5fc505

URL: https://github.com/llvm/llvm-project/commit/427beff2ad274f38f9de682f48f550cdcf5fc505
DIFF: https://github.com/llvm/llvm-project/commit/427beff2ad274f38f9de682f48f550cdcf5fc505.diff

LOG: [OpenMP][MLIR] Add `private` clause to `omp.target` (#91202)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e62..a641588eaa8d4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
                        Variadic<OpenMP_PointerLikeType>:$has_device_addr,
-                       Variadic<AnyType>:$map_operands);
+                       Variadic<AnyType>:$map_operands,
+                       Variadic<AnyType>:$private_vars,
+                       OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
   let regions = (region AnyRegion:$region);
 
   let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
     | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
     | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
     | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea98..e016a326ecc78 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -470,13 +470,17 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
                                       ValueRange argsSubrange,
                                       StringRef clauseName, ValueRange operands,
                                       TypeRange types, ArrayAttr symbols) {
-  p << clauseName << "(";
+  if (!clauseName.empty())
+    p << clauseName << "(";
+
   llvm::interleaveComma(
       llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
-  p << ") ";
+
+  if (!clauseName.empty())
+    p << ") ";
 }
 
 static ParseResult parseParallelRegion(
@@ -1048,6 +1052,49 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
   }
 }
 
+static ParseResult parsePrivateList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+    SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+  SmallVector<SymbolRefAttr> privateSymRefs;
+  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+            parser.parseOperand(privateOperands.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+            parser.parseColonType(privateOperandTypes.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+
+  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+                                         privateSymRefs.end());
+  privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+  return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+                             ValueRange privateVarOperands,
+                             TypeRange privateVarTypes,
+                             ArrayAttr privatizerSymbols) {
+  // TODO: Remove target-specific logic from this function.
+  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+  assert(targetOp);
+
+  auto &region = op->getRegion(0);
+  auto *argsBegin = region.front().getArguments().begin();
+  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+                               argsBegin + targetOp.getMapOperands().size() +
+                                   privateVarTypes.size());
+  printClauseWithRegionArgs(
+      p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
+      privateVarTypes, privatizerSymbols);
+}
+
 static void printCaptureType(OpAsmPrinter &p, Operation *op,
                              VariableCaptureKindAttr mapCaptureType) {
   std::string typeCapStr;
@@ -1256,13 +1303,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                      const TargetClauseOps &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
-  // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
-  // reductionByRefAttr, reductionDeclSymbols.
+  // inReductionDeclSymbols, reductionVars, reductionByRefAttr,
+  // reductionDeclSymbols.
   TargetOp::build(
       builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
-      clauses.mapVars);
+      clauses.mapVars, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privatizers));
 }
 
 LogicalResult TargetOp::verify() {

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c687..138c2c9d418dc 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b7..828c9d2c3b84f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
   }
   return
 }
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // CHECK: omp.target
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  // CHECK: omp.target
+
+  // CHECK-SAME: map_entries(
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
+  // CHECK-SAME: )
+
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+  // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  return
+}


        


More information about the Mlir-commits mailing list