[Mlir-commits] [mlir] 427beff - [OpenMP][MLIR] Add `private` clause to `omp.target` (#91202)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 9 19:20:47 PDT 2024
Author: Kareem Ergawy
Date: 2024-05-10T04:20:43+02:00
New Revision: 427beff2ad274f38f9de682f48f550cdcf5fc505
URL: https://github.com/llvm/llvm-project/commit/427beff2ad274f38f9de682f48f550cdcf5fc505
DIFF: https://github.com/llvm/llvm-project/commit/427beff2ad274f38f9de682f48f550cdcf5fc505.diff
LOG: [OpenMP][MLIR] Add `private` clause to `omp.target` (#91202)
Added:
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e62..a641588eaa8d4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
- Variadic<AnyType>:$map_operands);
+ Variadic<AnyType>:$map_operands,
+ Variadic<AnyType>:$private_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
let regions = (region AnyRegion:$region);
let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+ | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea98..e016a326ecc78 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -470,13 +470,17 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
TypeRange types, ArrayAttr symbols) {
- p << clauseName << "(";
+ if (!clauseName.empty())
+ p << clauseName << "(";
+
llvm::interleaveComma(
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
- p << ") ";
+
+ if (!clauseName.empty())
+ p << ") ";
}
static ParseResult parseParallelRegion(
@@ -1048,6 +1052,49 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
}
}
+static ParseResult parsePrivateList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+ SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+ SmallVector<SymbolRefAttr> privateSymRefs;
+ SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+ parser.parseOperand(privateOperands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+ parser.parseColonType(privateOperandTypes.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+ privateSymRefs.end());
+ privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+ return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+ ValueRange privateVarOperands,
+ TypeRange privateVarTypes,
+ ArrayAttr privatizerSymbols) {
+ // TODO: Remove target-specific logic from this function.
+ auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+ assert(targetOp);
+
+ auto ®ion = op->getRegion(0);
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+ argsBegin + targetOp.getMapOperands().size() +
+ privateVarTypes.size());
+ printClauseWithRegionArgs(
+ p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
+ privateVarTypes, privatizerSymbols);
+}
+
static void printCaptureType(OpAsmPrinter &p, Operation *op,
VariableCaptureKindAttr mapCaptureType) {
std::string typeCapStr;
@@ -1256,13 +1303,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
const TargetClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
- // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
- // reductionByRefAttr, reductionDeclSymbols.
+ // inReductionDeclSymbols, reductionVars, reductionByRefAttr,
+ // reductionDeclSymbols.
TargetOp::build(
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
- clauses.mapVars);
+ clauses.mapVars, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privatizers));
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c687..138c2c9d418dc 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b7..828c9d2c3b84f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
}
return
}
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+ %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // CHECK: omp.target
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+
+ // CHECK-SAME: map_entries(
+ // CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
+ // CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : memref<?xi32>, memref<?xi32>
+ // CHECK-SAME: )
+
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+ // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
More information about the Mlir-commits
mailing list