[flang-commits] [flang] [mlir] [MLIR][OpenMP] Support basic materialization for `omp.private` ops (PR #81715)
via flang-commits
flang-commits at lists.llvm.org
Wed Feb 14 00:20:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Kareem Ergawy (ergawy)
<details>
<summary>Changes</summary>
Adds basic support for materializing delayed privatization. So far, the
restrictions on the implementation are:
- Only `private` clauses are supported (`firstprivate` support will be
added in a later PR).
- Only single-block `omp.private -> alloc` regions are supported
(multi-block ones will be supported in a later PR).
### This is a follow-up to both #<!-- -->81414 & #<!-- -->81452, only the latest commit (with the same title as the PR) is relevant to this PR.
---
Patch is 39.55 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81715.diff
12 Files Affected:
- (modified) flang/lib/Lower/OpenMP.cpp (+2-1)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+10-2)
- (modified) mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp (+19-13)
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+3-1)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+141-48)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+108-6)
- (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+26)
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+56)
- (added) mlir/test/Dialect/OpenMP/ops-2.mlir (+74)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+5-5)
- (removed) mlir/test/Dialect/OpenMP/roundtrip.mlir (-21)
- (added) mlir/test/Target/LLVMIR/openmp-private.mlir (+91)
``````````diff
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 24f91765cb439b..74b2727961a03d 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -2640,7 +2640,8 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
? nullptr
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
reductionDeclSymbols),
- procBindKindAttr);
+ procBindKindAttr, /*private_vars=*/llvm::SmallVector<mlir::Value>{},
+ /*privatizers=*/nullptr);
}
static mlir::omp::SectionOp
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 0adf186ae0c7e9..9ed40904f22ae2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -221,6 +221,12 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> {
attr-dict
}];
+ let builders = [
+ OpBuilder<(ins CArg<"TypeRange">:$result,
+ CArg<"StringAttr">:$sym_name,
+ CArg<"TypeAttr">:$type)>
+ ];
+
let hasVerifier = 1;
}
@@ -270,7 +276,9 @@ def ParallelOp : OpenMP_Op<"parallel", [
Variadic<AnyType>:$allocators_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
- OptionalAttr<ProcBindKindAttr>:$proc_bind_val);
+ OptionalAttr<ProcBindKindAttr>:$proc_bind_val,
+ Variadic<AnyType>:$private_vars,
+ OptionalAttr<SymbolRefArrayAttr>:$privatizers);
let regions = (region AnyRegion:$region);
@@ -291,7 +299,7 @@ def ParallelOp : OpenMP_Op<"parallel", [
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
- ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
+ ) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions, $private_vars, type($private_vars), $privatizers) attr-dict
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 730858ffc67a71..2eba4fba105c7b 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -200,16 +200,20 @@ struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
}
};
-struct ReductionDeclareOpConversion
- : public ConvertOpToLLVMPattern<omp::ReductionDeclareOp> {
- using ConvertOpToLLVMPattern<omp::ReductionDeclareOp>::ConvertOpToLLVMPattern;
+template <typename OpType>
+struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
+ using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(omp::ReductionDeclareOp curOp, OpAdaptor adaptor,
+ matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto newOp = rewriter.create<omp::ReductionDeclareOp>(
+ auto newOp = rewriter.create<OpType>(
curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
TypeAttr::get(this->getTypeConverter()->convertType(
curOp.getTypeAttr().getValue())));
+
+ if constexpr (std::is_same_v<OpType, mlir::omp::PrivateClauseOp>)
+ newOp.setDataSharingType(curOp.getDataSharingType());
+
for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
newOp.getRegion(idx).end());
@@ -231,11 +235,12 @@ void mlir::configureOpenMPToLLVMConversionLegality(
mlir::omp::DataOp, mlir::omp::OrderedRegionOp, mlir::omp::ParallelOp,
mlir::omp::WsLoopOp, mlir::omp::SimdLoopOp, mlir::omp::MasterOp,
mlir::omp::SectionOp, mlir::omp::SectionsOp, mlir::omp::SingleOp,
- mlir::omp::TaskGroupOp, mlir::omp::TaskOp>([&](Operation *op) {
- return typeConverter.isLegal(&op->getRegion(0)) &&
- typeConverter.isLegal(op->getOperandTypes()) &&
- typeConverter.isLegal(op->getResultTypes());
- });
+ mlir::omp::TaskGroupOp, mlir::omp::TaskOp, mlir::omp::PrivateClauseOp>(
+ [&](Operation *op) {
+ return typeConverter.isLegal(&op->getRegion(0)) &&
+ typeConverter.isLegal(op->getOperandTypes()) &&
+ typeConverter.isLegal(op->getResultTypes());
+ });
target.addDynamicallyLegalOp<
mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp, mlir::omp::FlushOp,
mlir::omp::ThreadprivateOp, mlir::omp::YieldOp, mlir::omp::EnterDataOp,
@@ -267,9 +272,10 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
patterns.add<
AtomicReadOpConversion, MapInfoOpConversion, ReductionOpConversion,
- ReductionDeclareOpConversion, RegionOpConversion<omp::CriticalOp>,
- RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
- RegionOpConversion<omp::OrderedRegionOp>,
+ MultiRegionOpConversion<omp::ReductionDeclareOp>,
+ MultiRegionOpConversion<omp::PrivateClauseOp>,
+ RegionOpConversion<omp::CriticalOp>, RegionOpConversion<omp::MasterOp>,
+ ReductionOpConversion, RegionOpConversion<omp::OrderedRegionOp>,
RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::WsLoopOp>,
RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SectionOp>,
RegionOpConversion<omp::SimdLoopOp>, RegionOpConversion<omp::SingleOp>,
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ea5f31ee8c6aa7..464a647564aced 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,7 +450,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocators_vars = */ llvm::SmallVector<Value>{},
/* reduction_vars = */ llvm::SmallVector<Value>{},
/* reductions = */ ArrayAttr{},
- /* proc_bind_val = */ omp::ClauseProcBindKindAttr{});
+ /* proc_bind_val = */ omp::ClauseProcBindKindAttr{},
+ /* private_vars = */ ValueRange(),
+ /* privatizers = */ nullptr);
{
OpBuilder::InsertionGuard guard(rewriter);
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 13fc01d58eced5..3d18b9fe13e42c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -430,68 +430,102 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
-ParseResult
-parseReductionClause(OpAsmParser &parser, Region ®ion,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
- SmallVectorImpl<OpAsmParser::Argument> &privates) {
- if (failed(parser.parseOptionalKeyword("reduction")))
- return failure();
-
+ParseResult parseClauseWithRegionArgs(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
+ SmallVectorImpl<Type> &types, ArrayAttr &symbols,
+ SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs) {
SmallVector<SymbolRefAttr> reductionVec;
+ unsigned regionArgOffset = regionPrivateArgs.size();
if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
- parser.parseArgument(privates.emplace_back()) ||
+ parser.parseArgument(regionPrivateArgs.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();
- for (auto [prv, type] : llvm::zip_equal(privates, types)) {
+ auto *argsBegin = regionPrivateArgs.begin();
+ MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
+ argsBegin + regionArgOffset + types.size());
+ for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
- reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
+ symbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}
-static void printReductionClause(OpAsmPrinter &p, Operation *op,
- ValueRange reductionArgs, ValueRange operands,
- TypeRange types, ArrayAttr reductionSymbols) {
- p << "reduction(";
+static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
+ ValueRange argsSubrange,
+ StringRef clauseName, ValueRange operands,
+ TypeRange types, ArrayAttr symbols) {
+ p << clauseName << "(";
llvm::interleaveComma(
- llvm::zip_equal(reductionSymbols, operands, reductionArgs, types), p,
- [&p](auto t) {
+ llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : " << type;
});
p << ") ";
}
-static ParseResult
-parseParallelRegion(OpAsmParser &parser, Region ®ion,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
- SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {
+static ParseResult parseParallelRegion(
+ OpAsmParser &parser, Region ®ion,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
+ SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
+ llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
+ llvm::SmallVectorImpl<Type> &privateVarsTypes,
+ ArrayAttr &privatizerSymbols) {
+ llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
- llvm::SmallVector<OpAsmParser::Argument> privates;
- if (succeeded(parseReductionClause(parser, region, operands, types,
- reductionSymbols, privates)))
- return parser.parseRegion(region, privates);
+ if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+ if (failed(parseClauseWithRegionArgs(parser, region, reductionVarOperands,
+ reductionVarTypes, reductionSymbols,
+ regionPrivateArgs)))
+ return failure();
+ }
- return parser.parseRegion(region);
+ if (succeeded(parser.parseOptionalKeyword("private"))) {
+ if (failed(parseClauseWithRegionArgs(parser, region, privateVarOperands,
+ privateVarsTypes, privatizerSymbols,
+ regionPrivateArgs)))
+ return failure();
+ }
+
+ return parser.parseRegion(region, regionPrivateArgs);
}
static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
- ValueRange operands, TypeRange types,
- ArrayAttr reductionSymbols) {
- if (reductionSymbols)
- printReductionClause(p, op, region.front().getArguments(), operands, types,
- reductionSymbols);
+ ValueRange reductionVarOperands,
+ TypeRange reductionVarTypes,
+ ArrayAttr reductionSymbols,
+ ValueRange privateVarOperands,
+ TypeRange privateVarTypes,
+ ArrayAttr privatizerSymbols) {
+ if (reductionSymbols) {
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin,
+ argsBegin + reductionVarTypes.size());
+ printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
+ reductionVarOperands, reductionVarTypes,
+ reductionSymbols);
+ }
+
+ if (privatizerSymbols) {
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
+ argsBegin + reductionVarOperands.size() +
+ privateVarTypes.size());
+ printClauseWithRegionArgs(p, op, argsSubrange, "private",
+ privateVarOperands, privateVarTypes,
+ privatizerSymbols);
+ }
+
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
@@ -1008,9 +1042,8 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
}
if (always || close || implicit) {
- return emitError(
- op->getLoc(),
- "present, mapper and iterator map type modifiers are permitted");
+ return emitError(op->getLoc(), "present, mapper and iterator map "
+ "type modifiers are permitted");
}
to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
@@ -1070,14 +1103,63 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
/*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
/*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
- /*proc_bind_val=*/nullptr);
+ /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
+ /*privatizers=*/nullptr);
state.addAttributes(attributes);
}
+static LogicalResult verifyPrivateVarList(ParallelOp &op) {
+ auto privateVars = op.getPrivateVars();
+ auto privatizers = op.getPrivatizersAttr();
+
+ if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
+ return success();
+
+ auto numPrivateVars = privateVars.size();
+ auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
+
+ if (numPrivateVars != numPrivatizers)
+ return op.emitError() << "inconsistent number of private variables and "
+ "privatizer op symbols, private vars: "
+ << numPrivateVars
+ << " vs. privatizer op symbols: " << numPrivatizers;
+
+ for (auto privateVarInfo : llvm::zip(privateVars, privatizers)) {
+ Type varType = std::get<0>(privateVarInfo).getType();
+ SymbolRefAttr privatizerSym =
+ std::get<1>(privateVarInfo).cast<SymbolRefAttr>();
+ PrivateClauseOp privatizerOp =
+ SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
+ privatizerSym);
+
+ if (privatizerOp == nullptr)
+ return op.emitError() << "failed to lookup privatizer op with symbol: '"
+ << privatizerSym << "'";
+
+ Type privatizerType = privatizerOp.getType();
+
+ if (varType != privatizerType)
+ return op.emitError()
+ << "type mismatch between a "
+ << (privatizerOp.getDataSharingType() ==
+ DataSharingClauseType::Private
+ ? "private"
+ : "firstprivate")
+ << " variable and its privatizer op, var type: " << varType
+ << " vs. privatizer op type: " << privatizerType;
+ }
+
+ return success();
+}
+
LogicalResult ParallelOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+
+ if (failed(verifyPrivateVarList(*this)))
+ return failure();
+
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}
@@ -1111,8 +1193,8 @@ LogicalResult TeamsOp::verify() {
return emitError("expected num_teams upper bound to be defined if the "
"lower bound is defined");
if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
- return emitError(
- "expected num_teams upper bound and lower bound to be the same type");
+ return emitError("expected num_teams upper bound and lower bound to be "
+ "the same type");
}
// Check for allocate clause restrictions
@@ -1174,9 +1256,10 @@ parseWsLoop(OpAsmParser &parser, Region ®ion,
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
- bool hasReduction = succeeded(
- parseReductionClause(parser, region, reductionOperands, reductionTypes,
- reductionSymbols, privates));
+ bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
+ succeeded(parseClauseWithRegionArgs(
+ parser, region, reductionOperands, reductionTypes,
+ reductionSymbols, privates));
if (parser.parseKeyword("for"))
return failure();
@@ -1223,8 +1306,9 @@ void printWsLoop(OpAsmPrinter &p, Operation *op, Region ®ion,
if (reductionSymbols) {
auto reductionArgs =
region.front().getArguments().drop_front(loopVarTypes.size());
- printReductionClause(p, op, reductionArgs, reductionOperands,
- reductionTypes, reductionSymbols);
+ printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
+ reductionOperands, reductionTypes,
+ reductionSymbols);
}
p << " for ";
@@ -1464,9 +1548,9 @@ LogicalResult TaskLoopOp::verify() {
}
if (getGrainSize() && getNumTasks()) {
- return emitError(
- "the grainsize clause and num_tasks clause are mutually exclusive and "
- "may not appear on the same taskloop directive");
+ return emitError("the grainsize clause and num_tasks clause are mutually "
+ "exclusive and "
+ "may not appear on the same taskloop directive");
}
return success();
}
@@ -1535,7 +1619,8 @@ LogicalResult OrderedOp::verify() {
}
LogicalResult OrderedRegionOp::verify() {
- // TODO: The code generation for ordered simd directive is not supported yet.
+ // TODO: The code generation for ordered simd directive is not supported
+ // yet.
if (getSimd())
return failure();
@@ -1753,6 +1838,15 @@ LogicalResult DataBoundsOp::verify() {
return success();
}
+void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ TypeRange /*result_types*/, StringAttr symName,
+ TypeAttr type) {
+ PrivateClauseOp::build(
+ odsBuilder, odsState, symName, type,
+ DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
+ DataSharingClauseType::Private));
+}
+
LogicalResult PrivateClauseOp::verify() {
Type symType = getType();
@@ -1785,8 +1879,7 @@ LogicalResult PrivateClauseOp::verify() {
if (region.getNumArguments() != expectedNumArgs)
return mlir::emitError(region.getLoc())
- << "`" << regionName << "`: "
- << "expected " << expectedNumArgs
+ << "`" << regionName << "`: " << "expected " << expectedNumArgs
<< " region arguments, got: " << region.getNumArguments();
for (Block &block : region) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 78a2ad76a1e3b8..6b59dc7377fc74 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1000,6 +1000,26 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
return success();
}
+/// Replace the region arguments of the parallel op (which correspond to private
+/// variables) with the actual private varibles t...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81715
More information about the flang-commits
mailing list