[flang-commits] [flang] [mlir] [MLIR][OpenMP] Normalize handling of entry block arguments (PR #109808)
Sergio Afonso via flang-commits
flang-commits at lists.llvm.org
Tue Sep 24 07:56:19 PDT 2024
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/109808
This patch introduces a new MLIR interface for the OpenMP dialect aimed at providing a uniform way of verifying and handling entry block arguments defined by OpenMP clauses.
The approach consists in defining a set of overrideable methods that return the number of block arguments the operation holds regarding each of the clauses that may define them. These by default return 0, but they are overriden by the corresponding clause through the `extraClassDeclaration` mechanism.
Another set of interface methods to get the actual lists of block arguments is defined, which is implemented based on the previously described methods. These implicitly define a standardized ordering between the list of block arguments associated to each clause, based on the alphabetical ordering of their names. They should be the preferred way of matching operation arguments and entry block arguments to that operation's first region.
Some updates are made to the printing/parsing of `omp.parallel` to follow the expected order between `private` and `reduction` clauses, as well as the MLIR to LLVM IR translation pass to access block arguments using the new interface. Unit tests of operations impacted by additional verification checks and sorting of entry block arguments.
>From 82a0c88f02c1eac5f4136a429dc331bab6c5ed58 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 18 Sep 2024 11:47:36 +0100
Subject: [PATCH] [MLIR][OpenMP] Normalize handling of entry block arguments
This patch introduces a new MLIR interface for the OpenMP dialect aimed at
providing a uniform way of verifying and handling entry block arguments defined
by OpenMP clauses.
The approach consists in defining a set of overrideable methods that return the
number of block arguments the operation holds regarding each of the clauses
that may define them. These by default return 0, but they are overriden by the
corresponding clause through the `extraClassDeclaration` mechanism.
Another set of interface methods to get the actual lists of block arguments is
defined, which is implemented based on the previously described methods. These
implicitly define a standardized ordering between the list of block arguments
associated to each clause, based on the alphabetical ordering of their names.
They should be the preferred way of matching operation arguments and entry
block arguments to that operation's first region.
Some updates are made to the printing/parsing of `omp.parallel` to follow the
expected order between `private` and `reduction` clauses, as well as the MLIR
to LLVM IR translation pass to access block arguments using the new interface.
Unit tests of operations impacted by additional verification checks and
sorting of entry block arguments.
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 29 ++++---
.../delayed-privatization-reduction-byref.f90 | 4 +-
.../delayed-privatization-reduction.f90 | 4 +-
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 39 ++++++---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 7 +-
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 80 +++++++++++++++++++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 34 ++++----
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 45 +++++------
mlir/test/Dialect/OpenMP/invalid.mlir | 4 +
mlir/test/Dialect/OpenMP/ops.mlir | 23 ++++--
mlir/test/Target/LLVMIR/openmp-private.mlir | 2 +-
11 files changed, 195 insertions(+), 76 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 960286732c90c2..e9095d631beb7b 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -472,17 +472,26 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
/// \param [in] infoAccessor - for a private variable, this returns the
/// data we want to merge: type or location.
/// \param [out] allRegionArgsInfo - the merged list of region info.
+/// \param [in] addBeforePrivate - `true` if the passed information goes before
+/// private information.
template <typename OMPOp, typename InfoTy>
static void
mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
llvm::function_ref<InfoTy(mlir::Value)> infoAccessor,
- llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo) {
+ llvm::SmallVectorImpl<InfoTy> &allRegionArgsInfo,
+ bool addBeforePrivate) {
mlir::OperandRange privateVars = op.getPrivateVars();
- llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
- [](InfoTy i) { return i; });
+ if (addBeforePrivate)
+ llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+ [](InfoTy i) { return i; });
+
llvm::transform(privateVars, std::back_inserter(allRegionArgsInfo),
infoAccessor);
+
+ if (!addBeforePrivate)
+ llvm::transform(currentList, std::back_inserter(allRegionArgsInfo),
+ [](InfoTy i) { return i; });
}
//===----------------------------------------------------------------------===//
@@ -868,12 +877,12 @@ static void genBodyOfTargetOp(
mergePrivateVarsInfo(targetOp, mapSymTypes,
llvm::function_ref<mlir::Type(mlir::Value)>{
[](mlir::Value v) { return v.getType(); }},
- allRegionArgTypes);
+ allRegionArgTypes, /*addBeforePrivate=*/true);
mergePrivateVarsInfo(targetOp, mapSymLocs,
llvm::function_ref<mlir::Location(mlir::Value)>{
[](mlir::Value v) { return v.getLoc(); }},
- allRegionArgLocs);
+ allRegionArgLocs, /*addBeforePrivate=*/true);
mlir::Block *regionBlock = firOpBuilder.createBlock(
®ion, {}, allRegionArgTypes, allRegionArgLocs);
@@ -1478,21 +1487,21 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mergePrivateVarsInfo(parallelOp, reductionTypes,
llvm::function_ref<mlir::Type(mlir::Value)>{
[](mlir::Value v) { return v.getType(); }},
- allRegionArgTypes);
+ allRegionArgTypes, /*addBeforePrivate=*/false);
llvm::SmallVector<mlir::Location> allRegionArgLocs;
mergePrivateVarsInfo(parallelOp, llvm::ArrayRef(reductionLocs),
llvm::function_ref<mlir::Location(mlir::Value)>{
[](mlir::Value v) { return v.getLoc(); }},
- allRegionArgLocs);
+ allRegionArgLocs, /*addBeforePrivate=*/false);
mlir::Region ®ion = parallelOp.getRegion();
firOpBuilder.createBlock(®ion, /*insertPt=*/{}, allRegionArgTypes,
allRegionArgLocs);
- llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
- allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
- dsp->getDelayedPrivSymbols().end());
+ llvm::SmallVector<const semantics::Symbol *> allSymbols(
+ dsp->getDelayedPrivSymbols());
+ allSymbols.append(reductionSyms.begin(), reductionSyms.end());
unsigned argIdx = 0;
for (const semantics::Symbol *arg : allSymbols) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
index 29439571179322..6c00bb23f15b96 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction-byref.f90
@@ -26,5 +26,5 @@ subroutine red_and_delayed_private
! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
-! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(byref @[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90 b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
index d814b2b0ff0f31..38139e52ce95cb 100644
--- a/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
+++ b/flang/test/Lower/OpenMP/delayed-privatization-reduction.f90
@@ -29,5 +29,5 @@ subroutine red_and_delayed_private
! CHECK-LABEL: _QPred_and_delayed_private
! CHECK: omp.parallel
-! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
-! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
+! CHECK-SAME: private(@[[PRIVATIZER_SYM]] %{{.*}} -> %arg0 : !fir.ref<i32>)
+! CHECK-SAME: reduction(@[[REDUCTION_SYM]] %{{.*}} -> %arg1 : !fir.ref<i32>) {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c579ba6e751d2b..876d53766a0ca1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -451,7 +451,7 @@ class OpenMP_InReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
- ReductionClauseInterface
+ BlockArgOpenMPOpInterface, ReductionClauseInterface
];
let arguments = (ins
@@ -472,6 +472,8 @@ class OpenMP_InReductionClauseSkip<
return SmallVector<Value>(getInReductionVars().begin(),
getInReductionVars().end());
}
+
+ unsigned numInReductionBlockArgs() { return getInReductionVars().size(); }
}];
// Description varies depending on the operation.
@@ -575,6 +577,8 @@ class OpenMP_MapClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
+ // Not adding the BlockArgOpenMPOpInterface here because omp.target is the
+ // only operation defining block arguments for `map` clauses.
MapClauseOwningOpInterface
];
@@ -923,6 +927,10 @@ class OpenMP_PrivateClauseSkip<
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
+ let traits = [
+ BlockArgOpenMPOpInterface
+ ];
+
let arguments = (ins
Variadic<AnyType>:$private_vars,
OptionalAttr<SymbolRefArrayAttr>:$private_syms
@@ -933,6 +941,10 @@ class OpenMP_PrivateClauseSkip<
custom<PrivateList>($private_vars, type($private_vars), $private_syms) `)`
}];
+ let extraClassDeclaration = [{
+ unsigned numPrivateBlockArgs() { return getPrivateVars().size(); }
+ }];
+
// TODO: Add description.
}
@@ -973,7 +985,7 @@ class OpenMP_ReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
- ReductionClauseInterface
+ BlockArgOpenMPOpInterface, ReductionClauseInterface
];
let arguments = (ins
@@ -991,6 +1003,7 @@ class OpenMP_ReductionClauseSkip<
let extraClassDeclaration = [{
/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }
+ unsigned numReductionBlockArgs() { return getReductionVars().size(); }
}];
// Description varies depending on the operation.
@@ -1104,7 +1117,7 @@ class OpenMP_TaskReductionClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
- ReductionClauseInterface
+ BlockArgOpenMPOpInterface, ReductionClauseInterface
];
let arguments = (ins
@@ -1119,6 +1132,18 @@ class OpenMP_TaskReductionClauseSkip<
$task_reduction_byref, $task_reduction_syms) `)`
}];
+ let extraClassDeclaration = [{
+ /// Returns the reduction variables.
+ SmallVector<Value> getReductionVars() {
+ return SmallVector<Value>(getTaskReductionVars().begin(),
+ getTaskReductionVars().end());
+ }
+
+ unsigned numTaskReductionBlockArgs() {
+ return getTaskReductionVars().size();
+ }
+ }];
+
let description = [{
The `task_reduction` clause specifies a reduction among tasks. For each list
item, the number of copies is unspecified. Any copies associated with the
@@ -1130,14 +1155,6 @@ class OpenMP_TaskReductionClauseSkip<
attribute, and whether the reduction variable should be passed into the
reduction region by value or by reference in `task_reduction_byref`.
}];
-
- let extraClassDeclaration = [{
- /// Returns the reduction variables.
- SmallVector<Value> getReductionVars() {
- return SmallVector<Value>(getTaskReductionVars().begin(),
- getTaskReductionVars().end());
- }
- }];
}
def OpenMP_TaskReductionClause : OpenMP_TaskReductionClauseSkip<>;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 9d2123a2e9bf52..326bdd3bbc9463 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1043,7 +1043,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
//===----------------------------------------------------------------------===//
def TargetOp : OpenMP_Op<"target", traits = [
- AttrSizedOperandSegments, IsolatedFromAbove, OutlineableOpenMPOpInterface
+ AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+ OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
@@ -1065,6 +1066,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
+ let extraClassDeclaration = [{
+ unsigned numMapBlockArgs() { return getMapVars().size(); }
+ }] # clausesExtraClassDeclaration;
+
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 0078e22b1c89a6..030075eaf45b14 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,86 @@
include "mlir/IR/OpBase.td"
+def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
+ let description = [{
+ OpenMP operations that define entry block arguments as part of the
+ representation of its clauses.
+ }];
+
+ let cppNamespace = "::mlir::omp";
+
+ let methods = [
+ // Default-implemented methods to be overriden by the corresponding clauses.
+ InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
+ "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+ InterfaceMethod<"Get number of block arguments defined by `map`.",
+ "unsigned", "numMapBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+ InterfaceMethod<"Get number of block arguments defined by `private`.",
+ "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+ InterfaceMethod<"Get number of block arguments defined by `reduction`.",
+ "unsigned", "numReductionBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+ InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
+ "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
+ return 0;
+ }]>,
+
+ // Unified access methods for clause-associated entry block arguments.
+ InterfaceMethod<"Get block arguments defined by `in_reduction`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getInReductionBlockArgs", (ins), [{
+ return $_op->getRegion(0).getArguments().take_front(
+ $_op.numInReductionBlockArgs());
+ }]>,
+ InterfaceMethod<"Get block arguments defined by `map`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getMapBlockArgs", (ins), [{
+ return $_op->getRegion(0).getArguments().slice(
+ $_op.numInReductionBlockArgs(), $_op.numMapBlockArgs());
+ }]>,
+ InterfaceMethod<"Get block arguments defined by `private`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getPrivateBlockArgs", (ins), [{
+ return $_op->getRegion(0).getArguments().slice(
+ $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs(),
+ $_op.numPrivateBlockArgs());
+ }]>,
+ InterfaceMethod<"Get block arguments defined by `reduction`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getReductionBlockArgs", (ins), [{
+ return $_op->getRegion(0).getArguments().slice(
+ $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+ $_op.numPrivateBlockArgs(), $_op.numReductionBlockArgs());
+ }]>,
+ InterfaceMethod<"Get block arguments defined by `task_reduction`.",
+ "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+ "getTaskReductionBlockArgs", (ins), [{
+ return $_op->getRegion(0).getArguments().slice(
+ $_op.numInReductionBlockArgs() + $_op.numMapBlockArgs() +
+ $_op.numPrivateBlockArgs() + $_op.numReductionBlockArgs(),
+ $_op.numTaskReductionBlockArgs());
+ }]>,
+ ];
+
+ let verify = [{
+ auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
+ unsigned expectedArgs = iface.numInReductionBlockArgs() +
+ iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
+ iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs();
+ if ($_op->getRegion(0).getNumArguments() < expectedArgs)
+ return $_op->emitOpError() << "expected at least " << expectedArgs
+ << " entry block argument(s)";
+ return ::mlir::success();
+ }];
+}
+
def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
let description = [{
OpenMP operations whose region will be outlined will implement this
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index db47276dcefe95..7ca7a2afbdbbc4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -536,13 +536,6 @@ static ParseResult parseParallelRegion(
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
- if (succeeded(parser.parseOptionalKeyword("reduction"))) {
- if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
- reductionTypes, reductionByref,
- reductionSyms, regionPrivateArgs)))
- return failure();
- }
-
if (succeeded(parser.parseOptionalKeyword("private"))) {
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
@@ -557,6 +550,13 @@ static ParseResult parseParallelRegion(
}
}
+ if (succeeded(parser.parseOptionalKeyword("reduction"))) {
+ if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
+ reductionTypes, reductionByref,
+ reductionSyms, regionPrivateArgs)))
+ return failure();
+ }
+
return parser.parseRegion(region, regionPrivateArgs);
}
@@ -566,18 +566,9 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms, ValueRange privateVars,
TypeRange privateTypes, ArrayAttr privateSyms) {
- if (reductionSyms) {
- auto *argsBegin = region.front().getArguments().begin();
- MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
- printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
- reductionTypes, reductionByref, reductionSyms);
- }
-
if (privateSyms) {
auto *argsBegin = region.front().getArguments().begin();
- MutableArrayRef argsSubrange(argsBegin + reductionVars.size(),
- argsBegin + reductionVars.size() +
- privateTypes.size());
+ MutableArrayRef argsSubrange(argsBegin, argsBegin + privateTypes.size());
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateTypes.size(), false);
DenseBoolArrayAttr isByRef =
@@ -587,6 +578,15 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
privateTypes, isByRef, privateSyms);
}
+ if (reductionSyms) {
+ auto *argsBegin = region.front().getArguments().begin();
+ MutableArrayRef argsSubrange(argsBegin + privateVars.size(),
+ argsBegin + privateVars.size() +
+ reductionTypes.size());
+ printClauseWithRegionArgs(p, op, argsSubrange, "reduction", reductionVars,
+ reductionTypes, reductionByref, reductionSyms);
+ }
+
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0cba8d80681f13..769cdd57656b51 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -920,7 +920,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
- sectionsOp.getRegion().getArguments();
+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
if (failed(allocAndInitializeReductionVars(
sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -954,8 +954,10 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
// variables
assert(region.getNumArguments() ==
sectionsOp.getRegion().getNumArguments());
- for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
- sectionsOp.getRegion().getArguments(), region.getArguments())) {
+ for (auto [sectionsArg, sectionArg] :
+ llvm::zip_equal(cast<omp::BlockArgOpenMPOpInterface>(*sectionsOp)
+ .getReductionBlockArgs(),
+ region.getArguments())) {
llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
assert(llvmVal);
moduleTranslation.mapValue(sectionArg, llvmVal);
@@ -1216,7 +1218,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
- wsloopOp.getRegion().getArguments();
+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
if (failed(allocAndInitializeReductionVars(
wsloopOp, reductionArgs, builder, moduleTranslation, allocaIP,
@@ -1329,31 +1331,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
class OmpParallelOpConversionManager {
public:
OmpParallelOpConversionManager(omp::ParallelOp opInst)
- : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
- privateArgBeginIdx(opInst.getNumReductionVars()),
- privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
- auto privateVarsIt = privateVars.begin();
-
- for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
- ++argIdx, ++privateVarsIt)
- mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
- *privateVarsIt, region);
+ : region(opInst.getRegion()),
+ privateBlockArgs(cast<omp::BlockArgOpenMPOpInterface>(*opInst)
+ .getPrivateBlockArgs()),
+ privateVars(opInst.getPrivateVars()) {
+ for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+ mlir::replaceAllUsesInRegionWith(blockArg, var, region);
}
~OmpParallelOpConversionManager() {
- auto privateVarsIt = privateVars.begin();
-
- for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
- ++argIdx, ++privateVarsIt)
- mlir::replaceAllUsesInRegionWith(*privateVarsIt,
- region.getArgument(argIdx), region);
+ for (auto [blockArg, var] : llvm::zip_equal(privateBlockArgs, privateVars))
+ mlir::replaceAllUsesInRegionWith(var, blockArg, region);
}
private:
Region ®ion;
+ llvm::MutableArrayRef<BlockArgument> privateBlockArgs;
OperandRange privateVars;
- unsigned privateArgBeginIdx;
- unsigned privateArgEndIdx;
};
/// Converts the OpenMP parallel operation to LLVM IR.
@@ -1382,9 +1376,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
DenseMap<Value, llvm::Value *> reductionVariableMap;
MutableArrayRef<BlockArgument> reductionArgs =
- opInst.getRegion().getArguments().slice(
- opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
- opInst.getNumReductionVars());
+ cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
allocaIP =
InsertPointTy(allocaIP.getBlock(),
@@ -3400,6 +3392,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto &targetRegion = targetOp.getRegion();
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
SmallVector<Value> mapVars = targetOp.getMapVars();
+ ArrayRef<BlockArgument> mapBlockArgs =
+ cast<omp::BlockArgOpenMPOpInterface>(opInst).getMapBlockArgs();
llvm::Function *llvmOutlinedFn = nullptr;
// TODO: It can also be false if a compile-time constant `false` IF clause is
@@ -3428,11 +3422,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvmOutlinedFn->addFnAttr(attr);
builder.restoreIP(codeGenIP);
- for (auto [argIndex, mapOp] : llvm::enumerate(mapVars)) {
+ for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
llvm::Value *mapOpValue =
moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
- const auto &arg = targetRegion.front().getArgument(argIndex);
moduleTranslation.mapValue(arg, mapOpValue);
}
llvm::BasicBlock *exitBlock = convertOmpOpRegions(
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d8745f1015af83..5e182dea52b40e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1459,6 +1459,7 @@ func.func @omp_sections(%data_var : memref<i32>) -> () {
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
"omp.sections" (%data_var) ({
+ ^bb0(%arg0: memref<i32>):
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,0,1>} : (memref<i32>) -> ()
return
@@ -1650,6 +1651,7 @@ func.func @omp_task_depend(%data_var: memref<i32>) {
func.func @omp_task(%ptr: !llvm.ptr) {
// expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
@@ -1674,6 +1676,7 @@ combiner {
func.func @omp_task(%ptr: !llvm.ptr) {
// expected-error @below {{op accumulator variable used more than once}}
omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr, @add_f32 -> %ptr : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
@@ -1704,6 +1707,7 @@ atomic {
func.func @omp_task(%mem: memref<1xf32>) {
// expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr')}}
omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) {
+ ^bb0(%arg0: memref<1xf32>):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e7d3e67ca7e05b..2116071f8523a3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1096,6 +1096,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.teams reduction(@add_f32 -> %{{.+}} : !llvm.ptr) {
omp.teams reduction(@add_f32 -> %0 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
%1 = arith.constant 2.0 : f32
// CHECK: omp.terminator
omp.terminator
@@ -1104,6 +1105,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
// Test reduction byref
// CHECK: omp.teams reduction(byref @add_f32 -> %{{.+}} : !llvm.ptr) {
omp.teams reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
%1 = arith.constant 2.0 : f32
// CHECK: omp.terminator
omp.terminator
@@ -1125,6 +1127,7 @@ func.func @sections_reduction() {
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr)
omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
// CHECK: omp.section
omp.section {
%1 = arith.constant 2.0 : f32
@@ -1146,6 +1149,7 @@ func.func @sections_reduction_byref() {
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
// CHECK: omp.sections reduction(byref @add_f32 -> {{.+}} : !llvm.ptr)
omp.sections reduction(byref @add_f32 -> %0 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
// CHECK: omp.section
omp.section {
%1 = arith.constant 2.0 : f32
@@ -1245,6 +1249,7 @@ func.func @sections_reduction2() {
%0 = memref.alloca() : memref<1xf32>
// CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>)
omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) {
+ ^bb0(%arg0: !llvm.ptr):
omp.section {
%1 = arith.constant 2.0 : f32
omp.terminator
@@ -1901,6 +1906,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
// CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr)
"omp.sections" (%redn_var) ({
+ ^bb0(%arg0: !llvm.ptr):
// CHECK: omp.terminator
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,0,1>, reduction_byref = array<i1: false>, reduction_syms=[@add_f32]} : (!llvm.ptr) -> ()
@@ -1913,6 +1919,7 @@ func.func @omp_sectionsop(%data_var1 : memref<i32>, %data_var2 : memref<i32>,
// CHECK: omp.sections reduction(@add_f32 -> %{{.*}} : !llvm.ptr) {
omp.sections reduction(@add_f32 -> %redn_var : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr):
// CHECK: omp.terminator
omp.terminator
}
@@ -2087,6 +2094,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
%1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
// CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
@@ -2096,6 +2104,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
// Checking `in_reduction` clause (mixed) byref
// CHECK: omp.task in_reduction(byref @add_f32 -> %[[redn_var1]] : !llvm.ptr, @add_f32 -> %[[redn_var2]] : !llvm.ptr) {
omp.task in_reduction(byref @add_f32 -> %0 : !llvm.ptr, @add_f32 -> %1 : !llvm.ptr) {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
@@ -2129,6 +2138,7 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
in_reduction(@add_f32 -> %0 : !llvm.ptr, byref @add_f32 -> %1 : !llvm.ptr)
// CHECK-SAME: priority(%[[i32_var]] : i32) untied
priority(%i32_var : i32) untied {
+ ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
// CHECK: "test.foo"() : () -> ()
"test.foo"() : () -> ()
// CHECK: omp.terminator
@@ -2306,6 +2316,7 @@ func.func @omp_taskgroup_clauses() -> () {
%testf32 = "test.f32"() : () -> (!llvm.ptr)
// CHECK: omp.taskgroup allocate(%{{.+}}: memref<i32> -> %{{.+}}: memref<i32>) task_reduction(@add_f32 -> %{{.+}}: !llvm.ptr)
omp.taskgroup allocate(%testmemref : memref<i32> -> %testmemref : memref<i32>) task_reduction(@add_f32 -> %testf32 : !llvm.ptr) {
+ ^bb0(%arg0 : !llvm.ptr):
// CHECK: omp.task
omp.task {
"test.foo"() : () -> ()
@@ -2783,15 +2794,15 @@ omp.private {type = firstprivate} @y.privatizer : !llvm.ptr alloc {
// CHECK-LABEL: parallel_op_reduction_and_private
func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !llvm.ptr, %reduc_var: !llvm.ptr, %reduc_var2: !llvm.ptr) {
// CHECK: omp.parallel
- // CHECK-SAME: reduction(
- // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
- // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
- //
// CHECK-SAME: private(
// CHECK-SAME: @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]] : !llvm.ptr,
// CHECK-SAME: @y.privatizer %[[PRIV_VAR2:[^[:space:]]+]] -> %[[PRIV_ARG2:[^[:space:]]+]] : !llvm.ptr)
- omp.parallel reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr)
- private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr) {
+ //
+ // CHECK-SAME: reduction(
+ // CHECK-SAME: @add_f32 %[[REDUC_VAR:[^[:space:]]+]] -> %[[REDUC_ARG:[^[:space:]]+]] : !llvm.ptr,
+ // CHECK-SAME: @add_f32 %[[REDUC_VAR2:[^[:space:]]+]] -> %[[REDUC_ARG2:[^[:space:]]+]] : !llvm.ptr)
+ omp.parallel private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr, @y.privatizer %priv_var2 -> %priv_arg2 : !llvm.ptr)
+ reduction(@add_f32 %reduc_var -> %reduc_arg : !llvm.ptr, @add_f32 %reduc_var2 -> %reduc_arg2 : !llvm.ptr) {
// CHECK: llvm.load %[[PRIV_ARG]]
%0 = llvm.load %priv_arg : !llvm.ptr -> f32
// CHECK: llvm.load %[[PRIV_ARG2]]
diff --git a/mlir/test/Target/LLVMIR/openmp-private.mlir b/mlir/test/Target/LLVMIR/openmp-private.mlir
index 21167668bbee16..a06e44fc5cfe01 100644
--- a/mlir/test/Target/LLVMIR/openmp-private.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-private.mlir
@@ -206,7 +206,7 @@ llvm.func @private_and_reduction_() attributes {fir.internal_name = "_QPprivate_
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
%2 = llvm.alloca %0 x f32 {bindc_name = "to_priv"} : (i64) -> !llvm.ptr
- omp.parallel reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) private(@privatizer.part %2 -> %arg1 : !llvm.ptr) {
+ omp.parallel private(@privatizer.part %2 -> %arg1 : !llvm.ptr) reduction(byref @reducer.part %1 -> %arg0 : !llvm.ptr) {
%3 = llvm.load %arg0 : !llvm.ptr -> !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
%4 = llvm.mlir.constant(8.000000e+00 : f32) : f32
llvm.store %4, %arg1 : f32, !llvm.ptr
More information about the flang-commits
mailing list