[Mlir-commits] [mlir] [OpenMP][MLIR] Add `private` clause to `omp.target` (PR #91202)
Kareem Ergawy
llvmlistbot at llvm.org
Tue May 7 01:36:12 PDT 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/91202
>From d53d96bc8e110e26172dce0eca6dd7ebbae3d9a1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 6 May 2024 08:12:55 -0500
Subject: [PATCH 1/2] [OpenMP][MLIR] Add `private` clause to `omp.target`
Starts the effort to support delayed privatization for `omp.target`.
This PR extends the `omp.target` MLIR op with a `private` clause similar
to what we currently have for `omp.parallel` in order to model
privatized variables.
---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 6 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 55 +++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 2 +-
mlir/test/Dialect/OpenMP/ops.mlir | 40 +++++++++++++-
4 files changed, 96 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e620..a641588eaa8d42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
Variadic<OpenMP_PointerLikeType>:$has_device_addr,
- Variadic<AnyType>:$map_operands);
+ Variadic<AnyType>:$map_operands,
+ Variadic<AnyType>:$private_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
let regions = (region AnyRegion:$region);
let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
| `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
| `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+ | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea981..cedcc40864d663 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,14 +469,18 @@ ParseResult parseClauseWithRegionArgs(
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
- TypeRange types, ArrayAttr symbols) {
- p << clauseName << "(";
+ TypeRange types, ArrayAttr symbols,
+ bool printPrefixSuffix = true) {
+ if (printPrefixSuffix)
+ p << clauseName << "(";
+
llvm::interleaveComma(
llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
- p << ") ";
+ if (printPrefixSuffix)
+ p << ") ";
}
static ParseResult parseParallelRegion(
@@ -1048,6 +1052,48 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
}
}
+static ParseResult parsePrivateList(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+ SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+ SmallVector<SymbolRefAttr> privateSymRefs;
+ SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+ parser.parseOperand(privateOperands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+ parser.parseColonType(privateOperandTypes.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+
+ SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+ privateSymRefs.end());
+ privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+ return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+ ValueRange privateVarOperands,
+ TypeRange privateVarTypes,
+ ArrayAttr privatizerSymbols) {
+ auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+ assert(targetOp);
+
+ auto ®ion = op->getRegion(0);
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+ argsBegin + targetOp.getMapOperands().size() +
+ privateVarTypes.size());
+ printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
+ privateVarTypes, privatizerSymbols,
+ /*printPrefixSuffix=*/false);
+}
+
static void printCaptureType(OpAsmPrinter &p, Operation *op,
VariableCaptureKindAttr mapCaptureType) {
std::string typeCapStr;
@@ -1262,7 +1308,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
- clauses.mapVars);
+ clauses.mapVars, clauses.privateVars,
+ ArrayAttr::get(builder.getContext(), clauses.privatizers));
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c6875..138c2c9d418dc3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b73..f0b76c117a4568 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
}
return
}
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+ %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+ // CHECK: omp.target
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ // CHECK: omp.target
+
+ // CHECK-SAME: map_entries(
+ // CHECK-SAME: %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
+ // CHECK-SAME: %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : memref<?xi32>, memref<?xi32>
+ // CHECK-SAME: )
+
+ // CHECK-SAME: private(
+ // CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+ // CHECK-SAME: : !llvm.ptr
+ // CHECK-SAME: )
+ omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+ // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+ // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+ ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
>From d6f265b626969f51efc5f596ab8f82bbf0113674 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 7 May 2024 03:35:55 -0500
Subject: [PATCH 2/2] review comments
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 20 ++++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cedcc40864d663..432a0b0600e039 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,9 +469,8 @@ ParseResult parseClauseWithRegionArgs(
static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
ValueRange argsSubrange,
StringRef clauseName, ValueRange operands,
- TypeRange types, ArrayAttr symbols,
- bool printPrefixSuffix = true) {
- if (printPrefixSuffix)
+ TypeRange types, ArrayAttr symbols) {
+ if (!clauseName.empty())
p << clauseName << "(";
llvm::interleaveComma(
@@ -479,7 +478,8 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
- if (printPrefixSuffix)
+
+ if (!clauseName.empty())
p << ") ";
}
@@ -1089,9 +1089,9 @@ static void printPrivateList(OpAsmPrinter &p, Operation *op,
MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
argsBegin + targetOp.getMapOperands().size() +
privateVarTypes.size());
- printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
- privateVarTypes, privatizerSymbols,
- /*printPrefixSuffix=*/false);
+ printClauseWithRegionArgs(
+ p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
+ privateVarTypes, privatizerSymbols);
}
static void printCaptureType(OpAsmPrinter &p, Operation *op,
@@ -1302,14 +1302,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
const TargetClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
- // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
- // reductionByRefAttr, reductionDeclSymbols.
+ // inReductionDeclSymbols, reductionVars, reductionByRefAttr,
+ // reductionDeclSymbols.
TargetOp::build(
builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
clauses.mapVars, clauses.privateVars,
- ArrayAttr::get(builder.getContext(), clauses.privatizers));
+ makeArrayAttr(ctx, clauses.privatizers));
}
LogicalResult TargetOp::verify() {
More information about the Mlir-commits
mailing list