[flang-commits] [flang] eba35cc - [flang][do concurrent] Re-model `reduce` to match reductions are modelled in OpenMP and OpenACC (#145837)
via flang-commits
flang-commits at lists.llvm.org
Thu Jul 10 21:39:36 PDT 2025
Author: Kareem Ergawy
Date: 2025-07-11T06:39:30+02:00
New Revision: eba35cc1c0e4e2c59f9fd1f7a6f3b17cb4d8c765
URL: https://github.com/llvm/llvm-project/commit/eba35cc1c0e4e2c59f9fd1f7a6f3b17cb4d8c765
DIFF: https://github.com/llvm/llvm-project/commit/eba35cc1c0e4e2c59f9fd1f7a6f3b17cb4d8c765.diff
LOG: [flang][do concurrent] Re-model `reduce` to match reductions are modelled in OpenMP and OpenACC (#145837)
This PR proposes re-modelling `reduce` specifiers to match OpenMP and
OpenACC. In particular, this PR includes the following:
* A new `fir` op: `fir.delcare_reduction` which is identical to OpenMP's
`omp.declare_reduction` op.
* Updating the `reduce` clause on `fir.do_concurrent.loop` to use the
new op.
* Re-uses the `ReductionProcessor` component to emit reductions for `do
concurrent` just like we do for OpenMP. To do this, the
`ReductionProcessor` had to be refactored to be more generalized.
* Upates mapping `do concurrent` to `fir.loop ... unordered` nests using
the new reduction model.
Unfortunately, this is a big PR that would be difficult to divide up in
smaller parts because the bottom of the changes are the `fir` table-gen
changes to `do concurrent`. However, doing these MLIR changes cascades
to the other parts that have to be modified to not break things.
This PR goes in the same direction we went for `private/local`
speicifiers. Now the `do concurrent` and OpenMP (and OpenACC) dialects
are modelled in essentially the same way which makes mapping between
them more trivial, hopefully.
PR stack:
- https://github.com/llvm/llvm-project/pull/145837 (this one)
- https://github.com/llvm/llvm-project/pull/146025
- https://github.com/llvm/llvm-project/pull/146028
- https://github.com/llvm/llvm-project/pull/146033
Added:
flang/test/Lower/do_concurrent_reduce.f90
Modified:
flang/include/flang/Optimizer/Dialect/FIRAttr.td
flang/include/flang/Optimizer/Dialect/FIROps.td
flang/lib/Lower/Bridge.cpp
flang/lib/Lower/OpenMP/ClauseProcessor.cpp
flang/lib/Lower/OpenMP/ClauseProcessor.h
flang/lib/Lower/OpenMP/Clauses.h
flang/lib/Lower/OpenMP/OpenMP.cpp
flang/lib/Lower/OpenMP/ReductionProcessor.cpp
flang/lib/Lower/OpenMP/ReductionProcessor.h
flang/lib/Lower/Support/Utils.cpp
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/lib/Optimizer/Dialect/FIROps.cpp
flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
flang/test/Fir/do_concurrent.fir
flang/test/Fir/invalid.fir
flang/test/Lower/loops.f90
flang/test/Lower/loops3.f90
flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.td b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
index 2845080030b92..7bd96ac3ea631 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRAttr.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.td
@@ -112,7 +112,7 @@ def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
I32BitEnumAttrCaseBit<"MIN", 7, "min">,
I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
- I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
+ I32BitEnumAttrCaseBit<"IEOR", 10, "ieor">
]> {
let separator = ", ";
let cppNamespace = "::fir";
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index f440580f0878a..c3d3582a50e7f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3518,7 +3518,7 @@ def fir_BoxTotalElementsOp
def YieldOp : fir_Op<"yield",
[Pure, ReturnLike, Terminator,
- ParentOneOf<["LocalitySpecifierOp"]>]> {
+ ParentOneOf<["LocalitySpecifierOp", "DeclareReductionOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"fir.yield" yields SSA values from a fir dialect op region and
@@ -3656,6 +3656,103 @@ def fir_LocalitySpecifierOp : fir_Op<"local", [IsolatedFromAbove]> {
let hasRegionVerifier = 1;
}
+def fir_DeclareReductionOp : fir_Op<"declare_reduction", [IsolatedFromAbove,
+ Symbol]> {
+ let summary = "declares a reduction kind";
+ let description = [{
+ Note: this operation is adapted from omp::DeclareReductionOp. There is a lot
+ duplication at the moment. TODO Combine both ops into one. See:
+ https://discourse.llvm.org/t/dialect-for-data-locality-sharing-specifiers-clauses-in-openmp-openacc-and-do-concurrent/86108.
+
+ Declares a `do concurrent` reduction. This requires two mandatory and three
+ optional regions.
+
+ 1. The optional alloc region specifies how to allocate the thread-local
+ reduction value. This region should not contain control flow and all
+ IR should be suitable for inlining straight into an entry block. In
+ the common case this is expected to contain only allocas. It is
+ expected to `fir.yield` the allocated value on all control paths.
+ If allocation is conditional (e.g. only allocate if the mold is
+ allocated), this should be done in the initilizer region and this
+ region not included. The alloc region is not used for by-value
+ reductions (where allocation is implicit).
+ 2. The initializer region specifies how to initialize the thread-local
+ reduction value. This is usually the neutral element of the reduction.
+ For convenience, the region has an argument that contains the value
+ of the reduction accumulator at the start of the reduction. If an alloc
+ region is specified, there is a second block argument containing the
+ address of the allocated memory. The initializer region is expected to
+ `fir.yield` the new value on all control flow paths.
+ 3. The reduction region specifies how to combine two values into one, i.e.
+ the reduction operator. It accepts the two values as arguments and is
+ expected to `fir.yield` the combined value on all control flow paths.
+ 4. The atomic reduction region is optional and specifies how two values
+ can be combined atomically given local accumulator variables. It is
+ expected to store the combined value in the first accumulator variable.
+ 5. The cleanup region is optional and specifies how to clean up any memory
+ allocated by the initializer region. The region has an argument that
+ contains the value of the thread-local reduction accumulator. This will
+ be executed after the reduction has completed.
+
+ Note that the MLIR type system does not allow for type-polymorphic
+ reductions. Separate reduction declarations should be created for
diff erent
+ element and accumulator types.
+
+ For initializer and reduction regions, the operand to `fir.yield` must
+ match the parent operation's results.
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttr:$type);
+
+ let regions = (region MaxSizedRegion<1>:$allocRegion,
+ AnyRegion:$initializerRegion,
+ AnyRegion:$reductionRegion,
+ AnyRegion:$atomicReductionRegion,
+ AnyRegion:$cleanupRegion);
+
+ let assemblyFormat = "$sym_name `:` $type attr-dict-with-keyword "
+ "( `alloc` $allocRegion^ )? "
+ "`init` $initializerRegion "
+ "`combiner` $reductionRegion "
+ "( `atomic` $atomicReductionRegion^ )? "
+ "( `cleanup` $cleanupRegion^ )? ";
+
+ let extraClassDeclaration = [{
+ mlir::BlockArgument getAllocMoldArg() {
+ auto ®ion = getAllocRegion();
+ return region.empty() ? nullptr : region.getArgument(0);
+ }
+ mlir::BlockArgument getInitializerMoldArg() {
+ return getInitializerRegion().getArgument(0);
+ }
+ mlir::BlockArgument getInitializerAllocArg() {
+ return getAllocRegion().empty() ?
+ nullptr : getInitializerRegion().getArgument(1);
+ }
+ mlir::BlockArgument getReductionLhsArg() {
+ return getReductionRegion().getArgument(0);
+ }
+ mlir::BlockArgument getReductionRhsArg() {
+ return getReductionRegion().getArgument(1);
+ }
+ mlir::BlockArgument getAtomicReductionLhsArg() {
+ auto ®ion = getAtomicReductionRegion();
+ return region.empty() ? nullptr : region.getArgument(0);
+ }
+ mlir::BlockArgument getAtomicReductionRhsArg() {
+ auto ®ion = getAtomicReductionRegion();
+ return region.empty() ? nullptr : region.getArgument(1);
+ }
+ mlir::BlockArgument getCleanupAllocArg() {
+ auto ®ion = getCleanupRegion();
+ return region.empty() ? nullptr : region.getArgument(0);
+ }
+ }];
+
+ let hasRegionVerifier = 1;
+}
+
def fir_DoConcurrentOp : fir_Op<"do_concurrent",
[SingleBlock, AutomaticAllocationScope]> {
let summary = "do concurrent loop wrapper";
@@ -3694,6 +3791,25 @@ def fir_LocalSpecifier {
);
}
+def fir_ReduceSpecifier {
+ dag arguments = (ins
+ Variadic<AnyType>:$reduce_vars,
+ OptionalAttr<DenseBoolArrayAttr>:$reduce_byref,
+
+ // This introduces redundency in how reductions are modelled. In particular,
+ // a single reduction is represented by 2 attributes:
+ //
+ // 1. `$reduce_syms` which is a list of `DeclareReductionOp`s.
+ // 2. `$reduce_attrs` which is an array of `fir::ReduceAttr` values.
+ //
+ // The first makes it easier to map `do concurrent` to parallization models
+ // (e.g. OpenMP and OpenACC) while the second makes it easier to map it to
+ // nests of `fir.do_loop ... unodered` ops.
+ OptionalAttr<SymbolRefArrayAttr>:$reduce_syms,
+ OptionalAttr<ArrayAttr>:$reduce_attrs
+ );
+}
+
def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getLoopInductionVars"]>,
@@ -3703,7 +3819,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
let description = [{
An operation that models a Fortran `do concurrent` loop's header and block.
This is a single-region single-block terminator op that is expected to
- terminate the region of a `omp.do_concurrent` wrapper op.
+ terminate the region of a `fir.do_concurrent` wrapper op.
This op borrows from both `scf.parallel` and `fir.do_loop` ops. Similar to
`scf.parallel`, a loop nest takes 3 groups of SSA values as operands that
@@ -3741,8 +3857,6 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
- `lowerBound`: The group of SSA values for the nest's lower bounds.
- `upperBound`: The group of SSA values for the nest's upper bounds.
- `step`: The group of SSA values for the nest's steps.
- - `reduceOperands`: The reduction SSA values, if any.
- - `reduceAttrs`: Attributes to store reduction operations, if any.
- `loopAnnotation`: Loop metadata to be passed down the compiler pipeline to
LLVM.
}];
@@ -3751,12 +3865,12 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
Variadic<Index>:$lowerBound,
Variadic<Index>:$upperBound,
Variadic<Index>:$step,
- Variadic<AnyType>:$reduceOperands,
- OptionalAttr<ArrayAttr>:$reduceAttrs,
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
);
- let arguments = !con(opArgs, fir_LocalSpecifier.arguments);
+ let arguments = !con(opArgs,
+ fir_LocalSpecifier.arguments,
+ fir_ReduceSpecifier.arguments);
let regions = (region SizedRegion<1>:$region);
@@ -3777,12 +3891,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
getNumLocalOperands());
}
+ mlir::Block::BlockArgListType getRegionReduceArgs() {
+ return getBody()->getArguments().slice(getNumInductionVars()
+ + getNumLocalOperands(),
+ getNumReduceOperands());
+ }
+
/// Number of operands controlling the loop
unsigned getNumControlOperands() { return getLowerBound().size() * 3; }
// Get Number of reduction operands
unsigned getNumReduceOperands() {
- return getReduceOperands().size();
+ return getReduceVars().size();
}
mlir::Operation::operand_range getLocalOperands() {
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index cd55d10314740..e2d3fe964d49b 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -12,6 +12,7 @@
#include "flang/Lower/Bridge.h"
+#include "OpenMP/ReductionProcessor.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
#include "flang/Lower/Coarray.h"
@@ -127,9 +128,8 @@ struct IncrementLoopInfo {
bool isConcurrent;
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
- llvm::SmallVector<
- std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
- reduceSymList;
+ llvm::SmallVector<const Fortran::semantics::Symbol *> reduceSymList;
+ llvm::SmallVector<fir::ReduceOperationEnum> reduceOperatorList;
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
mlir::Value loopVariable = nullptr;
@@ -1993,7 +1993,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
case Fortran::parser::ReductionOperator::Operator::Ior:
return fir::ReduceOperationEnum::IOR;
case Fortran::parser::ReductionOperator::Operator::Ieor:
- return fir::ReduceOperationEnum::EIOR;
+ return fir::ReduceOperationEnum::IEOR;
}
llvm_unreachable("illegal reduction operator");
}
@@ -2027,8 +2027,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get<Fortran::parser::ReductionOperator>(reduceList->t));
for (const Fortran::parser::Name &x :
std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
- info.reduceSymList.push_back(
- std::make_pair(reduce_operation, x.symbol));
+ info.reduceSymList.push_back(x.symbol);
+ info.reduceOperatorList.push_back(reduce_operation);
}
}
}
@@ -2089,6 +2089,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
assign.u = Fortran::evaluate::Assignment::BoundsSpec{};
genAssignment(assign);
}
+
for (const Fortran::semantics::Symbol *sym : info.sharedSymList) {
const auto *hostDetails =
sym->detailsIf<Fortran::semantics::HostAssocDetails>();
@@ -2112,6 +2113,45 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
+ llvm::SmallVector<bool> reduceVarByRef;
+ llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
+ llvm::SmallVector<mlir::Attribute> nestReduceAttrs;
+
+ for (const auto &reduceOp : info.reduceOperatorList)
+ nestReduceAttrs.push_back(
+ fir::ReduceAttr::get(builder->getContext(), reduceOp));
+
+ llvm::SmallVector<mlir::Value> reduceVars;
+ Fortran::lower::omp::ReductionProcessor rp;
+ rp.processReductionArguments<fir::DeclareReductionOp>(
+ toLocation(), *this, info.reduceOperatorList, reduceVars,
+ reduceVarByRef, reductionDeclSymbols, info.reduceSymList);
+
+ doConcurrentLoopOp.getReduceVarsMutable().assign(reduceVars);
+ doConcurrentLoopOp.setReduceSymsAttr(
+ reductionDeclSymbols.empty()
+ ? nullptr
+ : mlir::ArrayAttr::get(builder->getContext(),
+ reductionDeclSymbols));
+ doConcurrentLoopOp.setReduceAttrsAttr(
+ nestReduceAttrs.empty()
+ ? nullptr
+ : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs));
+ doConcurrentLoopOp.setReduceByrefAttr(
+ reduceVarByRef.empty() ? nullptr
+ : mlir::DenseBoolArrayAttr::get(
+ builder->getContext(), reduceVarByRef));
+
+ for (auto [sym, reduceVar] :
+ llvm::zip_equal(info.reduceSymList, reduceVars)) {
+ auto arg = doConcurrentLoopOp.getRegion().begin()->addArgument(
+ reduceVar.getType(), doConcurrentLoopOp.getLoc());
+ bindSymbol(*sym, hlfir::translateToExtendedValue(
+ reduceVar.getLoc(), *builder, hlfir::Entity{arg},
+ /*contiguousHint=*/true)
+ .first);
+ }
+
// Note that allocatable, types with ultimate components, and type
// requiring finalization are forbidden in LOCAL/LOCAL_INIT (F2023 C1130),
// so no clean-up needs to be generated for these entities.
@@ -2203,6 +2243,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
+ // Introduce a `do concurrent` scope to bind symbols corresponding to local,
+ // local_init, and reduce region arguments.
+ if (!incrementLoopNestInfo.empty() &&
+ incrementLoopNestInfo.back().isConcurrent)
+ localSymbols.pushScope();
+
// Increment loop begin code. (Infinite/while code was already generated.)
if (!infiniteLoop && !whileCondition)
genFIRIncrementLoopBegin(incrementLoopNestInfo, doStmtEval.dirs);
@@ -2226,6 +2272,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// This call may generate a branch in some contexts.
genFIR(endDoEval, unstructuredContext);
+
+ if (!incrementLoopNestInfo.empty() &&
+ incrementLoopNestInfo.back().isConcurrent)
+ localSymbols.popScope();
}
/// Generate FIR to evaluate loop control values (lower, upper and step).
@@ -2408,19 +2458,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
info.stepVariable = builder->createTemporary(loc, stepValue.getType());
builder->create<fir::StoreOp>(loc, stepValue, info.stepVariable);
}
-
- if (genDoConcurrent && nestReduceOperands.empty()) {
- // Create DO CONCURRENT reduce operands and attributes
- for (const auto &reduceSym : info.reduceSymList) {
- const fir::ReduceOperationEnum reduceOperation = reduceSym.first;
- const Fortran::semantics::Symbol *sym = reduceSym.second;
- fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
- nestReduceOperands.push_back(fir::getBase(exv));
- auto reduceAttr =
- fir::ReduceAttr::get(builder->getContext(), reduceOperation);
- nestReduceAttrs.push_back(reduceAttr);
- }
- }
}
for (auto [info, lowerValue, upperValue, stepValue] :
@@ -2518,11 +2555,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->setInsertionPointToEnd(loopWrapperOp.getBody());
auto loopOp = builder->create<fir::DoConcurrentLoopOp>(
- loc, nestLBs, nestUBs, nestSts, nestReduceOperands,
- nestReduceAttrs.empty()
- ? nullptr
- : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
- nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr);
+ loc, nestLBs, nestUBs, nestSts, /*loopAnnotation=*/nullptr,
+ /*local_vars=*/std::nullopt,
+ /*local_syms=*/nullptr, /*reduce_vars=*/std::nullopt,
+ /*reduce_byref=*/nullptr, /*reduce_syms=*/nullptr,
+ /*reduce_attrs=*/nullptr);
llvm::SmallVector<mlir::Type> loopBlockArgTypes(
incrementLoopNestInfo.size(), builder->getIndexType());
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 7bea427099a28..5aebfc901e8ac 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -12,6 +12,7 @@
#include "ClauseProcessor.h"
#include "Clauses.h"
+#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/ConvertExprToHLFIR.h"
@@ -25,6 +26,21 @@ namespace Fortran {
namespace lower {
namespace omp {
+using ReductionModifier =
+ Fortran::lower::omp::clause::Reduction::ReductionModifier;
+
+mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
+ switch (mod) {
+ case ReductionModifier::Default:
+ return mlir::omp::ReductionModifier::defaultmod;
+ case ReductionModifier::Inscan:
+ return mlir::omp::ReductionModifier::inscan;
+ case ReductionModifier::Task:
+ return mlir::omp::ReductionModifier::task;
+ }
+ return mlir::omp::ReductionModifier::defaultmod;
+}
+
/// Check for unsupported map operand types.
static void checkMapType(mlir::Location location, mlir::Type type) {
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(type))
@@ -1076,6 +1092,18 @@ bool ClauseProcessor::processIf(
});
return found;
}
+
+template <typename T>
+void collectReductionSyms(
+ const T &reduction,
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
+ const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
+ for (const Object &object : objectList) {
+ const semantics::Symbol *symbol = object.sym();
+ reductionSyms.push_back(symbol);
+ }
+}
+
bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
@@ -1085,10 +1113,14 @@ bool ClauseProcessor::processInReduction(
llvm::SmallVector<bool> inReduceVarByRef;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
+ collectReductionSyms(clause, inReductionSyms);
+
ReductionProcessor rp;
- rp.processReductionArguments<omp::clause::InReduction>(
- currentLocation, converter, clause, inReductionVars,
- inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);
+ rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
+ currentLocation, converter,
+ std::get<typename omp::clause::ReductionOperatorList>(clause.t),
+ inReductionVars, inReduceVarByRef, inReductionDeclSymbols,
+ inReductionSyms);
// Copy local lists into the output.
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
@@ -1416,10 +1448,23 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<bool> reduceVarByRef;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
+ collectReductionSyms(clause, reductionSyms);
+
+ auto mod = std::get<std::optional<ReductionModifier>>(clause.t);
+ if (mod.has_value()) {
+ if (mod.value() == ReductionModifier::Task)
+ TODO(currentLocation, "Reduction modifier `task` is not supported");
+ else
+ result.reductionMod = mlir::omp::ReductionModifierAttr::get(
+ converter.getFirOpBuilder().getContext(),
+ translateReductionModifier(mod.value()));
+ }
+
ReductionProcessor rp;
- rp.processReductionArguments<omp::clause::Reduction>(
- currentLocation, converter, clause, reductionVars, reduceVarByRef,
- reductionDeclSymbols, reductionSyms, &result.reductionMod);
+ rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
+ currentLocation, converter,
+ std::get<typename omp::clause::ReductionOperatorList>(clause.t),
+ reductionVars, reduceVarByRef, reductionDeclSymbols, reductionSyms);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
@@ -1435,21 +1480,25 @@ bool ClauseProcessor::processTaskReduction(
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
- llvm::SmallVector<bool> TaskReduceVarByRef;
- llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
- llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
+ llvm::SmallVector<bool> taskReduceVarByRef;
+ llvm::SmallVector<mlir::Attribute> taskReductionDeclSymbols;
+ llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
+ collectReductionSyms(clause, taskReductionSyms);
+
ReductionProcessor rp;
- rp.processReductionArguments<omp::clause::TaskReduction>(
- currentLocation, converter, clause, taskReductionVars,
- TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
+ rp.processReductionArguments<mlir::omp::DeclareReductionOp>(
+ currentLocation, converter,
+ std::get<typename omp::clause::ReductionOperatorList>(clause.t),
+ taskReductionVars, taskReduceVarByRef, taskReductionDeclSymbols,
+ taskReductionSyms);
// Copy local lists into the output.
llvm::copy(taskReductionVars,
std::back_inserter(result.taskReductionVars));
- llvm::copy(TaskReduceVarByRef,
+ llvm::copy(taskReduceVarByRef,
std::back_inserter(result.taskReductionByref));
- llvm::copy(TaskReductionDeclSymbols,
+ llvm::copy(taskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));
- llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
+ llvm::copy(taskReductionSyms, std::back_inserter(outReductionSyms));
});
}
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3d8c4a337a4a4..46b749fb66c86 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -14,7 +14,6 @@
#include "ClauseFinder.h"
#include "Clauses.h"
-#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index d7ab21d428e32..7f317f05f67b7 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -179,6 +179,7 @@ using IteratorSpecifier = tomp::type::IteratorSpecifierT<TypeTy, IdTy, ExprTy>;
using DefinedOperator = tomp::type::DefinedOperatorT<IdTy, ExprTy>;
using ProcedureDesignator = tomp::type::ProcedureDesignatorT<IdTy, ExprTy>;
using ReductionOperator = tomp::type::ReductionIdentifierT<IdTy, ExprTy>;
+using ReductionOperatorList = List<ReductionOperator>;
using DependenceType = tomp::type::DependenceType;
using Prescriptiveness = tomp::type::Prescriptiveness;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0a56e888ac44b..65e852cbcc911 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -17,7 +17,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 330cef7b54c74..d14fc1f7a52da 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -12,6 +12,7 @@
#include "ReductionProcessor.h"
+#include "Clauses.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/Support/PrivateReductionUtils.h"
@@ -21,8 +22,6 @@
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
-#include "flang/Optimizer/Support/FatalError.h"
-#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>
@@ -40,35 +39,35 @@ namespace lower {
namespace omp {
// explicit template declarations
-template void
-ReductionProcessor::processReductionArguments<omp::clause::Reduction>(
+template void ReductionProcessor::processReductionArguments<
+ mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const omp::clause::Reduction &reduction,
+ const omp::clause::ReductionOperatorList &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- mlir::omp::ReductionModifierAttr *reductionMod);
+ const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
-template void
-ReductionProcessor::processReductionArguments<omp::clause::TaskReduction>(
+template void ReductionProcessor::processReductionArguments<
+ fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const omp::clause::TaskReduction &reduction,
+ const llvm::SmallVector<fir::ReduceOperationEnum> &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- mlir::omp::ReductionModifierAttr *reductionMod);
+ const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
-template void
-ReductionProcessor::processReductionArguments<omp::clause::InReduction>(
- mlir::Location currentLocation, lower::AbstractConverter &converter,
- const omp::clause::InReduction &reduction,
- llvm::SmallVectorImpl<mlir::Value> &reductionVars,
- llvm::SmallVectorImpl<bool> &reduceVarByRef,
- llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- mlir::omp::ReductionModifierAttr *reductionMod);
+template mlir::omp::DeclareReductionOp
+ReductionProcessor::createDeclareReduction<mlir::omp::DeclareReductionOp>(
+ AbstractConverter &converter, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+ bool isByRef);
+
+template fir::DeclareReductionOp
+ReductionProcessor::createDeclareReduction<fir::DeclareReductionOp>(
+ AbstractConverter &converter, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
+ bool isByRef);
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
@@ -106,6 +105,37 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
}
}
+ReductionProcessor::ReductionIdentifier
+ReductionProcessor::getReductionType(const fir::ReduceOperationEnum &redOp) {
+ switch (redOp) {
+ case fir::ReduceOperationEnum::Add:
+ return ReductionIdentifier::ADD;
+ case fir::ReduceOperationEnum::Multiply:
+ return ReductionIdentifier::MULTIPLY;
+
+ case fir::ReduceOperationEnum::AND:
+ return ReductionIdentifier::AND;
+ case fir::ReduceOperationEnum::OR:
+ return ReductionIdentifier::OR;
+
+ case fir::ReduceOperationEnum::EQV:
+ return ReductionIdentifier::EQV;
+ case fir::ReduceOperationEnum::NEQV:
+ return ReductionIdentifier::NEQV;
+
+ case fir::ReduceOperationEnum::IAND:
+ return ReductionIdentifier::IAND;
+ case fir::ReduceOperationEnum::IEOR:
+ return ReductionIdentifier::IEOR;
+ case fir::ReduceOperationEnum::IOR:
+ return ReductionIdentifier::IOR;
+ case fir::ReduceOperationEnum::MAX:
+ return ReductionIdentifier::MAX;
+ case fir::ReduceOperationEnum::MIN:
+ return ReductionIdentifier::MIN;
+ }
+}
+
bool ReductionProcessor::supportedIntrinsicProcReduction(
const omp::clause::ProcedureDesignator &pd) {
semantics::Symbol *sym = pd.v.sym();
@@ -136,28 +166,29 @@ ReductionProcessor::getReductionName(llvm::StringRef name,
return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
}
-std::string ReductionProcessor::getReductionName(
- omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
- const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
+std::string
+ReductionProcessor::getReductionName(ReductionIdentifier redId,
+ const fir::KindMapping &kindMap,
+ mlir::Type ty, bool isByRef) {
std::string reductionName;
- switch (intrinsicOp) {
- case omp::clause::DefinedOperator::IntrinsicOperator::Add:
+ switch (redId) {
+ case ReductionIdentifier::ADD:
reductionName = "add_reduction";
break;
- case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
+ case ReductionIdentifier::MULTIPLY:
reductionName = "multiply_reduction";
break;
- case omp::clause::DefinedOperator::IntrinsicOperator::AND:
+ case ReductionIdentifier::AND:
reductionName = "and_reduction";
break;
- case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
+ case ReductionIdentifier::EQV:
reductionName = "eqv_reduction";
break;
- case omp::clause::DefinedOperator::IntrinsicOperator::OR:
+ case ReductionIdentifier::OR:
reductionName = "or_reduction";
break;
- case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
+ case ReductionIdentifier::NEQV:
reductionName = "neqv_reduction";
break;
default:
@@ -334,8 +365,18 @@ mlir::Value ReductionProcessor::createScalarCombiner(
return reductionOp;
}
+template <typename ParentDeclOpType>
+static void genYield(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::Value yieldedValue) {
+ if constexpr (std::is_same_v<ParentDeclOpType, mlir::omp::DeclareReductionOp>)
+ builder.create<mlir::omp::YieldOp>(loc, yieldedValue);
+ else
+ builder.create<fir::YieldOp>(loc, yieldedValue);
+}
+
/// Create reduction combiner region for reduction variables which are boxed
/// arrays
+template <typename DeclRedOpType>
static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
fir::BaseBoxType boxTy, mlir::Value lhs,
@@ -369,7 +410,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value result = ReductionProcessor::createScalarCombiner(
builder, loc, redId, eleTy, lhs, rhs);
builder.create<fir::StoreOp>(loc, result, lhsValAddr);
- builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
+ genYield<DeclRedOpType>(builder, loc, lhsAddr);
return;
}
@@ -408,10 +449,11 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
builder.setInsertionPointAfter(nest.outerOp);
- builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
+ genYield<DeclRedOpType>(builder, loc, lhsAddr);
}
// generate combiner region for reduction operations
+template <typename DeclRedOpType>
static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
@@ -426,15 +468,15 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder, loc, redId, ty, lhsLoaded, rhsLoaded);
if (isByRef) {
builder.create<fir::StoreOp>(loc, result, lhs);
- builder.create<mlir::omp::YieldOp>(loc, lhs);
+ genYield<DeclRedOpType>(builder, loc, lhs);
} else {
- builder.create<mlir::omp::YieldOp>(loc, result);
+ genYield<DeclRedOpType>(builder, loc, result);
}
return;
}
// all arrays should have been boxed
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
- genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
+ genBoxCombiner<DeclRedOpType>(builder, loc, redId, boxTy, lhs, rhs);
return;
}
@@ -454,15 +496,13 @@ static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
return ty;
}
+template <typename OpType>
static void createReductionAllocAndInitRegions(
- AbstractConverter &converter, mlir::Location loc,
- mlir::omp::DeclareReductionOp &reductionDecl,
+ AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
bool isByRef) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- auto yield = [&](mlir::Value ret) {
- builder.create<mlir::omp::YieldOp>(loc, ret);
- };
+ auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
mlir::Block *allocBlock = nullptr;
mlir::Block *initBlock = nullptr;
@@ -512,7 +552,8 @@ static void createReductionAllocAndInitRegions(
yield(boxAlloca);
}
-mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
+template <typename OpType>
+OpType ReductionProcessor::createDeclareReduction(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
@@ -522,8 +563,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
assert(!reductionOpName.empty());
- auto decl =
- module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
+ auto decl = module.lookupSymbol<OpType>(reductionOpName);
if (decl)
return decl;
@@ -532,8 +572,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
if (!isByRef)
type = valTy;
- decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
- type);
+ decl = modBuilder.create<OpType>(loc, reductionOpName, type);
createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
isByRef);
@@ -544,7 +583,7 @@ mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
- genCombiner(builder, loc, redId, type, op1, op2, isByRef);
+ genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
return decl;
}
@@ -563,64 +602,41 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
-mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
- switch (mod) {
- case ReductionModifier::Default:
- return mlir::omp::ReductionModifier::defaultmod;
- case ReductionModifier::Inscan:
- return mlir::omp::ReductionModifier::inscan;
- case ReductionModifier::Task:
- return mlir::omp::ReductionModifier::task;
- }
- return mlir::omp::ReductionModifier::defaultmod;
-}
-
-template <class T>
+template <typename OpType, typename RedOperatorListTy>
void ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ const RedOperatorListTy &redOperatorList,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- mlir::omp::ReductionModifierAttr *reductionMod) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- if constexpr (std::is_same_v<T, omp::clause::Reduction>) {
- auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
- if (mod.has_value()) {
- if (mod.value() == ReductionModifier::Task)
- TODO(currentLocation, "Reduction modifier `task` is not supported");
- else
- *reductionMod = mlir::omp::ReductionModifierAttr::get(
- firOpBuilder.getContext(), translateReductionModifier(mod.value()));
- }
- }
-
- mlir::omp::DeclareReductionOp decl;
- const auto &redOperatorList{
- std::get<typename T::ReductionIdentifiers>(reduction.t)};
- assert(redOperatorList.size() == 1 && "Expecting single operator");
- const auto &redOperator = redOperatorList.front();
- const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
-
- if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
- if (const auto *reductionIntrinsic =
- std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
+ const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+ if constexpr (std::is_same_v<RedOperatorListTy,
+ omp::clause::ReductionOperatorList>) {
+ // For OpenMP reduction clauses, check if the reduction operator is
+ // supported.
+ assert(redOperatorList.size() == 1 && "Expecting single operator");
+ const Fortran::lower::omp::clause::ReductionOperator &redOperator =
+ redOperatorList.front();
+
+ if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
+ if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
+ if (!ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ return;
+ }
+ } else {
return;
}
- } else {
- return;
}
}
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
// Reduction variable processing common to both intrinsic operators and
// procedure designators
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
- for (const Object &object : objectList) {
- const semantics::Symbol *symbol = object.sym();
- reductionSymbols.push_back(symbol);
+ for (const semantics::Symbol *symbol : reductionSymbols) {
mlir::Value symVal = converter.getSymbolAddress(*symbol);
mlir::Type eleType;
auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
@@ -672,52 +688,63 @@ void ReductionProcessor::processReductionArguments(
reduceVarByRef.push_back(doReductionByRef(symVal));
}
+ unsigned idx = 0;
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = firOpBuilder.getKindMap();
std::string reductionName;
ReductionIdentifier redId;
- if (const auto &redDefinedOp =
- std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
- const auto &intrinsicOp{
- std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
- redDefinedOp->u)};
- redId = getReductionType(intrinsicOp);
- switch (redId) {
- case ReductionIdentifier::ADD:
- case ReductionIdentifier::MULTIPLY:
- case ReductionIdentifier::AND:
- case ReductionIdentifier::EQV:
- case ReductionIdentifier::OR:
- case ReductionIdentifier::NEQV:
- break;
- default:
- TODO(currentLocation,
- "Reduction of some intrinsic operators is not supported");
- break;
- }
-
- reductionName = getReductionName(intrinsicOp, kindMap, redType, isByRef);
- } else if (const auto *reductionIntrinsic =
- std::get_if<omp::clause::ProcedureDesignator>(
- &redOperator.u)) {
- if (!ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- TODO(currentLocation, "Unsupported intrinsic proc reduction");
+ if constexpr (std::is_same_v<RedOperatorListTy,
+ omp::clause::ReductionOperatorList>) {
+ const Fortran::lower::omp::clause::ReductionOperator &redOperator =
+ redOperatorList.front();
+ if (const auto &redDefinedOp =
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
+ const auto &intrinsicOp{
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
+ redDefinedOp->u)};
+ redId = getReductionType(intrinsicOp);
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
+ break;
+ default:
+ TODO(currentLocation,
+ "Reduction of some intrinsic operators is not supported");
+ break;
+ }
+
+ reductionName = getReductionName(redId, kindMap, redType, isByRef);
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (!ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ TODO(currentLocation, "Unsupported intrinsic proc reduction");
+ }
+ redId = getReductionType(*reductionIntrinsic);
+ reductionName =
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ kindMap, redType, isByRef);
+ } else {
+ TODO(currentLocation, "Unexpected reduction type");
}
- redId = getReductionType(*reductionIntrinsic);
- reductionName =
- getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
- redType, isByRef);
} else {
- TODO(currentLocation, "Unexpected reduction type");
+ // `do concurrent` reductions
+ redId = getReductionType(redOperatorList[idx]);
+ reductionName = getReductionName(redId, kindMap, redType, isByRef);
}
- decl = createDeclareReduction(converter, reductionName, redId, redType,
- currentLocation, isByRef);
+ OpType decl = createDeclareReduction<OpType>(
+ converter, reductionName, redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+ ++idx;
}
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index a7198b48f6b4e..95b4b077bdc46 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,7 +19,6 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
@@ -65,6 +64,9 @@ class ReductionProcessor {
static ReductionIdentifier
getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
+ static ReductionIdentifier
+ getReductionType(const fir::ReduceOperationEnum &pd);
+
static bool
supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
@@ -78,10 +80,9 @@ class ReductionProcessor {
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef);
- static std::string
- getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
- const fir::KindMapping &kindMap, mlir::Type ty,
- bool isByRef);
+ static std::string getReductionName(ReductionIdentifier redId,
+ const fir::KindMapping &kindMap,
+ mlir::Type ty, bool isByRef);
/// This function returns the identity value of the operator \p
/// reductionOpName. For example:
@@ -113,22 +114,23 @@ class ReductionProcessor {
/// symbol table. The declaration has a constant initializer with the neutral
/// value `initValue`, and the reduction combiner carried over from `reduce`.
/// TODO: add atomic region.
- static mlir::omp::DeclareReductionOp
- createDeclareReduction(AbstractConverter &builder,
- llvm::StringRef reductionOpName,
- const ReductionIdentifier redId, mlir::Type type,
- mlir::Location loc, bool isByRef);
+ template <typename OpType>
+ static OpType createDeclareReduction(AbstractConverter &builder,
+ llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId,
+ mlir::Type type, mlir::Location loc,
+ bool isByRef);
/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
- template <class T>
+ template <typename OpType, typename RedOperatorListTy>
static void processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
- const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
+ const RedOperatorListTy &redOperatorList,
+ llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
- mlir::omp::ReductionModifierAttr *reductionMod = nullptr);
+ const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
};
template <typename FloatOp, typename IntegerOp>
diff --git a/flang/lib/Lower/Support/Utils.cpp b/flang/lib/Lower/Support/Utils.cpp
index c65f51ce6cacd..b9d2574a76ad0 100644
--- a/flang/lib/Lower/Support/Utils.cpp
+++ b/flang/lib/Lower/Support/Utils.cpp
@@ -668,9 +668,7 @@ void privatizeSymbol(
const semantics::Symbol *sym =
isDoConcurrent ? &symToPrivatize->GetUltimate() : symToPrivatize;
- const lower::SymbolBox hsb = isDoConcurrent
- ? converter.shallowLookupSymbol(*sym)
- : converter.lookupOneLevelUpSymbol(*sym);
+ const lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
assert(hsb && "Host symbol box not found");
mlir::Location symLoc = hsb.getAddr().getLoc();
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index 3bbc32f23bcfa..c35d757be0f09 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -3342,26 +3342,26 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
}
};
-struct LocalitySpecifierOpConversion
- : public fir::FIROpConversion<fir::LocalitySpecifierOp> {
- using FIROpConversion::FIROpConversion;
+template <typename OpTy>
+struct DoConcurrentSpecifierOpConversion : public fir::FIROpConversion<OpTy> {
+ using fir::FIROpConversion<OpTy>::FIROpConversion;
llvm::LogicalResult
- matchAndRewrite(fir::LocalitySpecifierOp localizer, OpAdaptor adaptor,
+ matchAndRewrite(OpTy specifier, typename OpTy::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
#ifdef EXPENSIVE_CHECKS
auto uses = mlir::SymbolTable::getSymbolUses(
- localizer, localizer->getParentOfType<mlir::ModuleOp>());
+ specifier, specifier->getParentOfType<mlir::ModuleOp>());
- // `fir.local` ops are not supposed to have any uses at this point (i.e.
- // during lowering to LLVM). In case of serialization, the
- // `fir.do_concurrent` users are expected to have been lowered to
+ // `fir.local|fir.declare_reduction` ops are not supposed to have any uses
+ // at this point (i.e. during lowering to LLVM). In case of serialization,
+ // the `fir.do_concurrent` users are expected to have been lowered to
// `fir.do_loop` nests. In case of parallelization, the `fir.do_concurrent`
// users are expected to have been lowered to the target parallel model
// (e.g. OpenMP).
assert(uses && uses->empty());
#endif
- rewriter.eraseOp(localizer);
+ rewriter.eraseOp(specifier);
return mlir::success();
}
};
@@ -4330,20 +4330,22 @@ void fir::populateFIRToLLVMConversionPatterns(
BoxTypeCodeOpConversion, BoxTypeDescOpConversion, CallOpConversion,
CmpcOpConversion, VolatileCastOpConversion, ConvertOpConversion,
CoordinateOpConversion, CopyOpConversion, DTEntryOpConversion,
- DeclareOpConversion, DivcOpConversion, EmboxOpConversion,
- EmboxCharOpConversion, EmboxProcOpConversion, ExtractValueOpConversion,
- FieldIndexOpConversion, FirEndOpConversion, FreeMemOpConversion,
- GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion,
- IsPresentOpConversion, LenParamIndexOpConversion, LoadOpConversion,
- LocalitySpecifierOpConversion, MulcOpConversion, NegcOpConversion,
- NoReassocOpConversion, SelectCaseOpConversion, SelectOpConversion,
- SelectRankOpConversion, SelectTypeOpConversion, ShapeOpConversion,
- ShapeShiftOpConversion, ShiftOpConversion, SliceOpConversion,
- StoreOpConversion, StringLitOpConversion, SubcOpConversion,
- TypeDescOpConversion, TypeInfoOpConversion, UnboxCharOpConversion,
- UnboxProcOpConversion, UndefOpConversion, UnreachableOpConversion,
- XArrayCoorOpConversion, XEmboxOpConversion, XReboxOpConversion,
- ZeroOpConversion>(converter, options);
+ DeclareOpConversion,
+ DoConcurrentSpecifierOpConversion<fir::LocalitySpecifierOp>,
+ DoConcurrentSpecifierOpConversion<fir::DeclareReductionOp>,
+ DivcOpConversion, EmboxOpConversion, EmboxCharOpConversion,
+ EmboxProcOpConversion, ExtractValueOpConversion, FieldIndexOpConversion,
+ FirEndOpConversion, FreeMemOpConversion, GlobalLenOpConversion,
+ GlobalOpConversion, InsertOnRangeOpConversion, IsPresentOpConversion,
+ LenParamIndexOpConversion, LoadOpConversion, MulcOpConversion,
+ NegcOpConversion, NoReassocOpConversion, SelectCaseOpConversion,
+ SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion,
+ ShapeOpConversion, ShapeShiftOpConversion, ShiftOpConversion,
+ SliceOpConversion, StoreOpConversion, StringLitOpConversion,
+ SubcOpConversion, TypeDescOpConversion, TypeInfoOpConversion,
+ UnboxCharOpConversion, UnboxProcOpConversion, UndefOpConversion,
+ UnreachableOpConversion, XArrayCoorOpConversion, XEmboxOpConversion,
+ XReboxOpConversion, ZeroOpConversion>(converter, options);
// Patterns that are populated without a type converter do not trigger
// target materializations for the operands of the root op.
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index ecfa2939e96a6..6b40e7015fdd8 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5041,6 +5041,9 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
// LocalitySpecifierOp
//===----------------------------------------------------------------------===//
+// TODO This is a copy of omp::PrivateClauseOp::verifiyRegions(). Once we find a
+// solution to merge both ops into one this duplication will not be needed. See:
+// https://discourse.llvm.org/t/dialect-for-data-locality-sharing-specifiers-clauses-in-openmp-openacc-and-do-concurrent/86108.
llvm::LogicalResult fir::LocalitySpecifierOp::verifyRegions() {
mlir::Type argType = getArgType();
auto verifyTerminator = [&](mlir::Operation *terminator,
@@ -5136,6 +5139,84 @@ llvm::LogicalResult fir::LocalitySpecifierOp::verifyRegions() {
return llvm::success();
}
+// TODO This is a copy of omp::DeclareReductionOp::verifiyRegions(). Once we
+// find a solution to merge both ops into one this duplication will not be
+// needed.
+mlir::LogicalResult fir::DeclareReductionOp::verifyRegions() {
+ if (!getAllocRegion().empty()) {
+ for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
+ if (yieldOp.getResults().size() != 1 ||
+ yieldOp.getResults().getTypes()[0] != getType())
+ return emitOpError() << "expects alloc region to yield a value "
+ "of the reduction type";
+ }
+ }
+
+ if (getInitializerRegion().empty())
+ return emitOpError() << "expects non-empty initializer region";
+ mlir::Block &initializerEntryBlock = getInitializerRegion().front();
+
+ if (initializerEntryBlock.getNumArguments() == 1) {
+ if (!getAllocRegion().empty())
+ return emitOpError() << "expects two arguments to the initializer region "
+ "when an allocation region is used";
+ } else if (initializerEntryBlock.getNumArguments() == 2) {
+ if (getAllocRegion().empty())
+ return emitOpError() << "expects one argument to the initializer region "
+ "when no allocation region is used";
+ } else {
+ return emitOpError()
+ << "expects one or two arguments to the initializer region";
+ }
+
+ for (mlir::Value arg : initializerEntryBlock.getArguments())
+ if (arg.getType() != getType())
+ return emitOpError() << "expects initializer region argument to match "
+ "the reduction type";
+
+ for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
+ if (yieldOp.getResults().size() != 1 ||
+ yieldOp.getResults().getTypes()[0] != getType())
+ return emitOpError() << "expects initializer region to yield a value "
+ "of the reduction type";
+ }
+
+ if (getReductionRegion().empty())
+ return emitOpError() << "expects non-empty reduction region";
+ mlir::Block &reductionEntryBlock = getReductionRegion().front();
+ if (reductionEntryBlock.getNumArguments() != 2 ||
+ reductionEntryBlock.getArgumentTypes()[0] !=
+ reductionEntryBlock.getArgumentTypes()[1] ||
+ reductionEntryBlock.getArgumentTypes()[0] != getType())
+ return emitOpError() << "expects reduction region with two arguments of "
+ "the reduction type";
+ for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
+ if (yieldOp.getResults().size() != 1 ||
+ yieldOp.getResults().getTypes()[0] != getType())
+ return emitOpError() << "expects reduction region to yield a value "
+ "of the reduction type";
+ }
+
+ if (!getAtomicReductionRegion().empty()) {
+ mlir::Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
+ if (atomicReductionEntryBlock.getNumArguments() != 2 ||
+ atomicReductionEntryBlock.getArgumentTypes()[0] !=
+ atomicReductionEntryBlock.getArgumentTypes()[1])
+ return emitOpError() << "expects atomic reduction region with two "
+ "arguments of the same type";
+ }
+
+ if (getCleanupRegion().empty())
+ return mlir::success();
+ mlir::Block &cleanupEntryBlock = getCleanupRegion().front();
+ if (cleanupEntryBlock.getNumArguments() != 1 ||
+ cleanupEntryBlock.getArgument(0).getType() != getType())
+ return emitOpError() << "expects cleanup region with one argument "
+ "of the reduction type";
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// DoConcurrentOp
//===----------------------------------------------------------------------===//
@@ -5157,6 +5238,97 @@ llvm::LogicalResult fir::DoConcurrentOp::verify() {
// DoConcurrentLoopOp
//===----------------------------------------------------------------------===//
+static mlir::ParseResult parseSpecifierList(
+ mlir::OpAsmParser &parser, mlir::OperationState &result,
+ llvm::StringRef specifierKeyword, llvm::StringRef symsAttrName,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::Argument> ®ionArgs,
+ llvm::SmallVectorImpl<mlir::Type> ®ionArgTypes,
+ int32_t &numSpecifierOperands, bool isReduce = false) {
+ auto &builder = parser.getBuilder();
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> specifierOperands;
+
+ if (failed(parser.parseOptionalKeyword(specifierKeyword)))
+ return mlir::success();
+
+ std::size_t oldArgTypesSize = regionArgTypes.size();
+ if (failed(parser.parseLParen()))
+ return mlir::failure();
+
+ llvm::SmallVector<bool> isByRefVec;
+ llvm::SmallVector<mlir::SymbolRefAttr> spceifierSymbolVec;
+ llvm::SmallVector<fir::ReduceAttr> attributes;
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (isReduce)
+ isByRefVec.push_back(
+ parser.parseOptionalKeyword("byref").succeeded());
+
+ if (failed(parser.parseAttribute(spceifierSymbolVec.emplace_back())))
+ return mlir::failure();
+
+ if (isReduce &&
+ failed(parser.parseAttribute(attributes.emplace_back())))
+ return mlir::failure();
+
+ if (parser.parseOperand(specifierOperands.emplace_back()) ||
+ parser.parseArrow() ||
+ parser.parseArgument(regionArgs.emplace_back()))
+ return mlir::failure();
+
+ return mlir::success();
+ })))
+ return mlir::failure();
+
+ if (failed(parser.parseColon()))
+ return mlir::failure();
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (failed(parser.parseType(regionArgTypes.emplace_back())))
+ return mlir::failure();
+
+ return mlir::success();
+ })))
+ return mlir::failure();
+
+ if (regionArgs.size() != regionArgTypes.size())
+ return parser.emitError(parser.getNameLoc(), "mismatch in number of " +
+ specifierKeyword.str() +
+ " arg and types");
+
+ if (failed(parser.parseRParen()))
+ return mlir::failure();
+
+ for (auto operandType :
+ llvm::zip_equal(specifierOperands,
+ llvm::drop_begin(regionArgTypes, oldArgTypesSize)))
+ if (parser.resolveOperand(std::get<0>(operandType),
+ std::get<1>(operandType), result.operands))
+ return mlir::failure();
+
+ if (isReduce)
+ result.addAttribute(
+ fir::DoConcurrentLoopOp::getReduceByrefAttrName(result.name),
+ isByRefVec.empty()
+ ? nullptr
+ : mlir::DenseBoolArrayAttr::get(builder.getContext(), isByRefVec));
+
+ llvm::SmallVector<mlir::Attribute> symbolAttrs(spceifierSymbolVec.begin(),
+ spceifierSymbolVec.end());
+ result.addAttribute(symsAttrName, builder.getArrayAttr(symbolAttrs));
+
+ if (isReduce) {
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ result.addAttribute(
+ fir::DoConcurrentLoopOp::getReduceAttrsAttrName(result.name),
+ builder.getArrayAttr(arrayAttr));
+ }
+
+ numSpecifierOperands = specifierOperands.size();
+
+ return mlir::success();
+}
+
mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
mlir::OperationState &result) {
auto &builder = parser.getBuilder();
@@ -5192,90 +5364,26 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
return mlir::failure();
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
- llvm::SmallVector<mlir::Type> reduceArgTypes;
- if (succeeded(parser.parseOptionalKeyword("reduce"))) {
- // Parse reduction attributes and variables.
- llvm::SmallVector<fir::ReduceAttr> attributes;
- if (failed(parser.parseCommaSeparatedList(
- mlir::AsmParser::Delimiter::Paren, [&]() {
- if (parser.parseAttribute(attributes.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseOperand(reduceOperands.emplace_back()) ||
- parser.parseColonType(reduceArgTypes.emplace_back()))
- return mlir::failure();
- return mlir::success();
- })))
- return mlir::failure();
- // Resolve input operands.
- for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
- if (parser.resolveOperand(std::get<0>(operand_type),
- std::get<1>(operand_type), result.operands))
- return mlir::failure();
- llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
- attributes.end());
- result.addAttribute(getReduceAttrsAttrName(result.name),
- builder.getArrayAttr(arrayAttr));
- }
-
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
- if (succeeded(parser.parseOptionalKeyword("local"))) {
- std::size_t oldArgTypesSize = argTypes.size();
- if (failed(parser.parseLParen()))
- return mlir::failure();
-
- llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (failed(parser.parseAttribute(localSymbolVec.emplace_back())))
- return mlir::failure();
-
- if (parser.parseOperand(localOperands.emplace_back()) ||
- parser.parseArrow() ||
- parser.parseArgument(regionArgs.emplace_back()))
- return mlir::failure();
-
- return mlir::success();
- })))
- return mlir::failure();
-
- if (failed(parser.parseColon()))
- return mlir::failure();
-
- if (failed(parser.parseCommaSeparatedList([&]() {
- if (failed(parser.parseType(argTypes.emplace_back())))
- return mlir::failure();
-
- return mlir::success();
- })))
- return mlir::failure();
-
- if (regionArgs.size() != argTypes.size())
- return parser.emitError(parser.getNameLoc(),
- "mismatch in number of local arg and types");
-
- if (failed(parser.parseRParen()))
- return mlir::failure();
-
- for (auto operandType : llvm::zip_equal(
- localOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
- if (parser.resolveOperand(std::get<0>(operandType),
- std::get<1>(operandType), result.operands))
- return mlir::failure();
+ int32_t numLocalOperands = 0;
+ if (failed(parseSpecifierList(parser, result, "local",
+ getLocalSymsAttrName(result.name), regionArgs,
+ argTypes, numLocalOperands)))
+ return mlir::failure();
- llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(),
- localSymbolVec.end());
- result.addAttribute(getLocalSymsAttrName(result.name),
- builder.getArrayAttr(symbolAttrs));
- }
+ int32_t numReduceOperands = 0;
+ if (failed(parseSpecifierList(
+ parser, result, "reduce", getReduceSymsAttrName(result.name),
+ regionArgs, argTypes, numReduceOperands, /*isReduce=*/true)))
+ return mlir::failure();
// Set `operandSegmentSizes` attribute.
- result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
- builder.getDenseI32ArrayAttr(
- {static_cast<int32_t>(lower.size()),
- static_cast<int32_t>(upper.size()),
- static_cast<int32_t>(steps.size()),
- static_cast<int32_t>(reduceOperands.size()),
- static_cast<int32_t>(localOperands.size())}));
+ result.addAttribute(
+ DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(lower.size()),
+ static_cast<int32_t>(upper.size()),
+ static_cast<int32_t>(steps.size()),
+ static_cast<int32_t>(numLocalOperands),
+ static_cast<int32_t>(numReduceOperands)}));
// Now parse the body.
for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
@@ -5297,17 +5405,6 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
<< ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step ("
<< getStep() << ")";
- if (!getReduceOperands().empty()) {
- p << " reduce(";
- auto attrs = getReduceAttrsAttr();
- auto operands = getReduceOperands();
- llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
- p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
- << std::get<1>(it).getType();
- });
- p << ')';
- }
-
if (!getLocalVars().empty()) {
p << " local(";
llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(),
@@ -5322,13 +5419,34 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
p << ")";
}
+ if (!getReduceVars().empty()) {
+ p << " reduce(";
+ llvm::interleaveComma(
+ llvm::zip_equal(getReduceByrefAttr().asArrayRef(), getReduceSymsAttr(),
+ getReduceAttrsAttr(), getReduceVars(),
+ getRegionReduceArgs()),
+ p, [&](auto it) {
+ if (std::get<0>(it))
+ p << "byref ";
+
+ p << std::get<1>(it) << " " << std::get<2>(it) << " "
+ << std::get<3>(it) << " -> " << std::get<4>(it);
+ });
+ p << " : ";
+ llvm::interleaveComma(getReduceVars(), p,
+ [&](auto it) { p << it.getType(); });
+ p << ")";
+ }
+
p << ' ';
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
+ DoConcurrentLoopOp::getLocalSymsAttrName(),
+ DoConcurrentLoopOp::getReduceSymsAttrName(),
DoConcurrentLoopOp::getReduceAttrsAttrName(),
- DoConcurrentLoopOp::getLocalSymsAttrName()});
+ DoConcurrentLoopOp::getReduceByrefAttrName()});
}
llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
@@ -5340,6 +5458,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
mlir::Operation::operand_range ubValues = getUpperBound();
mlir::Operation::operand_range stepValues = getStep();
mlir::Operation::operand_range localVars = getLocalVars();
+ mlir::Operation::operand_range reduceVars = getReduceVars();
if (lbValues.empty())
return emitOpError(
@@ -5353,7 +5472,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
// Check that the body defines the same number of block arguments as the
// number of tuple elements in step.
mlir::Block *body = getBody();
- unsigned numIndVarArgs = body->getNumArguments() - localVars.size();
+ unsigned numIndVarArgs =
+ body->getNumArguments() - localVars.size() - reduceVars.size();
if (numIndVarArgs != stepValues.size())
return emitOpError() << "expects the same number of induction variables: "
diff --git a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
index e440852b3103a..506c8e66dbdfa 100644
--- a/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
+++ b/flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
@@ -234,6 +234,10 @@ class DoConcurrentConversion
loop.setLocalSymsAttr(nullptr);
}
+ for (auto [reduceVar, reduceArg] :
+ llvm::zip_equal(loop.getReduceVars(), loop.getRegionReduceArgs()))
+ rewriter.replaceAllUsesWith(reduceArg, reduceVar);
+
// Collect iteration variable(s) allocations so that we can move them
// outside the `fir.do_concurrent` wrapper.
llvm::SmallVector<mlir::Operation *> opsToMove;
@@ -257,12 +261,16 @@ class DoConcurrentConversion
innermostUnorderdLoop = rewriter.create<fir::DoLoopOp>(
doConcurentOp.getLoc(), lb, ub, st,
/*unordred=*/true, /*finalCountValue=*/false,
- /*iterArgs=*/std::nullopt, loop.getReduceOperands(),
+ /*iterArgs=*/std::nullopt, loop.getReduceVars(),
loop.getReduceAttrsAttr());
ivArgs.push_back(innermostUnorderdLoop.getInductionVar());
rewriter.setInsertionPointToStart(innermostUnorderdLoop.getBody());
}
+ loop.getRegion().front().eraseArguments(loop.getNumInductionVars() +
+ loop.getNumLocalOperands(),
+ loop.getNumReduceOperands());
+
rewriter.inlineBlockBefore(
&loopBlock, innermostUnorderdLoop.getBody()->getTerminator(), ivArgs);
rewriter.eraseOp(doConcurentOp);
diff --git a/flang/test/Fir/do_concurrent.fir b/flang/test/Fir/do_concurrent.fir
index cc1197ba56bd7..6e2173447855e 100644
--- a/flang/test/Fir/do_concurrent.fir
+++ b/flang/test/Fir/do_concurrent.fir
@@ -63,7 +63,7 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
%j = fir.alloca i32
fir.do_concurrent.loop
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
- reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
+ reduce(@add_reduction_i32 #fir.reduce_attr<add> %sum -> %sum_arg : !fir.ref<i32>) {
%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>
@@ -83,7 +83,7 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
// CHECK: %[[I:.*]] = fir.alloca i32
// CHECK: %[[J:.*]] = fir.alloca i32
// CHECK: fir.do_concurrent.loop
-// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) reduce(#fir.reduce_attr<add> -> %[[SUM]] : !fir.ref<i32>) {
+// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) reduce(@add_reduction_i32 #fir.reduce_attr<add> %[[SUM]] -> %{{.*}} : !fir.ref<i32>) {
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
@@ -161,3 +161,62 @@ func.func @do_concurrent_with_locality_specs() {
// CHECK: }
// CHECK: return
// CHECK: }
+
+func.func @dc_reduce() {
+ %3 = fir.alloca i32 {bindc_name = "s", uniq_name = "dc_reduce"}
+ %4:2 = hlfir.declare %3 {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %c1 = arith.constant 1 : index
+ fir.do_concurrent {
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c1) step (%c1) reduce(byref @add_reduction_i32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<i32>) {
+ }
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @dc_reduce() {
+// CHECK: %[[S_ALLOC:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "dc_reduce"}
+// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %[[S_ALLOC]] {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: fir.do_concurrent {
+// CHECK: fir.do_concurrent.loop (%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) reduce(byref @add_reduction_i32 #fir.reduce_attr<add> %[[S_DECL]]#0 -> %[[S_ARG:.*]] : !fir.ref<i32>) {
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+func.func @dc_reduce_2() {
+ %3 = fir.alloca i32 {bindc_name = "s", uniq_name = "dc_reduce"}
+ %4:2 = hlfir.declare %3 {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+ %5 = fir.alloca i32 {bindc_name = "m", uniq_name = "dc_reduce"}
+ %6:2 = hlfir.declare %5 {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+ %c1 = arith.constant 1 : index
+
+ fir.do_concurrent {
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c1) step (%c1)
+ reduce(@add_reduction_i32 #fir.reduce_attr<add> %4#0 -> %arg1,
+ @mul_reduction_i32 #fir.reduce_attr<multiply> %6#0 -> %arg2
+ : !fir.ref<i32>, !fir.ref<i32>) {
+ }
+ }
+
+ return
+}
+
+// CHECK-LABEL: func.func @dc_reduce_2() {
+// CHECK: %[[S_ALLOC:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "dc_reduce"}
+// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %[[S_ALLOC]] {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+// CHECK: %[[M_ALLOC:.*]] = fir.alloca i32 {bindc_name = "m", uniq_name = "dc_reduce"}
+// CHECK: %[[M_DECL:.*]]:2 = hlfir.declare %[[M_ALLOC]] {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+// CHECK: fir.do_concurrent {
+// CHECK: fir.do_concurrent.loop (%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{[^[:space:]]+}})
+// CHECK-SAME: reduce(
+// CHECK-SAME: @add_reduction_i32 #fir.reduce_attr<add> %[[S_DECL]]#0 -> %[[S_ARG:[^,]+]],
+// CHECK-SAME: @mul_reduction_i32 #fir.reduce_attr<multiply> %[[M_DECL]]#0 -> %[[M_ARG:[^[:space:]]+]]
+// CHECK-SAME: : !fir.ref<i32>, !fir.ref<i32>) {
+// CHECK: }
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index aca0ecc1abdc1..e32ea7ad3c729 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1256,8 +1256,8 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
%sum = fir.alloca i32
// expected-error at +2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}}
fir.do_concurrent {
- "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>}> ({
- ^bb0(%arg3: index):
+ "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 1>}> ({
+ ^bb0(%arg3: index, %sum_arg: i32):
%tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
}) : (index, index, index, !fir.ref<i32>) -> ()
}
@@ -1266,6 +1266,20 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
// -----
+func.func @dc_reduce_no_attr() {
+ %3 = fir.alloca i32 {bindc_name = "s", uniq_name = "dc_reduce"}
+ %4:2 = hlfir.declare %3 {uniq_name = "dc_reduce"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ %c1 = arith.constant 1 : index
+ // expected-error at +2 {{expected attribute value}}
+ fir.do_concurrent {
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c1) step (%c1) reduce(@add_reduction_i32 %4#0 -> %arg1 : !fir.ref<i32>) {
+ }
+ }
+ return
+}
+
+// -----
+
// Should fail when volatility changes from a fir.convert
func.func @bad_convert_volatile(%arg0: !fir.ref<i32>) -> !fir.ref<i32, volatile> {
// expected-error at +1 {{op this conversion does not preserve volatility}}
diff --git a/flang/test/Lower/do_concurrent_reduce.f90 b/flang/test/Lower/do_concurrent_reduce.f90
new file mode 100644
index 0000000000000..8591a21e2b9e0
--- /dev/null
+++ b/flang/test/Lower/do_concurrent_reduce.f90
@@ -0,0 +1,41 @@
+! RUN: %flang_fc1 -emit-hlfir -mmlir --enable-delayed-privatization-staging=true -o - %s | FileCheck %s
+
+subroutine do_concurrent_reduce
+ implicit none
+ integer :: s, i
+
+ do concurrent (i=1:10) reduce(+:s)
+ s = s + 1
+ end do
+end
+
+! CHECK-LABEL: fir.declare_reduction @add_reduction_i32 : i32 init {
+! CHECK: ^bb0(%[[ARG0:.*]]: i32):
+! CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
+! CHECK: fir.yield(%[[VAL_0]] : i32)
+! CHECK: } combiner {
+! CHECK: ^bb0(%[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: i32):
+! CHECK: %[[VAL_3:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : i32
+! CHECK: fir.yield(%[[VAL_3]] : i32)
+! CHECK: }
+
+! CHECK-LABEL: func.func @_QPdo_concurrent_reduce() {
+! CHECK: %[[S_ALLOC:.*]] = fir.alloca i32 {bindc_name = "s", uniq_name = "_QFdo_concurrent_reduceEs"}
+! CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %[[S_ALLOC]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+! CHECK: fir.do_concurrent {
+! CHECK: %[[VAL_0:.*]] = fir.alloca i32 {bindc_name = "i"}
+! CHECK: %[[VAL_1:.*]]:2 = hlfir.declare %[[VAL_0]] {uniq_name = "_QFdo_concurrent_reduceEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: fir.do_concurrent.loop (%{{.*}}) = (%{{.*}}) to (%{{.*}}) step (%{{[^[:space:]]+}})
+! CHECK-SAME: reduce(@add_reduction_i32 #fir.reduce_attr<add> %[[S_DECL]]#0 -> %[[S_ARG:.*]] : !fir.ref<i32>) {
+
+! CHECK: %[[S_ARG_DECL:.*]]:2 = hlfir.declare %[[S_ARG]] {uniq_name = "_QFdo_concurrent_reduceEs"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[S_ARG_VAL:.*]] = fir.load %[[S_ARG_DECL]]#0 : !fir.ref<i32>
+! CHECK: %[[C1:.*]] = arith.constant 1 : i32
+! CHECK: %[[RED_UPDATE:.*]] = arith.addi %[[S_ARG_VAL]], %[[C1]] : i32
+! CHECK: hlfir.assign %[[RED_UPDATE]] to %[[S_ARG_DECL]]#0 : i32, !fir.ref<i32>
+
+! CHECK: }
+! CHECK: }
+! CHECK: return
+! CHECK: }
diff --git a/flang/test/Lower/loops.f90 b/flang/test/Lower/loops.f90
index 60df27a591dc3..64f14ff972272 100644
--- a/flang/test/Lower/loops.f90
+++ b/flang/test/Lower/loops.f90
@@ -1,4 +1,4 @@
-! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
+! RUN: bbc -emit-fir -hlfir=false --enable-delayed-privatization=false -o - %s | FileCheck %s
! CHECK-LABEL: loop_test
subroutine loop_test
diff --git a/flang/test/Lower/loops3.f90 b/flang/test/Lower/loops3.f90
index 84db1972cca16..2965b954b49a8 100644
--- a/flang/test/Lower/loops3.f90
+++ b/flang/test/Lower/loops3.f90
@@ -12,7 +12,7 @@ subroutine loop_test
! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"}
! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
- ! CHECK: fir.do_concurrent.loop ({{.*}}) = ({{.*}}) to ({{.*}}) step ({{.*}}) reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
+ ! CHECK: fir.do_concurrent.loop ({{.*}}) = ({{.*}}) to ({{.*}}) step ({{.*}}) reduce(@add_reduction_i32 #fir.reduce_attr<add> %[[VAL_1]] -> %{{.*}}, @other_reduction_f32 #fir.reduce_attr<max> %[[VAL_0]] -> %{{.*}} : {{.*}}) {
do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
tmp = i + j + k
sum = tmp + sum
diff --git a/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir b/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
index d9ef36b175598..c550ab8a97d4c 100644
--- a/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
+++ b/flang/test/Transforms/do_concurrent-to-do_loop-unodered.fir
@@ -86,7 +86,7 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
%j = fir.alloca i32
fir.do_concurrent.loop
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
- reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
+ reduce(@add_reduction_i32 #fir.reduce_attr<add> %sum -> %sum_arg : !fir.ref<i32>) {
%0 = fir.convert %i_iv : (index) -> i32
fir.store %0 to %i : !fir.ref<i32>
More information about the flang-commits
mailing list