[Mlir-commits] [flang] [mlir] [MLIR][OpenMP] Add `private` clause to `omp.parallel` (PR #81452)
Kareem Ergawy
llvmlistbot at llvm.org
Mon Feb 12 20:39:24 PST 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/81452
>From a2aba2c8fbeadf88c661306de2e538436da4de57 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Sun, 11 Feb 2024 09:37:59 -0600
Subject: [PATCH] [MLIR][OpenMP] Add `private` clause to `omp.parallel`
Extends the `omp.parallel` op by adding a `private` clause to model
[first]private variables. This uses the `omp.private` op to map
privatized variables to their corresponding privatizers.
Example `omp.private` op with `private` variable:
```
omp.parallel private(@x.privatizer %arg0 -> %arg1 : !llvm.ptr) {
^bb0(%arg1: !llvm.ptr):
// ... use %arg1 ...
omp.terminator
}
```
Whether the variable is private or firstprivate is determined by the
attributes of the corresponding `omp.private` op.
---
flang/lib/Lower/OpenMP.cpp | 3 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 8 +-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 115 +++++++++++++++++-
mlir/test/Dialect/OpenMP/invalid.mlir | 59 +++++++++
mlir/test/Dialect/OpenMP/ops.mlir | 10 +-
mlir/test/Dialect/OpenMP/roundtrip.mlir | 11 +-
7 files changed, 199 insertions(+), 11 deletions(-)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index e5887620d503b9..791f2c29610205 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
- procBindKindAttr);
+ procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
+ /*privatizers=*/nullptr);
}
static mlir::omp::SectionOp
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 44f3e5b8dbc361..159efc39a00944 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -270,7 +270,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
- OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
+ OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
+ Variadic<AnyType>:$private_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizers);
let regions = (region AnyRegion:$region);
@@ -291,6 +293,10 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
+ | `private` `(`
+ custom<PrivateVarList>(
+ $private_vars, type($private_vars), $privatizers
+ ) `)`
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 2f8b3f7e11de15..4f22272d9aa3d1 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -420,7 +420,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reductions = */ ArrayAttr{},
- /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
+ /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
+ /* private_vars = */ ValueRange(),
+ /* privatizers = */ nullptr);
{
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index ef08bd87efc93a..ca87b7f55dd511 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -491,7 +491,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region, operands, types, reductionSymbols);
- p.printRegion(region, /*printEntryBlockArgs=*/false);
+ p.printRegion(region, /*printEntryBlockArgs=*/true);
}
/// reduction-entry-list ::= reduction-entry
@@ -1057,14 +1057,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
- /*proc_bind_val=*/nullptr);
+ /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
+ /*privatizers=*/nullptr);
state.addAttributes(attributes);
}
+static LogicalResult verifyPrivateVarList(ParallelOp &op) {
+ auto privateVars = op.getPrivateVars();
+ auto privatizers = op.getPrivatizersAttr();
+
+ if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
+ return success();
+
+ auto numPrivateVars = privateVars.size();
+ auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
+
+ if (numPrivateVars != numPrivatizers)
+ return op.emitError() << "inconsistent number of private variables and "
+ "privatizer op symbols, private vars: "
+ << numPrivateVars
+ << " vs. privatizer op symbols: " << numPrivatizers;
+
+ for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
+ Type varType = std::get<0>(privateVarInfo).getType();
+ SymbolRefAttr privatizerSym =
+ std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
+ PrivateClauseOp privatizerOp =
+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
+ privatizerSym);
+
+ if (privatizerOp == nullptr)
+ return op.emitError() << "failed to lookup privatizer op with symbol: '"
+ << privatizerSym << "'";
+
+ Type privatizerType = privatizerOp.getType();
+
+ if (varType != privatizerType)
+ return op.emitError()
+ << "type mismatch between a "
+ << (privatizerOp.getDataSharingType() ==
+ DataSharingClauseType::Private
+ ? "private"
+ : "firstprivate")
+ << " variable and its privatizer op, var type: " << varType
+ << " vs. privatizer op type: " << privatizerType;
+ }
+
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+
+ if (failed(verifyPrivateVarList(*this)))
+ return failure();
+
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}
@@ -1729,6 +1778,68 @@ LogicalResult PrivateClauseOp::verify() {
return success();
}
+static ParseResult parsePrivateVarList(
+ OpAsmParser &parser,
+ llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> &privateVarsOperands,
+ llvm::SmallVector<Type, 1> &privateVarsTypes, ArrayAttr &privatizersAttr) {
+ SymbolRefAttr privatizerSym;
+ OpAsmParser::UnresolvedOperand privateArg;
+ OpAsmParser::UnresolvedOperand regionArg;
+ Type argType;
+
+ SmallVector<SymbolRefAttr> privatizersVec;
+
+ auto parsePrivatizers = [&]() -> ParseResult {
+ // @privatizer %var -> %region_arg : type
+ if (parser.parseAttribute(privatizerSym) ||
+ parser.parseOperand(privateArg) || parser.parseArrow() ||
+ parser.parseOperand(regionArg) || parser.parseColon() ||
+ parser.parseType(argType)) {
+ return failure();
+ }
+
+ privatizersVec.push_back(privatizerSym);
+ privateVarsOperands.push_back(privateArg);
+ privateVarsTypes.push_back(argType);
+
+ return success();
+ };
+
+ if (parser.parseCommaSeparatedList(parsePrivatizers))
+ return failure();
+
+ SmallVector<Attribute> privatizers(privatizersVec.begin(),
+ privatizersVec.end());
+ privatizersAttr = ArrayAttr::get(parser.getContext(), privatizers);
+
+ return success();
+}
+
+static void printPrivateVarList(OpAsmPrinter &printer, Operation *op,
+ OperandRange privateVars,
+ TypeRange privateVarTypes,
+ std::optional<ArrayAttr> privatizersAttr) {
+ unsigned argIndex = 0;
+ assert(privatizersAttr);
+
+ for (const auto &priateVarArgInfo :
+ llvm::zip(*privatizersAttr, privateVars, op->getRegion(0).getArguments(),
+ privateVarTypes)) {
+ assert(privatizersAttr);
+ printer << std::get<0>(priateVarArgInfo) << " "
+ << std::get<1>(priateVarArgInfo) << " -> ";
+
+ std::get<2>(priateVarArgInfo)
+ .printAsOperand(printer.getStream(), OpPrintingFlags());
+
+ printer << " : " << std::get<3>(priateVarArgInfo);
+
+ ++argIndex;
+ if (argIndex < privateVars.size())
+ printer << ", ";
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 59b42390b206f1..6b93bce5653901 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1801,3 +1801,62 @@ omp.private {type = firstprivate} @x.privatizer : f32 alloc {
^bb0(%arg0: f32):
omp.yield(%arg0 : f32)
}
+
+// -----
+
+func.func @private_type_mismatch(%arg0: index) {
+// expected-error @below {{type mismatch between a private variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
+ omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+ ^bb0(%arg2: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
+
+omp.private {type = private} @var1.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+ omp.yield(%arg0 : !llvm.ptr)
+}
+
+// -----
+
+func.func @firstprivate_type_mismatch(%arg0: index) {
+ // expected-error @below {{type mismatch between a firstprivate variable and its privatizer op, var type: 'index' vs. privatizer op type: '!llvm.ptr'}}
+ omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+ ^bb0(%arg2: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
+
+omp.private {type = firstprivate} @var1.privatizer : !llvm.ptr alloc {
+^bb0(%arg0: !llvm.ptr):
+ omp.yield(%arg0 : !llvm.ptr)
+} copy {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ omp.yield(%arg0 : !llvm.ptr)
+}
+
+// -----
+
+func.func @undefined_privatizer(%arg0: index) {
+ // expected-error @below {{failed to lookup privatizer op with symbol: '@var1.privatizer'}}
+ omp.parallel private(@var1.privatizer %arg0 -> %arg2 : index) {
+ ^bb0(%arg2: !llvm.ptr):
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+func.func @undefined_privatizer(%arg0: !llvm.ptr) {
+ // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1>, privatizers = [@x.privatizer, @y.privatizer]}> ({
+ ^bb0(%arg2: !llvm.ptr):
+ omp.terminator
+ }) : (!llvm.ptr) -> ()
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 651405964c0675..0bfe0e4f51059b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -59,7 +59,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%num_threads, %data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,1,1,1,0>} : (i32, memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 0,1,1,1,0,0>} : (i32, memref<i32>, memref<i32>) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -68,22 +68,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel"(%if_cond, %data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,1,1,0>} : (i1, memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,1,1,0,0>} : (i1, memref<i32>, memref<i32>) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_val = #omp<procbindkind spread>} : (i1, i32, memref<i32>, memref<i32>) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
return
}
diff --git a/mlir/test/Dialect/OpenMP/roundtrip.mlir b/mlir/test/Dialect/OpenMP/roundtrip.mlir
index 2553442638ee84..bb4e50da2cc1cb 100644
--- a/mlir/test/Dialect/OpenMP/roundtrip.mlir
+++ b/mlir/test/Dialect/OpenMP/roundtrip.mlir
@@ -1,5 +1,15 @@
// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s
+// CHECK-LABEL: parallel_op_privatizers
+func.func @parallel_op_privatizers(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ // CHECK: omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr)
+ omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
+ ^bb0(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+ omp.terminator
+ }
+ return
+}
+
// CHECK: omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
// CHECK: ^bb0(%arg0: {{.*}}):
@@ -18,4 +28,3 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
omp.yield(%arg0 : !llvm.ptr)
}
-
More information about the Mlir-commits
mailing list